diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 5b99940b469..0631964fb06 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -111,6 +111,7 @@ enum { typedef enum { ms_out_mode_lds, ms_out_mode_vram, + ms_out_mode_var, } ms_out_mode; typedef struct @@ -134,6 +135,11 @@ typedef struct ms_out_part vtx_attr; ms_out_part prm_attr; } vram; + /* Outputs without cross-invocation access can be stored in variables. */ + struct { + ms_out_part vtx_attr; + ms_out_part prm_attr; + } var; } ms_out_mem_layout; typedef struct @@ -148,6 +154,7 @@ typedef struct unsigned hw_workgroup_size; nir_ssa_def *workgroup_index; + nir_variable *out_variables[VARYING_SLOT_MAX * 4]; /* True if the lowering needs to insert the layer output. */ bool insert_layer_output; @@ -155,8 +162,6 @@ typedef struct struct { /* Bitmask of components used: 4 bits per slot, 1 bit per component. */ uint32_t components_mask; - /* Driver location of the output slot, if used. */ - unsigned driver_location; } output_info[VARYING_SLOT_MAX]; } lower_ngg_ms_state; @@ -2128,17 +2133,10 @@ ms_arrayed_output_base_addr(nir_builder *b, static void update_ms_output_info_slot(lower_ngg_ms_state *s, - unsigned slot, unsigned base, unsigned base_off, + unsigned slot, unsigned base_off, uint32_t components_mask) { while (components_mask) { - unsigned driver_location = base + base_off; - - /* If already set, it must match. */ - if (s->output_info[slot + base_off].driver_location) - assert(s->output_info[slot + base_off].driver_location == driver_location); - - s->output_info[slot + base_off].driver_location = driver_location; s->output_info[slot + base_off].components_mask |= components_mask & 0xF; components_mask >>= 4; @@ -2148,13 +2146,13 @@ update_ms_output_info_slot(lower_ngg_ms_state *s, static void update_ms_output_info(nir_intrinsic_instr *intrin, + const ms_out_part *out, lower_ngg_ms_state *s) { nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); nir_src *base_offset_src = nir_get_io_offset_src(intrin); uint32_t write_mask = nir_intrinsic_write_mask(intrin); unsigned component_offset = nir_intrinsic_component(intrin); - unsigned base = nir_intrinsic_base(intrin); nir_ssa_def *store_val = intrin->src[0].ssa; write_mask = util_widen_mask(write_mask, DIV_ROUND_UP(store_val->bit_size, 32)); @@ -2163,11 +2161,11 @@ update_ms_output_info(nir_intrinsic_instr *intrin, if (nir_src_is_const(*base_offset_src)) { /* Simply mark the components of the current slot as used. */ unsigned base_off = nir_src_as_uint(*base_offset_src); - update_ms_output_info_slot(s, io_sem.location, base, base_off, components_mask); + update_ms_output_info_slot(s, io_sem.location, base_off, components_mask); } else { /* Indirect offset: mark the components of all slots as used. */ for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off) - update_ms_output_info_slot(s, io_sem.location, base, base_off, components_mask); + update_ms_output_info_slot(s, io_sem.location, base_off, components_mask); } } @@ -2228,6 +2226,9 @@ ms_get_out_layout_part(unsigned location, } else if (mask & s->layout.vram.prm_attr.mask) { *out_mode = ms_out_mode_vram; return &s->layout.vram.prm_attr; + } else if (mask & s->layout.var.prm_attr.mask) { + *out_mode = ms_out_mode_var; + return &s->layout.var.prm_attr; } } else { if (mask & s->layout.lds.vtx_attr.mask) { @@ -2236,6 +2237,9 @@ ms_get_out_layout_part(unsigned location, } else if (mask & s->layout.vram.vtx_attr.mask) { *out_mode = ms_out_mode_vram; return &s->layout.vram.vtx_attr; + } else if (mask & s->layout.var.vtx_attr.mask) { + *out_mode = ms_out_mode_var; + return &s->layout.var.vtx_attr; } } @@ -2250,8 +2254,13 @@ ms_store_arrayed_output_intrin(nir_builder *b, ms_out_mode out_mode; unsigned location = nir_intrinsic_io_semantics(intrin).location; const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s); + update_ms_output_info(intrin, out, s); - unsigned driver_location = nir_intrinsic_base(intrin); + /* We compact the LDS size (we don't reserve LDS space for outputs which can + * be stored in variables), so we can't rely on the original driver_location. + * Instead, we compute the first free location based on the output mask. + */ + unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location)); unsigned component_offset = nir_intrinsic_component(intrin); unsigned write_mask = nir_intrinsic_write_mask(intrin); unsigned num_outputs = util_bitcount64(out->mask); @@ -2271,10 +2280,23 @@ ms_store_arrayed_output_intrin(nir_builder *b, } else if (out_mode == ms_out_mode_vram) { nir_ssa_def *ring = nir_load_ring_mesh_scratch_amd(b); nir_ssa_def *off = nir_load_ring_mesh_scratch_offset_amd(b); - nir_store_buffer_amd(b, store_val, ring, base_addr, off, + nir_store_buffer_amd(b, store_val, ring, addr, off, .base = const_off, .write_mask = write_mask, .memory_modes = nir_var_shader_out); + } else if (out_mode == ms_out_mode_var) { + if (store_val->bit_size > 32) { + /* Split 64-bit store values to 32-bit components. */ + store_val = nir_bitcast_vector(b, store_val, 32); + /* Widen the write mask so it is in 32-bit components. */ + write_mask = util_widen_mask(write_mask, store_val->bit_size / 32); + } + + u_foreach_bit(comp, write_mask) { + nir_ssa_def *val = nir_channel(b, store_val, comp); + unsigned idx = location * 4 + comp + component_offset; + nir_store_var(b, s->out_variables[idx], val, 0x1); + } } else { unreachable("Invalid MS output mode for store"); } @@ -2285,7 +2307,6 @@ ms_load_arrayed_output(nir_builder *b, nir_ssa_def *arr_index, nir_ssa_def *base_offset, unsigned location, - unsigned driver_location, unsigned component_offset, unsigned num_components, unsigned load_bit_size, @@ -2298,6 +2319,9 @@ ms_load_arrayed_output(nir_builder *b, unsigned num_outputs = util_bitcount64(out->mask); unsigned const_off = out->addr + component_offset * 4; + /* Use compacted driver location instead of the original. */ + unsigned driver_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location)); + nir_ssa_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, driver_location, num_outputs); nir_ssa_def *base_addr_off = nir_imul_imm(b, base_offset, 16); nir_ssa_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); @@ -2312,6 +2336,16 @@ ms_load_arrayed_output(nir_builder *b, return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, .base = const_off, .memory_modes = nir_var_shader_out); + } else if (out_mode == ms_out_mode_var) { + nir_ssa_def *arr[8] = {0}; + unsigned num_32bit_components = num_components * load_bit_size / 32; + for (unsigned comp = 0; comp < num_32bit_components; ++comp) { + unsigned idx = location * 4 + comp + component_addr_off; + arr[comp] = nir_load_var(b, s->out_variables[idx]); + } + if (load_bit_size > 32) + return nir_extract_bits(b, arr, 1, 0, num_components, load_bit_size); + return nir_vec(b, arr, num_components); } else { unreachable("Invalid MS output mode for load"); } @@ -2326,15 +2360,14 @@ ms_load_arrayed_output_intrin(nir_builder *b, nir_ssa_def *base_offset = nir_get_io_offset_src(intrin)->ssa; unsigned location = nir_intrinsic_io_semantics(intrin).location; - unsigned driver_location = nir_intrinsic_base(intrin); unsigned component_offset = nir_intrinsic_component(intrin); unsigned bit_size = intrin->dest.ssa.bit_size; unsigned num_components = intrin->dest.ssa.num_components; unsigned load_bit_size = MAX2(bit_size, 32); nir_ssa_def *load = - ms_load_arrayed_output(b, arr_index, base_offset, location, driver_location, - component_offset, num_components, load_bit_size, s); + ms_load_arrayed_output(b, arr_index, base_offset, location, component_offset, + num_components, load_bit_size, s); return regroup_load_val(b, load, bit_size); } @@ -2384,7 +2417,6 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state) return lower_ms_load_output(b, intrin, s); case nir_intrinsic_store_per_vertex_output: case nir_intrinsic_store_per_primitive_output: - update_ms_output_info(intrin, s); ms_store_arrayed_output_intrin(b, intrin, s); return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_load_per_vertex_output: @@ -2436,7 +2468,6 @@ ms_emit_arrayed_outputs(nir_builder *b, assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES); const nir_io_semantics io_sem = { .location = slot, .num_slots = 1 }; - const unsigned driver_location = s->output_info[slot].driver_location; unsigned component_mask = s->output_info[slot].components_mask; while (component_mask) { @@ -2444,7 +2475,7 @@ ms_emit_arrayed_outputs(nir_builder *b, u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components); nir_ssa_def *load = - ms_load_arrayed_output(b, invocation_index, zero, slot, driver_location, start_comp, + ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp, num_components, 32, s); nir_store_output(b, load, nir_imm_int(b, 0), .base = slot, .component = start_comp, @@ -2456,14 +2487,25 @@ ms_emit_arrayed_outputs(nir_builder *b, static void emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s) { + b->cursor = nir_before_cf_list(&b->impl->body); + + /* Initialize NIR variables for same-invocation outputs. */ + uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask; + + u_foreach_bit64(slot, same_invocation_output_mask) { + for (unsigned comp = 0; comp < 4; ++comp) { + unsigned idx = slot * 4 + comp; + s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output"); + nir_store_var(b, s->out_variables[idx], nir_imm_int(b, 0), 0x1); + } + } + bool uses_workgroup_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID); if (!uses_workgroup_id) return; - b->cursor = nir_before_cf_list(&b->impl->body); - /* The HW doesn't support a proper workgroup index for vertex processing stages, * so we use the vertex ID which is equivalent to the index of the current workgroup * within the current dispatch. @@ -2800,21 +2842,27 @@ static ms_out_mem_layout ms_calculate_output_layout(unsigned api_shared_size, uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask, + uint64_t cross_invocation_output_access, unsigned max_vertices, unsigned max_primitives, unsigned vertices_per_prim) { - uint64_t lds_per_vertex_output_mask = per_vertex_output_mask; - uint64_t lds_per_primitive_output_mask = per_primitive_output_mask; + uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access; + uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access; /* Shared memory used by the API shader. */ ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } }; + /* Outputs without cross-invocation access can be stored in variables. */ + l.var.vtx_attr.mask = per_vertex_output_mask & ~lds_per_vertex_output_mask; + l.var.prm_attr.mask = per_primitive_output_mask & ~lds_per_primitive_output_mask; + /* Workgroup information, see ms_workgroup_* for the layout. */ l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16); l.lds.total_size = l.lds.workgroup_info_addr + 16; /* Per-vertex and per-primitive output attributes. + * Outputs without cross-invocation access are not included here. * First, try to put all outputs into LDS (shared memory). * If they don't fit, try to move them to VRAM one by one. */ @@ -2867,12 +2915,16 @@ ac_nir_lower_ngg_ms(nir_shader *shader, uint64_t per_primitive_outputs = shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs; + /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */ + uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access | + shader->info.outputs_accessed_indirectly; + unsigned max_vertices = shader->info.mesh.max_vertices_out; unsigned max_primitives = shader->info.mesh.max_primitives_out; ms_out_mem_layout layout = ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs, - max_vertices, max_primitives, vertices_per_prim); + cross_invocation_access, max_vertices, max_primitives, vertices_per_prim); shader->info.shared_size = layout.lds.total_size; *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask; @@ -2920,5 +2972,10 @@ ac_nir_lower_ngg_ms(nir_shader *shader, nir_metadata_preserve(impl, nir_metadata_none); /* Cleanup */ + nir_lower_vars_to_ssa(shader); + nir_remove_dead_variables(shader, nir_var_function_temp, NULL); + nir_lower_alu_to_scalar(shader, NULL, NULL); + nir_lower_phis_to_scalar(shader, true); + nir_validate_shader(shader, "after emitting NGG MS"); }