diff --git a/src/amd/common/ac_nir_helpers.h b/src/amd/common/ac_nir_helpers.h index 4776600be55..624a7be5bb9 100644 --- a/src/amd/common/ac_nir_helpers.h +++ b/src/amd/common/ac_nir_helpers.h @@ -174,6 +174,11 @@ ac_nir_gs_shader_query(nir_builder *b, nir_def *vertex_count[4], nir_def *primitive_count[4]); +nir_def * +ac_nir_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives, + nir_def *vertex_indices[3], nir_def *is_null_prim, + enum amd_gfx_level gfx_level); + #ifdef __cplusplus } #endif diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 6ac41bd8c92..5eeedca17a0 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -12,23 +12,6 @@ #include "util/u_math.h" #include "util/u_vector.h" -#define SPECIAL_MS_OUT_MASK \ - (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \ - BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \ - BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) - -#define MS_PRIM_ARG_EXP_MASK \ - (VARYING_BIT_LAYER | \ - VARYING_BIT_VIEWPORT | \ - VARYING_BIT_PRIMITIVE_SHADING_RATE) - -#define MS_VERT_ARG_EXP_MASK \ - (VARYING_BIT_CULL_DIST0 | \ - VARYING_BIT_CULL_DIST1 | \ - VARYING_BIT_CLIP_DIST0 | \ - VARYING_BIT_CLIP_DIST1 | \ - VARYING_BIT_PSIZ) - enum { nggc_passflag_used_by_pos = 1, nggc_passflag_used_by_other = 2, @@ -109,103 +92,6 @@ typedef struct nir_def *primitive_count[4]; } lower_ngg_gs_state; -/* LDS layout of Mesh Shader workgroup info. */ -enum { - /* DW0: number of primitives */ - lds_ms_num_prims = 0, - /* DW1: number of vertices */ - lds_ms_num_vtx = 4, - /* DW2: workgroup index within the current dispatch */ - lds_ms_wg_index = 8, - /* DW3: number of API workgroups in flight */ - lds_ms_num_api_waves = 12, -}; - -/* Potential location for Mesh Shader outputs. */ -typedef enum { - ms_out_mode_lds, - ms_out_mode_scratch_ring, - ms_out_mode_attr_ring, - ms_out_mode_var, -} ms_out_mode; - -typedef struct -{ - uint64_t mask; /* Mask of output locations */ - uint32_t addr; /* Base address */ -} ms_out_part; - -typedef struct -{ - /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */ - struct { - uint32_t workgroup_info_addr; - ms_out_part vtx_attr; - ms_out_part prm_attr; - uint32_t indices_addr; - uint32_t cull_flags_addr; - uint32_t total_size; - } lds; - - /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. - * Not to be confused with scratch memory. - */ - struct { - ms_out_part vtx_attr; - ms_out_part prm_attr; - } scratch_ring; - - /* VRAM attributes ring (GFX11 only) for all non-position outputs. - * GFX11 doesn't have to reload attributes from this ring at the end of the shader. - */ - struct { - ms_out_part vtx_attr; - ms_out_part prm_attr; - } attr_ring; - - /* Outputs without cross-invocation access can be stored in variables. */ - struct { - ms_out_part vtx_attr; - ms_out_part prm_attr; - } var; -} ms_out_mem_layout; - -typedef struct -{ - enum amd_gfx_level gfx_level; - bool fast_launch_2; - bool vert_multirow_export; - bool prim_multirow_export; - - ms_out_mem_layout layout; - uint64_t per_vertex_outputs; - uint64_t per_primitive_outputs; - unsigned vertices_per_prim; - - unsigned wave_size; - unsigned api_workgroup_size; - unsigned hw_workgroup_size; - - nir_def *workgroup_index; - nir_variable *out_variables[VARYING_SLOT_MAX * 4]; - nir_variable *primitive_count_var; - nir_variable *vertex_count_var; - - ac_nir_prerast_out out; - - /* True if the lowering needs to insert the layer output. */ - bool insert_layer_output; - /* True if cull flags are used */ - bool uses_cull_flags; - - uint32_t clipdist_enable_mask; - const uint8_t *vs_output_param_offset; - bool has_param_exports; - - /* True if the lowering needs to insert shader query. */ - bool has_query; -} lower_ngg_ms_state; - /* Per-vertex LDS layout of culling shaders */ enum { /* Position of the ES vertex (at the beginning for alignment reasons) */ @@ -430,29 +316,6 @@ pervertex_lds_addr(nir_builder *b, nir_def *vertex_idx, unsigned per_vtx_bytes) return nir_imul_imm(b, vertex_idx, per_vtx_bytes); } -static nir_def * -emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives, - nir_def *vertex_indices[3], nir_def *is_null_prim, - enum amd_gfx_level gfx_level) -{ - nir_def *arg = nir_load_initial_edgeflags_amd(b); - - for (unsigned i = 0; i < num_vertices_per_primitives; ++i) { - assert(vertex_indices[i]); - arg = nir_ior(b, arg, nir_ishl_imm(b, vertex_indices[i], - (gfx_level >= GFX12 ? 9u : 10u) * i)); - } - - if (is_null_prim) { - if (is_null_prim->bit_size == 1) - is_null_prim = nir_b2i32(b, is_null_prim); - assert(is_null_prim->bit_size == 32); - arg = nir_ior(b, arg, nir_ishl_imm(b, is_null_prim, 31u)); - } - - return arg; -} - static void alloc_vertices_and_primitives(nir_builder *b, nir_def *num_vtx, @@ -546,7 +409,7 @@ emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *s) for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) vtx_idx[v] = nir_load_var(b, s->gs_vtx_indices_vars[v]); - return emit_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, vtx_idx, NULL, + return ac_nir_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, vtx_idx, NULL, s->options->gfx_level); } } @@ -1116,7 +979,7 @@ compact_vertices_after_culling(nir_builder *b, } nir_def *prim_exp_arg = - emit_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, + ac_nir_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, exporter_vtx_indices, NULL, s->options->gfx_level); nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u); } @@ -3200,7 +3063,7 @@ ngg_gs_export_primitives(nir_builder *b, nir_def *max_num_out_prims, nir_def *ti nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]); } - nir_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, + nir_def *arg = ac_nir_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, s->options->gfx_level); ac_nir_export_primitive(b, arg, NULL); nir_pop_if(b, if_prim_export_thread); @@ -3799,1334 +3662,3 @@ ac_ngg_get_scratch_lds_size(gl_shader_stage stage, return scratch_lds_size; } - -static void -ms_store_prim_indices(nir_builder *b, - nir_intrinsic_instr *intrin, - lower_ngg_ms_state *s) -{ - /* EXT_mesh_shader primitive indices: array of vectors. - * They don't count as per-primitive outputs, but the array is indexed - * by the primitive index, so they are practically per-primitive. - */ - assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); - assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); - - const unsigned component_offset = nir_intrinsic_component(intrin); - nir_def *store_val = intrin->src[0].ssa; - assert(store_val->num_components <= 3); - - if (store_val->num_components > s->vertices_per_prim) - store_val = nir_trim_vector(b, store_val, s->vertices_per_prim); - - if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { - for (unsigned c = 0; c < store_val->num_components; ++c) { - const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset; - nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1); - } - return; - } - - nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; - nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); - - /* The max vertex count is 256, so these indices always fit 8 bits. - * To reduce LDS use, store these as a flat array of 8-bit values. - */ - nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset); -} - -static void -ms_store_cull_flag(nir_builder *b, - nir_intrinsic_instr *intrin, - lower_ngg_ms_state *s) -{ - /* EXT_mesh_shader cull primitive: per-primitive bool. */ - assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); - assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); - assert(nir_intrinsic_component(intrin) == 0); - assert(nir_intrinsic_write_mask(intrin) == 1); - - nir_def *store_val = intrin->src[0].ssa; - - assert(store_val->num_components == 1); - assert(store_val->bit_size == 1); - - if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) { - nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1); - return; - } - - nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; - nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); - - /* To reduce LDS use, store these as an array of 8-bit values. */ - nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr); -} - -static nir_def * -ms_arrayed_output_base_addr(nir_builder *b, - nir_def *arr_index, - unsigned mapped_location, - unsigned num_arrayed_outputs) -{ - /* Address offset of the array item (vertex or primitive). */ - unsigned arr_index_stride = num_arrayed_outputs * 16u; - nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride); - - /* IO address offset within the vertex or primitive data. */ - unsigned io_offset = mapped_location * 16u; - nir_def *io_off = nir_imm_int(b, io_offset); - - return nir_iadd_nuw(b, arr_index_off, io_off); -} - -static void -update_ms_output_info(const nir_io_semantics io_sem, - const nir_src *base_offset_src, - const uint32_t write_mask, - const unsigned component_offset, - const unsigned bit_size, - const ms_out_part *out, - lower_ngg_ms_state *s) -{ - const uint32_t components_mask = write_mask << component_offset; - - /* 64-bit outputs should have already been lowered to 32-bit. */ - assert(bit_size <= 32); - assert(components_mask <= 0xf); - - /* When the base offset is constant, only mark the components of the current slot as used. - * Otherwise, mark the components of all possibly affected slots as used. - */ - const unsigned base_off_start = nir_src_is_const(*base_offset_src) ? nir_src_as_uint(*base_offset_src) : 0; - const unsigned num_slots = nir_src_is_const(*base_offset_src) ? 1 : io_sem.num_slots; - - for (unsigned base_off = base_off_start; base_off < num_slots; ++base_off) { - ac_nir_prerast_per_output_info *info = &s->out.infos[io_sem.location + base_off]; - info->components_mask |= components_mask; - - if (!io_sem.no_sysval_output) - info->as_sysval_mask |= components_mask; - if (!io_sem.no_varying) - info->as_varying_mask |= components_mask; - } -} - -static const ms_out_part * -ms_get_out_layout_part(unsigned location, - shader_info *info, - ms_out_mode *out_mode, - lower_ngg_ms_state *s) -{ - uint64_t mask = BITFIELD64_BIT(location); - - if (info->per_primitive_outputs & mask) { - if (mask & s->layout.lds.prm_attr.mask) { - *out_mode = ms_out_mode_lds; - return &s->layout.lds.prm_attr; - } else if (mask & s->layout.scratch_ring.prm_attr.mask) { - *out_mode = ms_out_mode_scratch_ring; - return &s->layout.scratch_ring.prm_attr; - } else if (mask & s->layout.attr_ring.prm_attr.mask) { - *out_mode = ms_out_mode_attr_ring; - return &s->layout.attr_ring.prm_attr; - } else if (mask & s->layout.var.prm_attr.mask) { - *out_mode = ms_out_mode_var; - return &s->layout.var.prm_attr; - } - } else { - if (mask & s->layout.lds.vtx_attr.mask) { - *out_mode = ms_out_mode_lds; - return &s->layout.lds.vtx_attr; - } else if (mask & s->layout.scratch_ring.vtx_attr.mask) { - *out_mode = ms_out_mode_scratch_ring; - return &s->layout.scratch_ring.vtx_attr; - } else if (mask & s->layout.attr_ring.vtx_attr.mask) { - *out_mode = ms_out_mode_attr_ring; - return &s->layout.attr_ring.vtx_attr; - } else if (mask & s->layout.var.vtx_attr.mask) { - *out_mode = ms_out_mode_var; - return &s->layout.var.vtx_attr; - } - } - - unreachable("Couldn't figure out mesh shader output mode."); -} - -static void -ms_store_arrayed_output(nir_builder *b, - nir_src *base_off_src, - nir_def *store_val, - nir_def *arr_index, - const nir_io_semantics io_sem, - const unsigned component_offset, - const unsigned write_mask, - lower_ngg_ms_state *s) -{ - ms_out_mode out_mode; - const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s); - update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s); - - bool hi_16b = io_sem.high_16bits; - bool lo_16b = !hi_16b && store_val->bit_size == 16; - - unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location)); - unsigned num_outputs = util_bitcount64(out->mask); - unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0); - - nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs); - nir_def *base_offset = base_off_src->ssa; - nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u); - nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); - - if (out_mode == ms_out_mode_lds) { - nir_store_shared(b, store_val, addr, .base = const_off, - .write_mask = write_mask, .align_mul = 16, - .align_offset = const_off % 16); - } else if (out_mode == ms_out_mode_scratch_ring) { - nir_def *ring = nir_load_ring_mesh_scratch_amd(b); - nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b); - nir_def *zero = nir_imm_int(b, 0); - nir_store_buffer_amd(b, store_val, ring, addr, off, zero, - .base = const_off, - .write_mask = write_mask, - .memory_modes = nir_var_shader_out, - .access = ACCESS_COHERENT); - } else if (out_mode == ms_out_mode_attr_ring) { - /* GFX11+: Store params straight to the attribute ring. - * - * Even though the access pattern may not be the most optimal, - * this is still much better than reserving LDS and losing waves. - * (Also much better than storing and reloading from the scratch ring.) - */ - unsigned param_offset = s->vs_output_param_offset[io_sem.location]; - nir_def *ring = nir_load_ring_attr_amd(b); - nir_def *soffset = nir_load_ring_attr_offset_amd(b); - nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index, - .base = const_off + param_offset * 16, - .write_mask = write_mask, - .memory_modes = nir_var_shader_out, - .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD); - } else if (out_mode == ms_out_mode_var) { - unsigned write_mask_32 = write_mask; - if (store_val->bit_size > 32) { - /* Split 64-bit store values to 32-bit components. */ - store_val = nir_bitcast_vector(b, store_val, 32); - /* Widen the write mask so it is in 32-bit components. */ - write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32); - } - - u_foreach_bit(comp, write_mask_32) { - unsigned idx = io_sem.location * 4 + comp + component_offset; - nir_def *val = nir_channel(b, store_val, comp); - nir_def *v = nir_load_var(b, s->out_variables[idx]); - - if (lo_16b) { - nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v); - val = nir_pack_32_2x16_split(b, val, var_hi); - } else if (hi_16b) { - nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v); - val = nir_pack_32_2x16_split(b, var_lo, val); - } - - nir_store_var(b, s->out_variables[idx], val, 0x1); - } - } else { - unreachable("Invalid MS output mode for store"); - } -} - -static void -ms_store_arrayed_output_intrin(nir_builder *b, - nir_intrinsic_instr *intrin, - lower_ngg_ms_state *s) -{ - const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); - - if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) { - ms_store_prim_indices(b, intrin, s); - return; - } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) { - ms_store_cull_flag(b, intrin, s); - return; - } - - unsigned component_offset = nir_intrinsic_component(intrin); - unsigned write_mask = nir_intrinsic_write_mask(intrin); - - nir_def *store_val = intrin->src[0].ssa; - nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; - nir_src *base_off_src = nir_get_io_offset_src(intrin); - - if (store_val->bit_size < 32) { - /* Split 16-bit output stores to ensure each 16-bit component is stored - * in the correct location, without overwriting the other 16 bits there. - */ - u_foreach_bit(c, write_mask) { - nir_def *store_component = nir_channel(b, store_val, c); - ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s); - } - } else { - ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s); - } -} - -static nir_def * -ms_load_arrayed_output(nir_builder *b, - nir_def *arr_index, - nir_def *base_offset, - unsigned location, - unsigned component_offset, - unsigned num_components, - unsigned load_bit_size, - lower_ngg_ms_state *s) -{ - ms_out_mode out_mode; - const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s); - - unsigned component_addr_off = component_offset * 4; - unsigned num_outputs = util_bitcount64(out->mask); - unsigned const_off = out->addr + component_offset * 4; - - /* Use compacted location instead of the original semantic location. */ - unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location)); - - nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs); - nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16); - nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); - - if (out_mode == ms_out_mode_lds) { - return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16, - .align_offset = component_addr_off % 16, - .base = const_off); - } else if (out_mode == ms_out_mode_scratch_ring) { - nir_def *ring = nir_load_ring_mesh_scratch_amd(b); - nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b); - nir_def *zero = nir_imm_int(b, 0); - return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero, - .base = const_off, - .memory_modes = nir_var_shader_out, - .access = ACCESS_COHERENT); - } else if (out_mode == ms_out_mode_var) { - assert(load_bit_size == 32); - nir_def *arr[8] = {0}; - for (unsigned comp = 0; comp < num_components; ++comp) { - unsigned idx = location * 4 + comp + component_addr_off; - arr[comp] = nir_load_var(b, s->out_variables[idx]); - } - return nir_vec(b, arr, num_components); - } else { - unreachable("Invalid MS output mode for load"); - } -} - -static nir_def * -lower_ms_load_workgroup_index(nir_builder *b, - UNUSED nir_intrinsic_instr *intrin, - lower_ngg_ms_state *s) -{ - return s->workgroup_index; -} - -static nir_def * -lower_ms_set_vertex_and_primitive_count(nir_builder *b, - nir_intrinsic_instr *intrin, - lower_ngg_ms_state *s) -{ - /* If either the number of vertices or primitives is zero, set both of them to zero. */ - nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa); - nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa); - nir_def *zero = nir_imm_int(b, 0); - nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero); - num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx); - num_prm = nir_bcsel(b, is_either_zero, zero, num_prm); - - nir_store_var(b, s->vertex_count_var, num_vtx, 0x1); - nir_store_var(b, s->primitive_count_var, num_prm, 0x1); - - return NIR_LOWER_INSTR_PROGRESS_REPLACE; -} - -static nir_def * -update_ms_barrier(nir_builder *b, - nir_intrinsic_instr *intrin, - lower_ngg_ms_state *s) -{ - /* Output loads and stores are lowered to shared memory access, - * so we have to update the barriers to also reflect this. - */ - unsigned mem_modes = nir_intrinsic_memory_modes(intrin); - if (mem_modes & nir_var_shader_out) - mem_modes |= nir_var_mem_shared; - else - return NULL; - - nir_intrinsic_set_memory_modes(intrin, mem_modes); - - return NIR_LOWER_INSTR_PROGRESS; -} - -static nir_def * -lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state) -{ - lower_ngg_ms_state *s = (lower_ngg_ms_state *) state; - - if (instr->type != nir_instr_type_intrinsic) - return NULL; - - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - - switch (intrin->intrinsic) { - case nir_intrinsic_store_per_vertex_output: - case nir_intrinsic_store_per_primitive_output: - ms_store_arrayed_output_intrin(b, intrin, s); - return NIR_LOWER_INSTR_PROGRESS_REPLACE; - case nir_intrinsic_barrier: - return update_ms_barrier(b, intrin, s); - case nir_intrinsic_load_workgroup_index: - return lower_ms_load_workgroup_index(b, intrin, s); - case nir_intrinsic_set_vertex_and_primitive_count: - return lower_ms_set_vertex_and_primitive_count(b, intrin, s); - default: - unreachable("Not a lowerable mesh shader intrinsic."); - } -} - -static bool -filter_ms_intrinsic(const nir_instr *instr, - UNUSED const void *s) -{ - if (instr->type != nir_instr_type_intrinsic) - return false; - - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - return intrin->intrinsic == nir_intrinsic_store_output || - intrin->intrinsic == nir_intrinsic_load_output || - intrin->intrinsic == nir_intrinsic_store_per_vertex_output || - intrin->intrinsic == nir_intrinsic_store_per_primitive_output || - intrin->intrinsic == nir_intrinsic_barrier || - intrin->intrinsic == nir_intrinsic_load_workgroup_index || - intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count; -} - -static void -lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s) -{ - nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s); -} - -static void -ms_emit_arrayed_outputs(nir_builder *b, - nir_def *invocation_index, - uint64_t mask, - lower_ngg_ms_state *s) -{ - nir_def *zero = nir_imm_int(b, 0); - - u_foreach_bit64(slot, mask) { - /* Should not occur here, handled separately. */ - assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES); - - unsigned component_mask = s->out.infos[slot].components_mask; - - while (component_mask) { - int start_comp = 0, num_components = 1; - u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components); - - nir_def *load = - ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp, - num_components, 32, s); - - for (int i = 0; i < num_components; i++) - s->out.outputs[slot][start_comp + i] = nir_channel(b, load, i); - } - } -} - -static void -ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s) -{ - /* Initialize NIR variables for same-invocation outputs. */ - uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask; - - u_foreach_bit64(slot, same_invocation_output_mask) { - for (unsigned comp = 0; comp < 4; ++comp) { - unsigned idx = slot * 4 + comp; - s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output"); - } - } -} - -static void -ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s) -{ - /* Workgroup ID should have been lowered to workgroup index. */ - assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID)); - - /* No need to do anything if the shader doesn't use the workgroup index. */ - if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX)) - return; - - b->cursor = nir_before_impl(b->impl); - - /* Legacy fast launch mode (FAST_LAUNCH=1): - * - * The HW doesn't support a proper workgroup index for vertex processing stages, - * so we use the vertex ID which is equivalent to the index of the current workgroup - * within the current dispatch. - * - * Due to the register programming of mesh shaders, this value is only filled for - * the first invocation of the first wave. To let other waves know, we use LDS. - */ - nir_def *workgroup_index = nir_load_vertex_id_zero_base(b); - - if (s->api_workgroup_size <= s->wave_size) { - /* API workgroup is small, so we don't need to use LDS. */ - s->workgroup_index = nir_read_first_invocation(b, workgroup_index); - return; - } - - unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index; - - nir_def *zero = nir_imm_int(b, 0); - nir_def *dont_care = nir_undef(b, 1, 32); - nir_def *loaded_workgroup_index = NULL; - - /* Use elect to make sure only 1 invocation uses LDS. */ - nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); - { - nir_def *wave_id = nir_load_subgroup_id(b); - nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0)); - { - nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr); - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_mem_shared); - } - nir_push_else(b, if_wave_0); - { - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_mem_shared); - loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr); - } - nir_pop_if(b, if_wave_0); - - workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index); - } - nir_pop_if(b, if_elected); - - workgroup_index = nir_if_phi(b, workgroup_index, dont_care); - s->workgroup_index = nir_read_first_invocation(b, workgroup_index); -} - -static void -set_ms_final_output_counts(nir_builder *b, - lower_ngg_ms_state *s, - nir_def **out_num_prm, - nir_def **out_num_vtx) -{ - /* The spec allows the numbers to be divergent, and in that case we need to - * use the values from the first invocation. Also the HW requires us to set - * both to 0 if either was 0. - * - * These are already done by the lowering. - */ - nir_def *num_prm = nir_load_var(b, s->primitive_count_var); - nir_def *num_vtx = nir_load_var(b, s->vertex_count_var); - - if (s->hw_workgroup_size <= s->wave_size) { - /* Single-wave mesh shader workgroup. */ - alloc_vertices_and_primitives(b, num_vtx, num_prm); - *out_num_prm = num_prm; - *out_num_vtx = num_vtx; - return; - } - - /* Multi-wave mesh shader workgroup: - * We need to use LDS to distribute the correct values to the other waves. - * - * TODO: - * If we can prove that the values are workgroup-uniform, we can skip this - * and just use whatever the current wave has. However, NIR divergence analysis - * currently doesn't support this. - */ - - nir_def *zero = nir_imm_int(b, 0); - - nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0)); - { - nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); - { - nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero, - .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); - } - nir_pop_if(b, if_elected); - - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_mem_shared); - - alloc_vertices_and_primitives(b, num_vtx, num_prm); - } - nir_push_else(b, if_wave_0); - { - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_mem_shared); - - nir_def *prm_vtx = NULL; - nir_def *dont_care_2x32 = nir_undef(b, 2, 32); - nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); - { - prm_vtx = nir_load_shared(b, 2, 32, zero, - .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); - } - nir_pop_if(b, if_elected); - - prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32); - num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0)); - num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1)); - - nir_store_var(b, s->primitive_count_var, num_prm, 0x1); - nir_store_var(b, s->vertex_count_var, num_vtx, 0x1); - } - nir_pop_if(b, if_wave_0); - - *out_num_prm = nir_load_var(b, s->primitive_count_var); - *out_num_vtx = nir_load_var(b, s->vertex_count_var); -} - -static void -ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask, - nir_def *idx, lower_ngg_ms_state *s) -{ - if (!outputs_mask) - return; - - nir_def *ring = nir_load_ring_attr_amd(b); - nir_def *off = nir_load_ring_attr_offset_amd(b); - nir_def *zero = nir_imm_int(b, 0); - - u_foreach_bit64 (slot, outputs_mask) { - if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31) - continue; - - nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32); - nir_def *store_val = nir_undef(b, 4, 32); - unsigned store_val_components = 0; - for (unsigned c = 0; c < 4; ++c) { - if (s->out.outputs[slot][c]) { - store_val = nir_vector_insert_imm(b, store_val, s->out.outputs[slot][c], c); - store_val_components = c + 1; - } - } - - store_val = nir_trim_vector(b, store_val, store_val_components); - nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx, - .memory_modes = nir_var_shader_out, - .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD); - } -} - -static nir_def * -ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s) -{ - /* Primitive connectivity data: describes which vertices the primitive uses. */ - nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim); - nir_def *indices_loaded = NULL; - nir_def *cull_flag = NULL; - - if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { - nir_def *indices[3] = {0}; - for (unsigned c = 0; c < s->vertices_per_prim; ++c) - indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]); - indices_loaded = nir_vec(b, indices, s->vertices_per_prim); - } else { - indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr); - indices_loaded = nir_u2u32(b, indices_loaded); - } - - if (s->uses_cull_flags) { - nir_def *loaded_cull_flag = NULL; - if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) - loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]); - else - loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr)); - - cull_flag = nir_i2b(b, loaded_cull_flag); - } - - nir_def *indices[3]; - nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u); - - for (unsigned i = 0; i < s->vertices_per_prim; ++i) { - indices[i] = nir_channel(b, indices_loaded, i); - indices[i] = nir_umin(b, indices[i], max_vtx_idx); - } - - return emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->gfx_level); -} - -static nir_def * -ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s) -{ - nir_def *prim_exp_arg_ch2 = NULL; - - if (outputs_mask) { - /* When layer, viewport etc. are per-primitive, they need to be encoded in - * the primitive export instruction's second channel. The encoding is: - * - * --- GFX10.3 --- - * bits 31..30: VRS rate Y - * bits 29..28: VRS rate X - * bits 23..20: viewport - * bits 19..17: layer - * - * --- GFX11 --- - * bits 31..28: VRS rate enum - * bits 23..20: viewport - * bits 12..00: layer - */ - prim_exp_arg_ch2 = nir_imm_int(b, 0); - - if (outputs_mask & VARYING_BIT_LAYER) { - nir_def *layer = - nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17); - prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer); - } - - if (outputs_mask & VARYING_BIT_VIEWPORT) { - nir_def *view = nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_VIEWPORT][0], 20); - prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view); - } - - if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) { - nir_def *rate = s->out.outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0]; - prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate); - } - } - - return prim_exp_arg_ch2; -} - -static void -ms_prim_gen_query(nir_builder *b, - nir_def *invocation_index, - nir_def *num_prm, - lower_ngg_ms_state *s) -{ - if (!s->has_query) - return; - - nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); - { - nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b)); - { - nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0); - } - nir_pop_if(b, if_shader_query); - } - nir_pop_if(b, if_invocation_index_zero); -} - -static void -ms_invocation_query(nir_builder *b, - nir_def *invocation_index, - lower_ngg_ms_state *s) -{ - if (!s->has_query) - return; - - nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); - { - nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b)); - { - nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size)); - } - nir_pop_if(b, if_pipeline_query); - } - nir_pop_if(b, if_invocation_index_zero); -} - -static void -emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters, - uint64_t per_vertex_outputs, lower_ngg_ms_state *s) -{ - ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s); - - if (exports) { - ac_nir_export_position(b, s->gfx_level, s->clipdist_enable_mask, - !s->has_param_exports, false, true, - s->per_vertex_outputs | VARYING_BIT_POS, &s->out, row); - } - - if (parameters) { - /* Export generic attributes on GFX10.3 - * (On GFX11 they are already stored in the attribute ring.) - */ - if (s->has_param_exports && s->gfx_level == GFX10_3) { - ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, &s->out); - } - - /* GFX11+: also store special outputs to the attribute ring so PS can load them. */ - if (s->gfx_level >= GFX11 && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK)) - ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s); - } -} - -static void -emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters, - uint64_t per_primitive_outputs, lower_ngg_ms_state *s) -{ - ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s); - - /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */ - if (s->insert_layer_output) { - s->out.outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b); - s->out.infos[VARYING_SLOT_LAYER].as_sysval_mask |= 1; - } - - if (exports) { - const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK; - nir_def *num_vtx = nir_load_var(b, s->vertex_count_var); - nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s); - nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s); - - nir_def *prim_exp_arg = prim_exp_arg_ch2 ? - nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1; - - ac_nir_export_primitive(b, prim_exp_arg, row); - } - - if (parameters) { - /* Export generic attributes on GFX10.3 - * (On GFX11 they are already stored in the attribute ring.) - */ - if (s->has_param_exports && s->gfx_level == GFX10_3) { - ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0, &s->out); - } - - /* GFX11+: also store special outputs to the attribute ring so PS can load them. */ - if (s->gfx_level >= GFX11) - ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s); - } -} - -static void -emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start, - nir_def *count, bool exports, bool parameters, uint64_t mask, - void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool, - uint64_t, lower_ngg_ms_state *), - lower_ngg_ms_state *s) -{ - if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) { - assert(s->hw_workgroup_size % s->wave_size == 0); - const unsigned num_waves = s->hw_workgroup_size / s->wave_size; - - nir_loop *row_loop = nir_push_loop(b); - { - nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node)); - - nir_phi_instr *index = nir_phi_instr_create(b->shader); - nir_phi_instr *row = nir_phi_instr_create(b->shader); - nir_def_init(&index->instr, &index->def, 1, 32); - nir_def_init(&row->instr, &row->def, 1, 32); - - nir_phi_instr_add_src(index, preheader, invocation_index); - nir_phi_instr_add_src(row, preheader, row_start); - - nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count)); - { - nir_jump(b, nir_jump_break); - } - nir_pop_if(b, if_break); - - cb(b, &index->def, &row->def, exports, parameters, mask, s); - - nir_block *body = nir_cursor_current_block(b->cursor); - nir_phi_instr_add_src(index, body, - nir_iadd_imm(b, &index->def, s->hw_workgroup_size)); - nir_phi_instr_add_src(row, body, - nir_iadd_imm(b, &row->def, num_waves)); - - nir_instr_insert_before_cf_list(&row_loop->body, &row->instr); - nir_instr_insert_before_cf_list(&row_loop->body, &index->instr); - } - nir_pop_loop(b, row_loop); - } else { - nir_def *has_output = nir_ilt(b, invocation_index, count); - nir_if *if_has_output = nir_push_if(b, has_output); - { - cb(b, invocation_index, row_start, exports, parameters, mask, s); - } - nir_pop_if(b, if_has_output); - } -} - -static void -emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) -{ - /* We assume there is always a single end block in the shader. */ - nir_block *last_block = nir_impl_last_block(b->impl); - b->cursor = nir_after_block(last_block); - - nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP, - .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared); - - nir_def *num_prm; - nir_def *num_vtx; - - set_ms_final_output_counts(b, s, &num_prm, &num_vtx); - - nir_def *invocation_index = nir_load_local_invocation_index(b); - - ms_prim_gen_query(b, invocation_index, num_prm, s); - - nir_def *row_start = NULL; - if (s->fast_launch_2) - row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b); - - /* Load vertex/primitive attributes from shared memory and - * emit store_output intrinsics for them. - * - * Contrary to the semantics of the API mesh shader, these are now - * compliant with NGG HW semantics, meaning that these store the - * current thread's vertex attributes in a way the HW can export. - */ - - uint64_t per_vertex_outputs = - s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask; - uint64_t per_primitive_outputs = - s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK; - - /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */ - if (s->insert_layer_output) { - b->shader->info.outputs_written |= VARYING_BIT_LAYER; - b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER; - per_primitive_outputs |= VARYING_BIT_LAYER; - } - - const bool has_special_param_exports = - (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) || - (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK); - - const bool wait_attr_ring = must_wait_attr_ring(s->gfx_level, has_special_param_exports); - - /* Export vertices. */ - if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) { - emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true, - per_vertex_outputs, &emit_ms_vertex, s); - } - - /* Export primitives. */ - if (per_primitive_outputs || !wait_attr_ring) { - emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true, - per_primitive_outputs, &emit_ms_primitive, s); - } - - /* When we need to wait for attribute ring stores, we emit both position and primitive - * export instructions after a barrier to make sure both per-vertex and per-primitive - * attribute ring stores are finished before the GPU starts rasterization. - */ - if (wait_attr_ring) { - /* Wait for attribute stores to finish. */ - nir_barrier(b, .execution_scope = SCOPE_SUBGROUP, - .memory_scope = SCOPE_DEVICE, - .memory_semantics = NIR_MEMORY_RELEASE, - .memory_modes = nir_var_shader_out); - - /* Position/primitive export only */ - emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false, - per_vertex_outputs, &emit_ms_vertex, s); - emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false, - per_primitive_outputs, &emit_ms_primitive, s); - } -} - -static void -handle_smaller_ms_api_workgroup(nir_builder *b, - lower_ngg_ms_state *s) -{ - if (s->api_workgroup_size >= s->hw_workgroup_size) - return; - - /* Handle barriers manually when the API workgroup - * size is less than the HW workgroup size. - * - * The problem is that the real workgroup launched on NGG HW - * will be larger than the size specified by the API, and the - * extra waves need to keep up with barriers in the API waves. - * - * There are 2 different cases: - * 1. The whole API workgroup fits in a single wave. - * We can shrink the barriers to subgroup scope and - * don't need to insert any extra ones. - * 2. The API workgroup occupies multiple waves, but not - * all. In this case, we emit code that consumes every - * barrier on the extra waves. - */ - assert(s->hw_workgroup_size % s->wave_size == 0); - bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size; - bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size; - bool need_additional_barriers = scan_barriers && !can_shrink_barriers; - - unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves; - unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size); - - /* Scan the shader for workgroup barriers. */ - if (scan_barriers) { - bool has_any_workgroup_barriers = false; - - nir_foreach_block(block, b->impl) { - nir_foreach_instr_safe(instr, block) { - if (instr->type != nir_instr_type_intrinsic) - continue; - - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - bool is_workgroup_barrier = - intrin->intrinsic == nir_intrinsic_barrier && - nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP; - - if (!is_workgroup_barrier) - continue; - - if (can_shrink_barriers) { - /* Every API invocation runs in the first wave. - * In this case, we can change the barriers to subgroup scope - * and avoid adding additional barriers. - */ - nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP); - nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP); - } else { - has_any_workgroup_barriers = true; - } - } - } - - need_additional_barriers &= has_any_workgroup_barriers; - } - - /* Extract the full control flow of the shader. */ - nir_cf_list extracted; - nir_cf_extract(&extracted, nir_before_impl(b->impl), - nir_after_cf_list(&b->impl->body)); - b->cursor = nir_before_impl(b->impl); - - /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */ - nir_def *invocation_index = nir_load_local_invocation_index(b); - nir_def *zero = nir_imm_int(b, 0); - - if (need_additional_barriers) { - /* First invocation stores 0 to number of API waves in flight. */ - nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); - { - nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr); - } - nir_pop_if(b, if_first_in_workgroup); - - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_shader_out | nir_var_mem_shared); - } - - nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size); - nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation); - { - nir_cf_reinsert(&extracted, b->cursor); - b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list); - - if (need_additional_barriers) { - /* One invocation in each API wave decrements the number of API waves in flight. */ - nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1)); - { - nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u), - .base = api_waves_in_flight_addr, - .atomic_op = nir_atomic_op_iadd); - } - nir_pop_if(b, if_elected_again); - - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_shader_out | nir_var_mem_shared); - } - - ms_invocation_query(b, invocation_index, s); - } - nir_pop_if(b, if_has_api_ms_invocation); - - if (need_additional_barriers) { - /* Make sure that waves that don't run any API invocations execute - * the same amount of barriers as those that do. - * - * We do this by executing a barrier until the number of API waves - * in flight becomes zero. - */ - nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation); - nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0); - nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms); - { - nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); - { - nir_loop *loop = nir_push_loop(b); - { - nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, - .memory_scope = SCOPE_WORKGROUP, - .memory_semantics = NIR_MEMORY_ACQ_REL, - .memory_modes = nir_var_shader_out | nir_var_mem_shared); - - nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr); - nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0)); - { - nir_jump(b, nir_jump_break); - } - nir_pop_if(b, if_break); - } - nir_pop_loop(b, loop); - } - nir_pop_if(b, if_elected); - } - nir_pop_if(b, if_wave_has_no_api_ms); - } -} - -static void -ms_move_output(ms_out_part *from, ms_out_part *to) -{ - uint64_t loc = util_logbase2_64(from->mask); - uint64_t bit = BITFIELD64_BIT(loc); - from->mask ^= bit; - to->mask |= bit; -} - -static void -ms_calculate_arrayed_output_layout(ms_out_mem_layout *l, - unsigned max_vertices, - unsigned max_primitives) -{ - uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16; - uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16; - l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16); - l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size; - - uint32_t scratch_ring_vtx_attr_size = - util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16; - l->scratch_ring.prm_attr.addr = - ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16); -} - -static ms_out_mem_layout -ms_calculate_output_layout(enum amd_gfx_level gfx_level, unsigned api_shared_size, - uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask, - uint64_t cross_invocation_output_access, unsigned max_vertices, - unsigned max_primitives, unsigned vertices_per_prim) -{ - /* These outputs always need export instructions and can't use the attributes ring. */ - const uint64_t always_export_mask = - VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 | - VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT | - VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER | - BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | - BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); - - const bool use_attr_ring = gfx_level >= GFX11; - const uint64_t attr_ring_per_vertex_output_mask = - use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0; - const uint64_t attr_ring_per_primitive_output_mask = - use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0; - - const uint64_t lds_per_vertex_output_mask = - per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access & - ~SPECIAL_MS_OUT_MASK; - const uint64_t lds_per_primitive_output_mask = - per_primitive_output_mask & ~attr_ring_per_primitive_output_mask & - cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK; - - const bool cross_invocation_indices = - cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES); - const bool cross_invocation_cull_primitive = - cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); - - /* Shared memory used by the API shader. */ - ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } }; - - /* GFX11+: use attribute ring for all generic attributes. */ - l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask; - l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask; - - /* Outputs without cross-invocation access can be stored in variables. */ - l.var.vtx_attr.mask = - per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access; - l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask & - ~cross_invocation_output_access; - - /* Workgroup information, see ms_workgroup_* for the layout. */ - l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16); - l.lds.total_size = l.lds.workgroup_info_addr + 16; - - /* Per-vertex and per-primitive output attributes. - * Outputs without cross-invocation access are not included here. - * First, try to put all outputs into LDS (shared memory). - * If they don't fit, try to move them to VRAM one by one. - */ - l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16); - l.lds.vtx_attr.mask = lds_per_vertex_output_mask; - l.lds.prm_attr.mask = lds_per_primitive_output_mask; - ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); - - /* NGG shaders can only address up to 32K LDS memory. - * The spec requires us to allow the application to use at least up to 28K - * shared memory. Additionally, we reserve 2K for driver internal use - * (eg. primitive indices and such, see below). - * - * Move the outputs that do not fit LDS, to VRAM. - * Start with per-primitive attributes, because those are grouped at the end. - */ - const unsigned usable_lds_kbytes = - (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31; - while (l.lds.total_size >= usable_lds_kbytes * 1024) { - if (l.lds.prm_attr.mask) - ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr); - else if (l.lds.vtx_attr.mask) - ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr); - else - unreachable("API shader uses too much shared memory."); - - ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); - } - - if (cross_invocation_indices) { - /* Indices: flat array of 8-bit vertex indices for each primitive. */ - l.lds.indices_addr = ALIGN(l.lds.total_size, 16); - l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim; - } - - if (cross_invocation_cull_primitive) { - /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */ - l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16); - l.lds.total_size = l.lds.cull_flags_addr + max_primitives; - } - - /* NGG is only allowed to address up to 32K of LDS. */ - assert(l.lds.total_size <= 32 * 1024); - return l; -} - -void -ac_nir_lower_ngg_mesh(nir_shader *shader, - enum amd_gfx_level gfx_level, - uint32_t clipdist_enable_mask, - const uint8_t *vs_output_param_offset, - bool has_param_exports, - bool *out_needs_scratch_ring, - unsigned wave_size, - unsigned hw_workgroup_size, - bool multiview, - bool has_query, - bool fast_launch_2) -{ - unsigned vertices_per_prim = - mesa_vertices_per_prim(shader->info.mesh.primitive_type); - - uint64_t per_vertex_outputs = - shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK; - uint64_t per_primitive_outputs = - shader->info.per_primitive_outputs & shader->info.outputs_written; - - /* Whether the shader uses CullPrimitiveEXT */ - bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); - /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */ - uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access | - shader->info.outputs_accessed_indirectly; - - unsigned max_vertices = shader->info.mesh.max_vertices_out; - unsigned max_primitives = shader->info.mesh.max_primitives_out; - - ms_out_mem_layout layout = ms_calculate_output_layout( - gfx_level, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs, - cross_invocation_access, max_vertices, max_primitives, vertices_per_prim); - - shader->info.shared_size = layout.lds.total_size; - *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask; - - /* The workgroup size that is specified by the API shader may be different - * from the size of the workgroup that actually runs on the HW, due to the - * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed. - * - * Therefore, we must make sure that when the API workgroup size is smaller, - * we don't run the API shader on more HW invocations than is necessary. - */ - unsigned api_workgroup_size = shader->info.workgroup_size[0] * - shader->info.workgroup_size[1] * - shader->info.workgroup_size[2]; - - lower_ngg_ms_state state = { - .layout = layout, - .wave_size = wave_size, - .per_vertex_outputs = per_vertex_outputs, - .per_primitive_outputs = per_primitive_outputs, - .vertices_per_prim = vertices_per_prim, - .api_workgroup_size = api_workgroup_size, - .hw_workgroup_size = hw_workgroup_size, - .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER), - .uses_cull_flags = uses_cull, - .gfx_level = gfx_level, - .fast_launch_2 = fast_launch_2, - .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size, - .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size, - .clipdist_enable_mask = clipdist_enable_mask, - .vs_output_param_offset = vs_output_param_offset, - .has_param_exports = has_param_exports, - .has_query = has_query, - }; - - nir_function_impl *impl = nir_shader_get_entrypoint(shader); - assert(impl); - - state.vertex_count_var = - nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var"); - state.primitive_count_var = - nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var"); - - nir_builder builder = nir_builder_at(nir_before_impl(impl)); - nir_builder *b = &builder; /* This is to avoid the & */ - - handle_smaller_ms_api_workgroup(b, &state); - if (!fast_launch_2) - ms_emit_legacy_workgroup_index(b, &state); - ms_create_same_invocation_vars(b, &state); - nir_metadata_preserve(impl, nir_metadata_none); - - lower_ms_intrinsics(shader, &state); - - emit_ms_finale(b, &state); - nir_metadata_preserve(impl, nir_metadata_none); - - /* Cleanup */ - nir_lower_vars_to_ssa(shader); - nir_remove_dead_variables(shader, nir_var_function_temp, NULL); - nir_lower_alu_to_scalar(shader, NULL, NULL); - nir_lower_phis_to_scalar(shader, true); - - /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup, - * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for - * all load_local_invocation_index. - */ - if (fast_launch_2 && api_workgroup_size == hw_workgroup_size && - ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) + - (shader->info.workgroup_size[2] == 1)) == 2) { - nir_lower_compute_system_values_options csv_options = { - .lower_local_invocation_index = true, - }; - nir_lower_compute_system_values(shader, &csv_options); - } - - nir_validate_shader(shader, "after emitting NGG MS"); -} diff --git a/src/amd/common/ac_nir_lower_ngg_mesh.c b/src/amd/common/ac_nir_lower_ngg_mesh.c new file mode 100644 index 00000000000..160701083f8 --- /dev/null +++ b/src/amd/common/ac_nir_lower_ngg_mesh.c @@ -0,0 +1,1463 @@ +/* + * Copyright © 2021 Valve Corporation + * + * SPDX-License-Identifier: MIT + */ + +#include "ac_nir.h" +#include "ac_nir_helpers.h" + +#include "nir_builder.h" + +#define SPECIAL_MS_OUT_MASK \ + (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \ + BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \ + BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) + +#define MS_PRIM_ARG_EXP_MASK \ + (VARYING_BIT_LAYER | \ + VARYING_BIT_VIEWPORT | \ + VARYING_BIT_PRIMITIVE_SHADING_RATE) + +#define MS_VERT_ARG_EXP_MASK \ + (VARYING_BIT_CULL_DIST0 | \ + VARYING_BIT_CULL_DIST1 | \ + VARYING_BIT_CLIP_DIST0 | \ + VARYING_BIT_CLIP_DIST1 | \ + VARYING_BIT_PSIZ) + +/* LDS layout of Mesh Shader workgroup info. */ +enum { + /* DW0: number of primitives */ + lds_ms_num_prims = 0, + /* DW1: number of vertices */ + lds_ms_num_vtx = 4, + /* DW2: workgroup index within the current dispatch */ + lds_ms_wg_index = 8, + /* DW3: number of API workgroups in flight */ + lds_ms_num_api_waves = 12, +}; + +/* Potential location for Mesh Shader outputs. */ +typedef enum { + ms_out_mode_lds, + ms_out_mode_scratch_ring, + ms_out_mode_attr_ring, + ms_out_mode_var, +} ms_out_mode; + +typedef struct +{ + uint64_t mask; /* Mask of output locations */ + uint32_t addr; /* Base address */ +} ms_out_part; + +typedef struct +{ + /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */ + struct { + uint32_t workgroup_info_addr; + ms_out_part vtx_attr; + ms_out_part prm_attr; + uint32_t indices_addr; + uint32_t cull_flags_addr; + uint32_t total_size; + } lds; + + /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. + * Not to be confused with scratch memory. + */ + struct { + ms_out_part vtx_attr; + ms_out_part prm_attr; + } scratch_ring; + + /* VRAM attributes ring (GFX11 only) for all non-position outputs. + * GFX11 doesn't have to reload attributes from this ring at the end of the shader. + */ + struct { + ms_out_part vtx_attr; + ms_out_part prm_attr; + } attr_ring; + + /* Outputs without cross-invocation access can be stored in variables. */ + struct { + ms_out_part vtx_attr; + ms_out_part prm_attr; + } var; +} ms_out_mem_layout; + +typedef struct +{ + enum amd_gfx_level gfx_level; + bool fast_launch_2; + bool vert_multirow_export; + bool prim_multirow_export; + + ms_out_mem_layout layout; + uint64_t per_vertex_outputs; + uint64_t per_primitive_outputs; + unsigned vertices_per_prim; + + unsigned wave_size; + unsigned api_workgroup_size; + unsigned hw_workgroup_size; + + nir_def *workgroup_index; + nir_variable *out_variables[VARYING_SLOT_MAX * 4]; + nir_variable *primitive_count_var; + nir_variable *vertex_count_var; + + ac_nir_prerast_out out; + + /* True if the lowering needs to insert the layer output. */ + bool insert_layer_output; + /* True if cull flags are used */ + bool uses_cull_flags; + + uint32_t clipdist_enable_mask; + const uint8_t *vs_output_param_offset; + bool has_param_exports; + + /* True if the lowering needs to insert shader query. */ + bool has_query; +} lower_ngg_ms_state; + +static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports) +{ + return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports; +} + +static void +ms_store_prim_indices(nir_builder *b, + nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + /* EXT_mesh_shader primitive indices: array of vectors. + * They don't count as per-primitive outputs, but the array is indexed + * by the primitive index, so they are practically per-primitive. + */ + assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); + assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); + + const unsigned component_offset = nir_intrinsic_component(intrin); + nir_def *store_val = intrin->src[0].ssa; + assert(store_val->num_components <= 3); + + if (store_val->num_components > s->vertices_per_prim) + store_val = nir_trim_vector(b, store_val, s->vertices_per_prim); + + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { + for (unsigned c = 0; c < store_val->num_components; ++c) { + const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset; + nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1); + } + return; + } + + nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; + nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); + + /* The max vertex count is 256, so these indices always fit 8 bits. + * To reduce LDS use, store these as a flat array of 8-bit values. + */ + nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset); +} + +static void +ms_store_cull_flag(nir_builder *b, + nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + /* EXT_mesh_shader cull primitive: per-primitive bool. */ + assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); + assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); + assert(nir_intrinsic_component(intrin) == 0); + assert(nir_intrinsic_write_mask(intrin) == 1); + + nir_def *store_val = intrin->src[0].ssa; + + assert(store_val->num_components == 1); + assert(store_val->bit_size == 1); + + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) { + nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1); + return; + } + + nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; + nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); + + /* To reduce LDS use, store these as an array of 8-bit values. */ + nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr); +} + +static nir_def * +ms_arrayed_output_base_addr(nir_builder *b, + nir_def *arr_index, + unsigned mapped_location, + unsigned num_arrayed_outputs) +{ + /* Address offset of the array item (vertex or primitive). */ + unsigned arr_index_stride = num_arrayed_outputs * 16u; + nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride); + + /* IO address offset within the vertex or primitive data. */ + unsigned io_offset = mapped_location * 16u; + nir_def *io_off = nir_imm_int(b, io_offset); + + return nir_iadd_nuw(b, arr_index_off, io_off); +} + +static void +update_ms_output_info(const nir_io_semantics io_sem, + const nir_src *base_offset_src, + const uint32_t write_mask, + const unsigned component_offset, + const unsigned bit_size, + const ms_out_part *out, + lower_ngg_ms_state *s) +{ + const uint32_t components_mask = write_mask << component_offset; + + /* 64-bit outputs should have already been lowered to 32-bit. */ + assert(bit_size <= 32); + assert(components_mask <= 0xf); + + /* When the base offset is constant, only mark the components of the current slot as used. + * Otherwise, mark the components of all possibly affected slots as used. + */ + const unsigned base_off_start = nir_src_is_const(*base_offset_src) ? nir_src_as_uint(*base_offset_src) : 0; + const unsigned num_slots = nir_src_is_const(*base_offset_src) ? 1 : io_sem.num_slots; + + for (unsigned base_off = base_off_start; base_off < num_slots; ++base_off) { + ac_nir_prerast_per_output_info *info = &s->out.infos[io_sem.location + base_off]; + info->components_mask |= components_mask; + + if (!io_sem.no_sysval_output) + info->as_sysval_mask |= components_mask; + if (!io_sem.no_varying) + info->as_varying_mask |= components_mask; + } +} + +static const ms_out_part * +ms_get_out_layout_part(unsigned location, + shader_info *info, + ms_out_mode *out_mode, + lower_ngg_ms_state *s) +{ + uint64_t mask = BITFIELD64_BIT(location); + + if (info->per_primitive_outputs & mask) { + if (mask & s->layout.lds.prm_attr.mask) { + *out_mode = ms_out_mode_lds; + return &s->layout.lds.prm_attr; + } else if (mask & s->layout.scratch_ring.prm_attr.mask) { + *out_mode = ms_out_mode_scratch_ring; + return &s->layout.scratch_ring.prm_attr; + } else if (mask & s->layout.attr_ring.prm_attr.mask) { + *out_mode = ms_out_mode_attr_ring; + return &s->layout.attr_ring.prm_attr; + } else if (mask & s->layout.var.prm_attr.mask) { + *out_mode = ms_out_mode_var; + return &s->layout.var.prm_attr; + } + } else { + if (mask & s->layout.lds.vtx_attr.mask) { + *out_mode = ms_out_mode_lds; + return &s->layout.lds.vtx_attr; + } else if (mask & s->layout.scratch_ring.vtx_attr.mask) { + *out_mode = ms_out_mode_scratch_ring; + return &s->layout.scratch_ring.vtx_attr; + } else if (mask & s->layout.attr_ring.vtx_attr.mask) { + *out_mode = ms_out_mode_attr_ring; + return &s->layout.attr_ring.vtx_attr; + } else if (mask & s->layout.var.vtx_attr.mask) { + *out_mode = ms_out_mode_var; + return &s->layout.var.vtx_attr; + } + } + + unreachable("Couldn't figure out mesh shader output mode."); +} + +static void +ms_store_arrayed_output(nir_builder *b, + nir_src *base_off_src, + nir_def *store_val, + nir_def *arr_index, + const nir_io_semantics io_sem, + const unsigned component_offset, + const unsigned write_mask, + lower_ngg_ms_state *s) +{ + ms_out_mode out_mode; + const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s); + update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s); + + bool hi_16b = io_sem.high_16bits; + bool lo_16b = !hi_16b && store_val->bit_size == 16; + + unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location)); + unsigned num_outputs = util_bitcount64(out->mask); + unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0); + + nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs); + nir_def *base_offset = base_off_src->ssa; + nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u); + nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); + + if (out_mode == ms_out_mode_lds) { + nir_store_shared(b, store_val, addr, .base = const_off, + .write_mask = write_mask, .align_mul = 16, + .align_offset = const_off % 16); + } else if (out_mode == ms_out_mode_scratch_ring) { + nir_def *ring = nir_load_ring_mesh_scratch_amd(b); + nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b); + nir_def *zero = nir_imm_int(b, 0); + nir_store_buffer_amd(b, store_val, ring, addr, off, zero, + .base = const_off, + .write_mask = write_mask, + .memory_modes = nir_var_shader_out, + .access = ACCESS_COHERENT); + } else if (out_mode == ms_out_mode_attr_ring) { + /* GFX11+: Store params straight to the attribute ring. + * + * Even though the access pattern may not be the most optimal, + * this is still much better than reserving LDS and losing waves. + * (Also much better than storing and reloading from the scratch ring.) + */ + unsigned param_offset = s->vs_output_param_offset[io_sem.location]; + nir_def *ring = nir_load_ring_attr_amd(b); + nir_def *soffset = nir_load_ring_attr_offset_amd(b); + nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index, + .base = const_off + param_offset * 16, + .write_mask = write_mask, + .memory_modes = nir_var_shader_out, + .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD); + } else if (out_mode == ms_out_mode_var) { + unsigned write_mask_32 = write_mask; + if (store_val->bit_size > 32) { + /* Split 64-bit store values to 32-bit components. */ + store_val = nir_bitcast_vector(b, store_val, 32); + /* Widen the write mask so it is in 32-bit components. */ + write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32); + } + + u_foreach_bit(comp, write_mask_32) { + unsigned idx = io_sem.location * 4 + comp + component_offset; + nir_def *val = nir_channel(b, store_val, comp); + nir_def *v = nir_load_var(b, s->out_variables[idx]); + + if (lo_16b) { + nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v); + val = nir_pack_32_2x16_split(b, val, var_hi); + } else if (hi_16b) { + nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v); + val = nir_pack_32_2x16_split(b, var_lo, val); + } + + nir_store_var(b, s->out_variables[idx], val, 0x1); + } + } else { + unreachable("Invalid MS output mode for store"); + } +} + +static void +ms_store_arrayed_output_intrin(nir_builder *b, + nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); + + if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) { + ms_store_prim_indices(b, intrin, s); + return; + } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) { + ms_store_cull_flag(b, intrin, s); + return; + } + + unsigned component_offset = nir_intrinsic_component(intrin); + unsigned write_mask = nir_intrinsic_write_mask(intrin); + + nir_def *store_val = intrin->src[0].ssa; + nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; + nir_src *base_off_src = nir_get_io_offset_src(intrin); + + if (store_val->bit_size < 32) { + /* Split 16-bit output stores to ensure each 16-bit component is stored + * in the correct location, without overwriting the other 16 bits there. + */ + u_foreach_bit(c, write_mask) { + nir_def *store_component = nir_channel(b, store_val, c); + ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s); + } + } else { + ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s); + } +} + +static nir_def * +ms_load_arrayed_output(nir_builder *b, + nir_def *arr_index, + nir_def *base_offset, + unsigned location, + unsigned component_offset, + unsigned num_components, + unsigned load_bit_size, + lower_ngg_ms_state *s) +{ + ms_out_mode out_mode; + const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s); + + unsigned component_addr_off = component_offset * 4; + unsigned num_outputs = util_bitcount64(out->mask); + unsigned const_off = out->addr + component_offset * 4; + + /* Use compacted location instead of the original semantic location. */ + unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location)); + + nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs); + nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16); + nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off); + + if (out_mode == ms_out_mode_lds) { + return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16, + .align_offset = component_addr_off % 16, + .base = const_off); + } else if (out_mode == ms_out_mode_scratch_ring) { + nir_def *ring = nir_load_ring_mesh_scratch_amd(b); + nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b); + nir_def *zero = nir_imm_int(b, 0); + return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero, + .base = const_off, + .memory_modes = nir_var_shader_out, + .access = ACCESS_COHERENT); + } else if (out_mode == ms_out_mode_var) { + assert(load_bit_size == 32); + nir_def *arr[8] = {0}; + for (unsigned comp = 0; comp < num_components; ++comp) { + unsigned idx = location * 4 + comp + component_addr_off; + arr[comp] = nir_load_var(b, s->out_variables[idx]); + } + return nir_vec(b, arr, num_components); + } else { + unreachable("Invalid MS output mode for load"); + } +} + +static nir_def * +lower_ms_load_workgroup_index(nir_builder *b, + UNUSED nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + return s->workgroup_index; +} + +static nir_def * +lower_ms_set_vertex_and_primitive_count(nir_builder *b, + nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + /* If either the number of vertices or primitives is zero, set both of them to zero. */ + nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa); + nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa); + nir_def *zero = nir_imm_int(b, 0); + nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero); + num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx); + num_prm = nir_bcsel(b, is_either_zero, zero, num_prm); + + nir_store_var(b, s->vertex_count_var, num_vtx, 0x1); + nir_store_var(b, s->primitive_count_var, num_prm, 0x1); + + return NIR_LOWER_INSTR_PROGRESS_REPLACE; +} + +static nir_def * +update_ms_barrier(nir_builder *b, + nir_intrinsic_instr *intrin, + lower_ngg_ms_state *s) +{ + /* Output loads and stores are lowered to shared memory access, + * so we have to update the barriers to also reflect this. + */ + unsigned mem_modes = nir_intrinsic_memory_modes(intrin); + if (mem_modes & nir_var_shader_out) + mem_modes |= nir_var_mem_shared; + else + return NULL; + + nir_intrinsic_set_memory_modes(intrin, mem_modes); + + return NIR_LOWER_INSTR_PROGRESS; +} + +static nir_def * +lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state) +{ + lower_ngg_ms_state *s = (lower_ngg_ms_state *) state; + + if (instr->type != nir_instr_type_intrinsic) + return NULL; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + + switch (intrin->intrinsic) { + case nir_intrinsic_store_per_vertex_output: + case nir_intrinsic_store_per_primitive_output: + ms_store_arrayed_output_intrin(b, intrin, s); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + case nir_intrinsic_barrier: + return update_ms_barrier(b, intrin, s); + case nir_intrinsic_load_workgroup_index: + return lower_ms_load_workgroup_index(b, intrin, s); + case nir_intrinsic_set_vertex_and_primitive_count: + return lower_ms_set_vertex_and_primitive_count(b, intrin, s); + default: + unreachable("Not a lowerable mesh shader intrinsic."); + } +} + +static bool +filter_ms_intrinsic(const nir_instr *instr, + UNUSED const void *s) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + return intrin->intrinsic == nir_intrinsic_store_output || + intrin->intrinsic == nir_intrinsic_load_output || + intrin->intrinsic == nir_intrinsic_store_per_vertex_output || + intrin->intrinsic == nir_intrinsic_store_per_primitive_output || + intrin->intrinsic == nir_intrinsic_barrier || + intrin->intrinsic == nir_intrinsic_load_workgroup_index || + intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count; +} + +static void +lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s) +{ + nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s); +} + +static void +ms_emit_arrayed_outputs(nir_builder *b, + nir_def *invocation_index, + uint64_t mask, + lower_ngg_ms_state *s) +{ + nir_def *zero = nir_imm_int(b, 0); + + u_foreach_bit64(slot, mask) { + /* Should not occur here, handled separately. */ + assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES); + + unsigned component_mask = s->out.infos[slot].components_mask; + + while (component_mask) { + int start_comp = 0, num_components = 1; + u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components); + + nir_def *load = + ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp, + num_components, 32, s); + + for (int i = 0; i < num_components; i++) + s->out.outputs[slot][start_comp + i] = nir_channel(b, load, i); + } + } +} + +static void +ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s) +{ + /* Initialize NIR variables for same-invocation outputs. */ + uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask; + + u_foreach_bit64(slot, same_invocation_output_mask) { + for (unsigned comp = 0; comp < 4; ++comp) { + unsigned idx = slot * 4 + comp; + s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output"); + } + } +} + +static void +ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s) +{ + /* Workgroup ID should have been lowered to workgroup index. */ + assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID)); + + /* No need to do anything if the shader doesn't use the workgroup index. */ + if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX)) + return; + + b->cursor = nir_before_impl(b->impl); + + /* Legacy fast launch mode (FAST_LAUNCH=1): + * + * The HW doesn't support a proper workgroup index for vertex processing stages, + * so we use the vertex ID which is equivalent to the index of the current workgroup + * within the current dispatch. + * + * Due to the register programming of mesh shaders, this value is only filled for + * the first invocation of the first wave. To let other waves know, we use LDS. + */ + nir_def *workgroup_index = nir_load_vertex_id_zero_base(b); + + if (s->api_workgroup_size <= s->wave_size) { + /* API workgroup is small, so we don't need to use LDS. */ + s->workgroup_index = nir_read_first_invocation(b, workgroup_index); + return; + } + + unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index; + + nir_def *zero = nir_imm_int(b, 0); + nir_def *dont_care = nir_undef(b, 1, 32); + nir_def *loaded_workgroup_index = NULL; + + /* Use elect to make sure only 1 invocation uses LDS. */ + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + nir_def *wave_id = nir_load_subgroup_id(b); + nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0)); + { + nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr); + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + } + nir_push_else(b, if_wave_0); + { + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr); + } + nir_pop_if(b, if_wave_0); + + workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index); + } + nir_pop_if(b, if_elected); + + workgroup_index = nir_if_phi(b, workgroup_index, dont_care); + s->workgroup_index = nir_read_first_invocation(b, workgroup_index); +} + +static void +set_ms_final_output_counts(nir_builder *b, + lower_ngg_ms_state *s, + nir_def **out_num_prm, + nir_def **out_num_vtx) +{ + /* The spec allows the numbers to be divergent, and in that case we need to + * use the values from the first invocation. Also the HW requires us to set + * both to 0 if either was 0. + * + * These are already done by the lowering. + */ + nir_def *num_prm = nir_load_var(b, s->primitive_count_var); + nir_def *num_vtx = nir_load_var(b, s->vertex_count_var); + + if (s->hw_workgroup_size <= s->wave_size) { + /* Single-wave mesh shader workgroup. */ + nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prm, 12), num_vtx); + nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ); + + *out_num_prm = num_prm; + *out_num_vtx = num_vtx; + return; + } + + /* Multi-wave mesh shader workgroup: + * We need to use LDS to distribute the correct values to the other waves. + * + * TODO: + * If we can prove that the values are workgroup-uniform, we can skip this + * and just use whatever the current wave has. However, NIR divergence analysis + * currently doesn't support this. + */ + + nir_def *zero = nir_imm_int(b, 0); + + nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0)); + { + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero, + .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); + } + nir_pop_if(b, if_elected); + + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prm, 12), num_vtx); + nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ); + } + nir_push_else(b, if_wave_0); + { + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + nir_def *prm_vtx = NULL; + nir_def *dont_care_2x32 = nir_undef(b, 2, 32); + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + prm_vtx = nir_load_shared(b, 2, 32, zero, + .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); + } + nir_pop_if(b, if_elected); + + prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32); + num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0)); + num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1)); + + nir_store_var(b, s->primitive_count_var, num_prm, 0x1); + nir_store_var(b, s->vertex_count_var, num_vtx, 0x1); + } + nir_pop_if(b, if_wave_0); + + *out_num_prm = nir_load_var(b, s->primitive_count_var); + *out_num_vtx = nir_load_var(b, s->vertex_count_var); +} + +static void +ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask, + nir_def *idx, lower_ngg_ms_state *s) +{ + if (!outputs_mask) + return; + + nir_def *ring = nir_load_ring_attr_amd(b); + nir_def *off = nir_load_ring_attr_offset_amd(b); + nir_def *zero = nir_imm_int(b, 0); + + u_foreach_bit64 (slot, outputs_mask) { + if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31) + continue; + + nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32); + nir_def *store_val = nir_undef(b, 4, 32); + unsigned store_val_components = 0; + for (unsigned c = 0; c < 4; ++c) { + if (s->out.outputs[slot][c]) { + store_val = nir_vector_insert_imm(b, store_val, s->out.outputs[slot][c], c); + store_val_components = c + 1; + } + } + + store_val = nir_trim_vector(b, store_val, store_val_components); + nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx, + .memory_modes = nir_var_shader_out, + .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD); + } +} + +static nir_def * +ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s) +{ + /* Primitive connectivity data: describes which vertices the primitive uses. */ + nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim); + nir_def *indices_loaded = NULL; + nir_def *cull_flag = NULL; + + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { + nir_def *indices[3] = {0}; + for (unsigned c = 0; c < s->vertices_per_prim; ++c) + indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]); + indices_loaded = nir_vec(b, indices, s->vertices_per_prim); + } else { + indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr); + indices_loaded = nir_u2u32(b, indices_loaded); + } + + if (s->uses_cull_flags) { + nir_def *loaded_cull_flag = NULL; + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) + loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]); + else + loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr)); + + cull_flag = nir_i2b(b, loaded_cull_flag); + } + + nir_def *indices[3]; + nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u); + + for (unsigned i = 0; i < s->vertices_per_prim; ++i) { + indices[i] = nir_channel(b, indices_loaded, i); + indices[i] = nir_umin(b, indices[i], max_vtx_idx); + } + + return ac_nir_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->gfx_level); +} + +static nir_def * +ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s) +{ + nir_def *prim_exp_arg_ch2 = NULL; + + if (outputs_mask) { + /* When layer, viewport etc. are per-primitive, they need to be encoded in + * the primitive export instruction's second channel. The encoding is: + * + * --- GFX10.3 --- + * bits 31..30: VRS rate Y + * bits 29..28: VRS rate X + * bits 23..20: viewport + * bits 19..17: layer + * + * --- GFX11 --- + * bits 31..28: VRS rate enum + * bits 23..20: viewport + * bits 12..00: layer + */ + prim_exp_arg_ch2 = nir_imm_int(b, 0); + + if (outputs_mask & VARYING_BIT_LAYER) { + nir_def *layer = + nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17); + prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer); + } + + if (outputs_mask & VARYING_BIT_VIEWPORT) { + nir_def *view = nir_ishl_imm(b, s->out.outputs[VARYING_SLOT_VIEWPORT][0], 20); + prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view); + } + + if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) { + nir_def *rate = s->out.outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0]; + prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate); + } + } + + return prim_exp_arg_ch2; +} + +static void +ms_prim_gen_query(nir_builder *b, + nir_def *invocation_index, + nir_def *num_prm, + lower_ngg_ms_state *s) +{ + if (!s->has_query) + return; + + nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); + { + nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b)); + { + nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0); + } + nir_pop_if(b, if_shader_query); + } + nir_pop_if(b, if_invocation_index_zero); +} + +static void +ms_invocation_query(nir_builder *b, + nir_def *invocation_index, + lower_ngg_ms_state *s) +{ + if (!s->has_query) + return; + + nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); + { + nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b)); + { + nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size)); + } + nir_pop_if(b, if_pipeline_query); + } + nir_pop_if(b, if_invocation_index_zero); +} + +static void +emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters, + uint64_t per_vertex_outputs, lower_ngg_ms_state *s) +{ + ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s); + + if (exports) { + ac_nir_export_position(b, s->gfx_level, s->clipdist_enable_mask, + !s->has_param_exports, false, true, + s->per_vertex_outputs | VARYING_BIT_POS, &s->out, row); + } + + if (parameters) { + /* Export generic attributes on GFX10.3 + * (On GFX11 they are already stored in the attribute ring.) + */ + if (s->has_param_exports && s->gfx_level == GFX10_3) { + ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, &s->out); + } + + /* GFX11+: also store special outputs to the attribute ring so PS can load them. */ + if (s->gfx_level >= GFX11 && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK)) + ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s); + } +} + +static void +emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters, + uint64_t per_primitive_outputs, lower_ngg_ms_state *s) +{ + ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s); + + /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */ + if (s->insert_layer_output) { + s->out.outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b); + s->out.infos[VARYING_SLOT_LAYER].as_sysval_mask |= 1; + } + + if (exports) { + const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK; + nir_def *num_vtx = nir_load_var(b, s->vertex_count_var); + nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s); + nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s); + + nir_def *prim_exp_arg = prim_exp_arg_ch2 ? + nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1; + + ac_nir_export_primitive(b, prim_exp_arg, row); + } + + if (parameters) { + /* Export generic attributes on GFX10.3 + * (On GFX11 they are already stored in the attribute ring.) + */ + if (s->has_param_exports && s->gfx_level == GFX10_3) { + ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0, &s->out); + } + + /* GFX11+: also store special outputs to the attribute ring so PS can load them. */ + if (s->gfx_level >= GFX11) + ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s); + } +} + +static void +emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start, + nir_def *count, bool exports, bool parameters, uint64_t mask, + void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool, + uint64_t, lower_ngg_ms_state *), + lower_ngg_ms_state *s) +{ + if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) { + assert(s->hw_workgroup_size % s->wave_size == 0); + const unsigned num_waves = s->hw_workgroup_size / s->wave_size; + + nir_loop *row_loop = nir_push_loop(b); + { + nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node)); + + nir_phi_instr *index = nir_phi_instr_create(b->shader); + nir_phi_instr *row = nir_phi_instr_create(b->shader); + nir_def_init(&index->instr, &index->def, 1, 32); + nir_def_init(&row->instr, &row->def, 1, 32); + + nir_phi_instr_add_src(index, preheader, invocation_index); + nir_phi_instr_add_src(row, preheader, row_start); + + nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, if_break); + + cb(b, &index->def, &row->def, exports, parameters, mask, s); + + nir_block *body = nir_cursor_current_block(b->cursor); + nir_phi_instr_add_src(index, body, + nir_iadd_imm(b, &index->def, s->hw_workgroup_size)); + nir_phi_instr_add_src(row, body, + nir_iadd_imm(b, &row->def, num_waves)); + + nir_instr_insert_before_cf_list(&row_loop->body, &row->instr); + nir_instr_insert_before_cf_list(&row_loop->body, &index->instr); + } + nir_pop_loop(b, row_loop); + } else { + nir_def *has_output = nir_ilt(b, invocation_index, count); + nir_if *if_has_output = nir_push_if(b, has_output); + { + cb(b, invocation_index, row_start, exports, parameters, mask, s); + } + nir_pop_if(b, if_has_output); + } +} + +static void +emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) +{ + /* We assume there is always a single end block in the shader. */ + nir_block *last_block = nir_impl_last_block(b->impl); + b->cursor = nir_after_block(last_block); + + nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP, + .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared); + + nir_def *num_prm; + nir_def *num_vtx; + + set_ms_final_output_counts(b, s, &num_prm, &num_vtx); + + nir_def *invocation_index = nir_load_local_invocation_index(b); + + ms_prim_gen_query(b, invocation_index, num_prm, s); + + nir_def *row_start = NULL; + if (s->fast_launch_2) + row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b); + + /* Load vertex/primitive attributes from shared memory and + * emit store_output intrinsics for them. + * + * Contrary to the semantics of the API mesh shader, these are now + * compliant with NGG HW semantics, meaning that these store the + * current thread's vertex attributes in a way the HW can export. + */ + + uint64_t per_vertex_outputs = + s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask; + uint64_t per_primitive_outputs = + s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK; + + /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */ + if (s->insert_layer_output) { + b->shader->info.outputs_written |= VARYING_BIT_LAYER; + b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER; + per_primitive_outputs |= VARYING_BIT_LAYER; + } + + const bool has_special_param_exports = + (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) || + (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK); + + const bool wait_attr_ring = must_wait_attr_ring(s->gfx_level, has_special_param_exports); + + /* Export vertices. */ + if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) { + emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true, + per_vertex_outputs, &emit_ms_vertex, s); + } + + /* Export primitives. */ + if (per_primitive_outputs || !wait_attr_ring) { + emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true, + per_primitive_outputs, &emit_ms_primitive, s); + } + + /* When we need to wait for attribute ring stores, we emit both position and primitive + * export instructions after a barrier to make sure both per-vertex and per-primitive + * attribute ring stores are finished before the GPU starts rasterization. + */ + if (wait_attr_ring) { + /* Wait for attribute stores to finish. */ + nir_barrier(b, .execution_scope = SCOPE_SUBGROUP, + .memory_scope = SCOPE_DEVICE, + .memory_semantics = NIR_MEMORY_RELEASE, + .memory_modes = nir_var_shader_out); + + /* Position/primitive export only */ + emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false, + per_vertex_outputs, &emit_ms_vertex, s); + emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false, + per_primitive_outputs, &emit_ms_primitive, s); + } +} + +static void +handle_smaller_ms_api_workgroup(nir_builder *b, + lower_ngg_ms_state *s) +{ + if (s->api_workgroup_size >= s->hw_workgroup_size) + return; + + /* Handle barriers manually when the API workgroup + * size is less than the HW workgroup size. + * + * The problem is that the real workgroup launched on NGG HW + * will be larger than the size specified by the API, and the + * extra waves need to keep up with barriers in the API waves. + * + * There are 2 different cases: + * 1. The whole API workgroup fits in a single wave. + * We can shrink the barriers to subgroup scope and + * don't need to insert any extra ones. + * 2. The API workgroup occupies multiple waves, but not + * all. In this case, we emit code that consumes every + * barrier on the extra waves. + */ + assert(s->hw_workgroup_size % s->wave_size == 0); + bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size; + bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size; + bool need_additional_barriers = scan_barriers && !can_shrink_barriers; + + unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves; + unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size); + + /* Scan the shader for workgroup barriers. */ + if (scan_barriers) { + bool has_any_workgroup_barriers = false; + + nir_foreach_block(block, b->impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + bool is_workgroup_barrier = + intrin->intrinsic == nir_intrinsic_barrier && + nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP; + + if (!is_workgroup_barrier) + continue; + + if (can_shrink_barriers) { + /* Every API invocation runs in the first wave. + * In this case, we can change the barriers to subgroup scope + * and avoid adding additional barriers. + */ + nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP); + nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP); + } else { + has_any_workgroup_barriers = true; + } + } + } + + need_additional_barriers &= has_any_workgroup_barriers; + } + + /* Extract the full control flow of the shader. */ + nir_cf_list extracted; + nir_cf_extract(&extracted, nir_before_impl(b->impl), + nir_after_cf_list(&b->impl->body)); + b->cursor = nir_before_impl(b->impl); + + /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */ + nir_def *invocation_index = nir_load_local_invocation_index(b); + nir_def *zero = nir_imm_int(b, 0); + + if (need_additional_barriers) { + /* First invocation stores 0 to number of API waves in flight. */ + nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0)); + { + nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr); + } + nir_pop_if(b, if_first_in_workgroup); + + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_shader_out | nir_var_mem_shared); + } + + nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size); + nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation); + { + nir_cf_reinsert(&extracted, b->cursor); + b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list); + + if (need_additional_barriers) { + /* One invocation in each API wave decrements the number of API waves in flight. */ + nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1)); + { + nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u), + .base = api_waves_in_flight_addr, + .atomic_op = nir_atomic_op_iadd); + } + nir_pop_if(b, if_elected_again); + + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_shader_out | nir_var_mem_shared); + } + + ms_invocation_query(b, invocation_index, s); + } + nir_pop_if(b, if_has_api_ms_invocation); + + if (need_additional_barriers) { + /* Make sure that waves that don't run any API invocations execute + * the same amount of barriers as those that do. + * + * We do this by executing a barrier until the number of API waves + * in flight becomes zero. + */ + nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation); + nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0); + nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms); + { + nir_if *if_elected = nir_push_if(b, nir_elect(b, 1)); + { + nir_loop *loop = nir_push_loop(b); + { + nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, + .memory_scope = SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_shader_out | nir_var_mem_shared); + + nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr); + nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, if_break); + } + nir_pop_loop(b, loop); + } + nir_pop_if(b, if_elected); + } + nir_pop_if(b, if_wave_has_no_api_ms); + } +} + +static void +ms_move_output(ms_out_part *from, ms_out_part *to) +{ + uint64_t loc = util_logbase2_64(from->mask); + uint64_t bit = BITFIELD64_BIT(loc); + from->mask ^= bit; + to->mask |= bit; +} + +static void +ms_calculate_arrayed_output_layout(ms_out_mem_layout *l, + unsigned max_vertices, + unsigned max_primitives) +{ + uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16; + uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16; + l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16); + l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size; + + uint32_t scratch_ring_vtx_attr_size = + util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16; + l->scratch_ring.prm_attr.addr = + ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16); +} + +static ms_out_mem_layout +ms_calculate_output_layout(enum amd_gfx_level gfx_level, unsigned api_shared_size, + uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask, + uint64_t cross_invocation_output_access, unsigned max_vertices, + unsigned max_primitives, unsigned vertices_per_prim) +{ + /* These outputs always need export instructions and can't use the attributes ring. */ + const uint64_t always_export_mask = + VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 | + VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT | + VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER | + BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | + BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); + + const bool use_attr_ring = gfx_level >= GFX11; + const uint64_t attr_ring_per_vertex_output_mask = + use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0; + const uint64_t attr_ring_per_primitive_output_mask = + use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0; + + const uint64_t lds_per_vertex_output_mask = + per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access & + ~SPECIAL_MS_OUT_MASK; + const uint64_t lds_per_primitive_output_mask = + per_primitive_output_mask & ~attr_ring_per_primitive_output_mask & + cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK; + + const bool cross_invocation_indices = + cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES); + const bool cross_invocation_cull_primitive = + cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); + + /* Shared memory used by the API shader. */ + ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } }; + + /* GFX11+: use attribute ring for all generic attributes. */ + l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask; + l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask; + + /* Outputs without cross-invocation access can be stored in variables. */ + l.var.vtx_attr.mask = + per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access; + l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask & + ~cross_invocation_output_access; + + /* Workgroup information, see ms_workgroup_* for the layout. */ + l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16); + l.lds.total_size = l.lds.workgroup_info_addr + 16; + + /* Per-vertex and per-primitive output attributes. + * Outputs without cross-invocation access are not included here. + * First, try to put all outputs into LDS (shared memory). + * If they don't fit, try to move them to VRAM one by one. + */ + l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16); + l.lds.vtx_attr.mask = lds_per_vertex_output_mask; + l.lds.prm_attr.mask = lds_per_primitive_output_mask; + ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); + + /* NGG shaders can only address up to 32K LDS memory. + * The spec requires us to allow the application to use at least up to 28K + * shared memory. Additionally, we reserve 2K for driver internal use + * (eg. primitive indices and such, see below). + * + * Move the outputs that do not fit LDS, to VRAM. + * Start with per-primitive attributes, because those are grouped at the end. + */ + const unsigned usable_lds_kbytes = + (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31; + while (l.lds.total_size >= usable_lds_kbytes * 1024) { + if (l.lds.prm_attr.mask) + ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr); + else if (l.lds.vtx_attr.mask) + ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr); + else + unreachable("API shader uses too much shared memory."); + + ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); + } + + if (cross_invocation_indices) { + /* Indices: flat array of 8-bit vertex indices for each primitive. */ + l.lds.indices_addr = ALIGN(l.lds.total_size, 16); + l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim; + } + + if (cross_invocation_cull_primitive) { + /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */ + l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16); + l.lds.total_size = l.lds.cull_flags_addr + max_primitives; + } + + /* NGG is only allowed to address up to 32K of LDS. */ + assert(l.lds.total_size <= 32 * 1024); + return l; +} + +void +ac_nir_lower_ngg_mesh(nir_shader *shader, + enum amd_gfx_level gfx_level, + uint32_t clipdist_enable_mask, + const uint8_t *vs_output_param_offset, + bool has_param_exports, + bool *out_needs_scratch_ring, + unsigned wave_size, + unsigned hw_workgroup_size, + bool multiview, + bool has_query, + bool fast_launch_2) +{ + unsigned vertices_per_prim = + mesa_vertices_per_prim(shader->info.mesh.primitive_type); + + uint64_t per_vertex_outputs = + shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK; + uint64_t per_primitive_outputs = + shader->info.per_primitive_outputs & shader->info.outputs_written; + + /* Whether the shader uses CullPrimitiveEXT */ + bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); + /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */ + uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access | + shader->info.outputs_accessed_indirectly; + + unsigned max_vertices = shader->info.mesh.max_vertices_out; + unsigned max_primitives = shader->info.mesh.max_primitives_out; + + ms_out_mem_layout layout = ms_calculate_output_layout( + gfx_level, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs, + cross_invocation_access, max_vertices, max_primitives, vertices_per_prim); + + shader->info.shared_size = layout.lds.total_size; + *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask; + + /* The workgroup size that is specified by the API shader may be different + * from the size of the workgroup that actually runs on the HW, due to the + * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed. + * + * Therefore, we must make sure that when the API workgroup size is smaller, + * we don't run the API shader on more HW invocations than is necessary. + */ + unsigned api_workgroup_size = shader->info.workgroup_size[0] * + shader->info.workgroup_size[1] * + shader->info.workgroup_size[2]; + + lower_ngg_ms_state state = { + .layout = layout, + .wave_size = wave_size, + .per_vertex_outputs = per_vertex_outputs, + .per_primitive_outputs = per_primitive_outputs, + .vertices_per_prim = vertices_per_prim, + .api_workgroup_size = api_workgroup_size, + .hw_workgroup_size = hw_workgroup_size, + .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER), + .uses_cull_flags = uses_cull, + .gfx_level = gfx_level, + .fast_launch_2 = fast_launch_2, + .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size, + .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size, + .clipdist_enable_mask = clipdist_enable_mask, + .vs_output_param_offset = vs_output_param_offset, + .has_param_exports = has_param_exports, + .has_query = has_query, + }; + + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + assert(impl); + + state.vertex_count_var = + nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var"); + state.primitive_count_var = + nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var"); + + nir_builder builder = nir_builder_at(nir_before_impl(impl)); + nir_builder *b = &builder; /* This is to avoid the & */ + + handle_smaller_ms_api_workgroup(b, &state); + if (!fast_launch_2) + ms_emit_legacy_workgroup_index(b, &state); + ms_create_same_invocation_vars(b, &state); + nir_metadata_preserve(impl, nir_metadata_none); + + lower_ms_intrinsics(shader, &state); + + emit_ms_finale(b, &state); + nir_metadata_preserve(impl, nir_metadata_none); + + /* Cleanup */ + nir_lower_vars_to_ssa(shader); + nir_remove_dead_variables(shader, nir_var_function_temp, NULL); + nir_lower_alu_to_scalar(shader, NULL, NULL); + nir_lower_phis_to_scalar(shader, true); + + /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup, + * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for + * all load_local_invocation_index. + */ + if (fast_launch_2 && api_workgroup_size == hw_workgroup_size && + ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) + + (shader->info.workgroup_size[2] == 1)) == 2) { + nir_lower_compute_system_values_options csv_options = { + .lower_local_invocation_index = true, + }; + nir_lower_compute_system_values(shader, &csv_options); + } + + nir_validate_shader(shader, "after emitting NGG MS"); +} diff --git a/src/amd/common/ac_nir_prerast_utils.c b/src/amd/common/ac_nir_prerast_utils.c index 1996611ad92..7848d8d010d 100644 --- a/src/amd/common/ac_nir_prerast_utils.c +++ b/src/amd/common/ac_nir_prerast_utils.c @@ -801,3 +801,26 @@ ac_nir_gs_shader_query(nir_builder *b, nir_pop_if(b, if_shader_query); return true; } + +nir_def * +ac_nir_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives, + nir_def *vertex_indices[3], nir_def *is_null_prim, + enum amd_gfx_level gfx_level) +{ + nir_def *arg = nir_load_initial_edgeflags_amd(b); + + for (unsigned i = 0; i < num_vertices_per_primitives; ++i) { + assert(vertex_indices[i]); + arg = nir_ior(b, arg, nir_ishl_imm(b, vertex_indices[i], + (gfx_level >= GFX12 ? 9u : 10u) * i)); + } + + if (is_null_prim) { + if (is_null_prim->bit_size == 1) + is_null_prim = nir_b2i32(b, is_null_prim); + assert(is_null_prim->bit_size == 32); + arg = nir_ior(b, arg, nir_ishl_imm(b, is_null_prim, 31u)); + } + + return arg; +} diff --git a/src/amd/common/meson.build b/src/amd/common/meson.build index 5bb1631efce..2fc1330873e 100644 --- a/src/amd/common/meson.build +++ b/src/amd/common/meson.build @@ -100,6 +100,7 @@ amd_common_files = files( 'ac_nir_lower_tess_io_to_mem.c', 'ac_nir_lower_tex.c', 'ac_nir_lower_ngg.c', + 'ac_nir_lower_ngg_mesh.c', 'ac_nir_lower_ps_early.c', 'ac_nir_lower_ps_late.c', 'ac_nir_lower_sin_cos.c',