diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h index fc514185e71..bfddb2a6ba3 100644 --- a/src/amd/common/ac_nir.h +++ b/src/amd/common/ac_nir.h @@ -131,6 +131,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, bool provoking_vtx_last, bool use_edgeflags, bool has_prim_query, + bool disable_streamout, uint32_t instance_rate_inputs, uint32_t clipdist_enable_mask, uint32_t user_clip_plane_enable_mask); diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index dc57923f429..2db9ee44049 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -24,6 +24,7 @@ #include "ac_nir.h" #include "nir_builder.h" +#include "nir_xfb_info.h" #include "u_math.h" #include "u_vector.h" @@ -56,12 +57,16 @@ typedef struct bool early_prim_export; bool use_edgeflags; bool has_prim_query; + bool streamout_enabled; unsigned wave_size; unsigned max_num_waves; unsigned num_vertices_per_primitives; unsigned provoking_vtx_idx; unsigned max_es_num_vertices; unsigned position_store_base; + + /* LDS params */ + unsigned pervertex_lds_bytes; unsigned total_lds_bytes; uint64_t inputs_needed_by_pos; @@ -479,25 +484,27 @@ emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st) */ nir_ssa_def *prim_id = nir_load_primitive_id(b); nir_ssa_def *provoking_vtx_idx = nir_load_var(b, st->gs_vtx_indices_vars[st->provoking_vtx_idx]); - nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u); + nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, st->pervertex_lds_bytes); - nir_store_shared(b, prim_id, addr); + /* primitive id is always at last of a vertex */ + nir_store_shared(b, prim_id, addr, .base = st->pervertex_lds_bytes - 4); } nir_pop_if(b, if_gs_thread); } static void -emit_store_ngg_nogs_es_primitive_id(nir_builder *b) +emit_store_ngg_nogs_es_primitive_id(nir_builder *b, lower_ngg_nogs_state *st) { nir_ssa_def *prim_id = NULL; if (b->shader->info.stage == MESA_SHADER_VERTEX) { /* LDS address where the primitive ID is stored */ nir_ssa_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b); - nir_ssa_def *addr = pervertex_lds_addr(b, thread_id_in_threadgroup, 4u); + nir_ssa_def *addr = + pervertex_lds_addr(b, thread_id_in_threadgroup, st->pervertex_lds_bytes); /* Load primitive ID from LDS */ - prim_id = nir_load_shared(b, 1, 32, addr); + prim_id = nir_load_shared(b, 1, 32, addr, .base = st->pervertex_lds_bytes - 4); } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) { /* Just use tess eval primitive ID, which is the same as the patch ID. */ prim_id = nir_load_primitive_id(b); @@ -1489,6 +1496,274 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c unreachable("Should be VS or TES."); } +static bool +do_ngg_nogs_store_output_to_lds(nir_builder *b, nir_instr *instr, void *state) +{ + lower_ngg_nogs_state *st = (lower_ngg_nogs_state *)state; + + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_store_output) + return false; + + unsigned component = nir_intrinsic_component(intrin); + unsigned write_mask = nir_instr_xfb_write_mask(intrin) >> component; + if (!write_mask) + return false; + + b->cursor = nir_before_instr(instr); + + unsigned base_offset = nir_src_as_uint(intrin->src[1]); + unsigned location = nir_intrinsic_io_semantics(intrin).location + base_offset; + unsigned packed_location = + util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(location)); + unsigned offset = packed_location * 16 + component * 4; + + nir_ssa_def *tid = nir_load_local_invocation_index(b); + nir_ssa_def *addr = pervertex_lds_addr(b, tid, st->pervertex_lds_bytes); + + nir_ssa_def *store_val = intrin->src[0].ssa; + nir_store_shared(b, store_val, addr, .base = offset, .write_mask = write_mask); + + return true; +} + +static void +ngg_nogs_store_all_outputs_to_lds(nir_shader *shader, lower_ngg_nogs_state *st) +{ + nir_shader_instructions_pass(shader, do_ngg_nogs_store_output_to_lds, + nir_metadata_block_index | nir_metadata_dominance, st); +} + +static void +ngg_build_streamout_buffer_info(nir_builder *b, + nir_xfb_info *info, + unsigned scratch_base, + nir_ssa_def *tid_in_tg, + nir_ssa_def *gen_prim[4], + nir_ssa_def *prim_stride_ret[4], + nir_ssa_def *so_buffer_ret[4], + nir_ssa_def *buffer_offsets_ret[4], + nir_ssa_def *emit_prim_ret[4]) +{ + /* For radeonsi which pass this value by arg when VS. Streamout need accurate + * num-vert-per-prim for writing correct amount of data to buffer. + */ + nir_ssa_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b); + for (unsigned buffer = 0; buffer < 4; buffer++) { + if (!(info->buffers_written & BITFIELD_BIT(buffer))) + continue; + + assert(info->buffers[buffer].stride); + + prim_stride_ret[buffer] = + nir_imul_imm(b, num_vert_per_prim, info->buffers[buffer].stride * 4); + so_buffer_ret[buffer] = nir_load_streamout_buffer_amd(b, .base = buffer); + } + + nir_if *if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0)); + { + nir_ssa_def *workgroup_buffer_sizes[4]; + for (unsigned buffer = 0; buffer < 4; buffer++) { + if (info->buffers_written & BITFIELD_BIT(buffer)) { + nir_ssa_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2); + /* In radeonsi, we may not know if a feedback buffer has been bound when + * compile time, so have to check buffer size in runtime to disable the + * GDS update for unbind buffer to prevent the case that previous draw + * compiled with streamout but does not bind feedback buffer miss update + * GDS which will affect current draw's streamout. + */ + nir_ssa_def *buffer_valid = nir_ine_imm(b, buffer_size, 0); + nir_ssa_def *inc_buffer_size = + nir_imul(b, gen_prim[info->buffer_to_stream[buffer]], prim_stride_ret[buffer]); + workgroup_buffer_sizes[buffer] = + nir_bcsel(b, buffer_valid, inc_buffer_size, nir_imm_int(b, 0)); + } else + workgroup_buffer_sizes[buffer] = nir_ssa_undef(b, 1, 32); + } + + nir_ssa_def *ordered_id = nir_load_ordered_id_amd(b); + /* Get current global offset of buffer and increase by amount of + * workgroup buffer size. This is an ordered operation sorted by + * ordered_id; Each buffer info is in a channel of a vec4. + */ + nir_ssa_def *buffer_offsets = + nir_ordered_xfb_counter_add_amd(b, ordered_id, nir_vec(b, workgroup_buffer_sizes, 4), + /* mask of buffers to update */ + .write_mask = info->buffers_written); + + nir_ssa_def *emit_prim[4]; + memcpy(emit_prim, gen_prim, 4 * sizeof(nir_ssa_def *)); + + for (unsigned buffer = 0; buffer < 4; buffer++) { + if (!(info->buffers_written & BITFIELD_BIT(buffer))) + continue; + + nir_ssa_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2); + nir_ssa_def *buffer_offset = nir_channel(b, buffer_offsets, buffer); + nir_ssa_def *remain_size = nir_isub(b, buffer_size, buffer_offset); + nir_ssa_def *remain_prim = nir_idiv(b, remain_size, prim_stride_ret[buffer]); + nir_ssa_def *overflow = nir_ilt(b, buffer_size, buffer_offset); + + unsigned stream = info->buffer_to_stream[buffer]; + /* when previous workgroup overflow, we can't emit any primitive */ + emit_prim[stream] = nir_bcsel( + b, overflow, nir_imm_int(b, 0), + /* we can emit part primitives, limited by smallest buffer */ + nir_imin(b, emit_prim[stream], remain_prim)); + + /* Save to LDS for being accessed by other waves in this workgroup. */ + nir_store_shared(b, buffer_offset, nir_imm_int(b, buffer * 4), + .base = scratch_base); + } + + /* No need to fixup the global buffer offset once we overflowed, + * because following workgroups overflow for sure. + */ + + /* Save to LDS for being accessed by other waves in this workgroup. */ + for (unsigned stream = 0; stream < 4; stream++) { + if (!(info->streams_written & BITFIELD_BIT(stream))) + continue; + + nir_store_shared(b, emit_prim[stream], nir_imm_int(b, stream * 4), + .base = scratch_base + 16); + } + } + nir_pop_if(b, if_invocation_0); + + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + /* Fetch the per-buffer offsets in all waves. */ + for (unsigned buffer = 0; buffer < 4; buffer++) { + if (!(info->buffers_written & BITFIELD_BIT(buffer))) + continue; + + buffer_offsets_ret[buffer] = + nir_load_shared(b, 1, 32, nir_imm_int(b, buffer * 4), .base = scratch_base); + } + + /* Fetch the per-stream emit prim in all waves. */ + for (unsigned stream = 0; stream < 4; stream++) { + if (!(info->streams_written & BITFIELD_BIT(stream))) + continue; + + emit_prim_ret[stream] = + nir_load_shared(b, 1, 32, nir_imm_int(b, stream * 4), .base = scratch_base + 16); + } +} + +static void +ngg_build_streamout_vertex(nir_builder *b, nir_xfb_info *info, + unsigned stream, int *slot_to_register, + nir_ssa_def *so_buffer[4], nir_ssa_def *buffer_offsets[4], + nir_ssa_def *vtx_buffer_idx, nir_ssa_def *vtx_lds_addr) +{ + nir_ssa_def *vtx_buffer_offsets[4]; + for (unsigned buffer = 0; buffer < 4; buffer++) { + if (!(info->buffers_written & BITFIELD_BIT(buffer))) + continue; + + nir_ssa_def *offset = nir_imul_imm(b, vtx_buffer_idx, info->buffers[buffer].stride * 4); + vtx_buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer], offset); + } + + for (unsigned i = 0; i < info->output_count; i++) { + nir_xfb_output_info *out = info->outputs + i; + if (!out->component_mask || info->buffer_to_stream[out->buffer] != stream) + continue; + + unsigned base = slot_to_register[out->location]; + unsigned offset = (base * 4 + out->component_offset) * 4; + unsigned count = util_bitcount(out->component_mask); + /* component_mask is constructed like this, see nir_gather_xfb_info_from_intrinsics() */ + assert(u_bit_consecutive(out->component_offset, count) == out->component_mask); + + nir_ssa_def *out_data = + nir_load_shared(b, count, 32, vtx_lds_addr, .base = offset); + + nir_store_buffer_amd(b, out_data, so_buffer[out->buffer], + vtx_buffer_offsets[out->buffer], + nir_imm_int(b, 0), + .base = out->offset, + .slc_amd = true); + } +} + +static void +ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s) +{ + int slot_to_register[NUM_TOTAL_VARYING_SLOTS]; + nir_xfb_info *info = nir_gather_xfb_info_from_intrinsics(b->shader, slot_to_register); + if (unlikely(!info)) { + s->streamout_enabled = false; + return; + } + + unsigned total_es_lds_bytes = s->pervertex_lds_bytes * s->max_es_num_vertices; + unsigned scratch_base = ALIGN(total_es_lds_bytes, 8u); + /* 4 dwords for 4 streamout buffer offset, 1 dword for emit prim count */ + unsigned scratch_size = 20; + s->total_lds_bytes = MAX2(s->total_lds_bytes, scratch_base + scratch_size); + + /* Get global buffer offset where this workgroup will stream out data to. */ + nir_ssa_def *generated_prim = nir_load_workgroup_num_input_primitives_amd(b); + nir_ssa_def *gen_prim_per_stream[4] = {generated_prim, 0, 0, 0}; + nir_ssa_def *emit_prim_per_stream[4] = {0}; + nir_ssa_def *buffer_offsets[4] = {0}; + nir_ssa_def *so_buffer[4] = {0}; + nir_ssa_def *prim_stride[4] = {0}; + nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b); + ngg_build_streamout_buffer_info(b, info, scratch_base, tid_in_tg, + gen_prim_per_stream, prim_stride, + so_buffer, buffer_offsets, + emit_prim_per_stream); + + /* Write out primitive data */ + nir_if *if_emit = nir_push_if(b, nir_ilt(b, tid_in_tg, emit_prim_per_stream[0])); + { + unsigned vtx_lds_stride = (b->shader->num_outputs * 4 + 1) * 4; + nir_ssa_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b); + nir_ssa_def *vtx_buffer_idx = nir_imul(b, tid_in_tg, num_vert_per_prim); + + for (unsigned i = 0; i < s->num_vertices_per_primitives; i++) { + nir_if *if_valid_vertex = + nir_push_if(b, nir_ilt(b, nir_imm_int(b, i), num_vert_per_prim)); + { + nir_ssa_def *vtx_lds_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]); + nir_ssa_def *vtx_lds_addr = pervertex_lds_addr(b, vtx_lds_idx, vtx_lds_stride); + ngg_build_streamout_vertex(b, info, 0, slot_to_register, + so_buffer, buffer_offsets, + nir_iadd_imm(b, vtx_buffer_idx, i), + vtx_lds_addr); + } + nir_pop_if(b, if_valid_vertex); + } + } + nir_pop_if(b, if_emit); + + /* Wait streamout memory ops done before export primitive, otherwise it + * may not finish when shader ends. + * + * If a shader has no param exports, rasterization can start before + * the shader finishes and thus memory stores might not finish before + * the pixel shader starts. + * + * TODO: we only need this when no param exports. + * + * TODO: not sure if we need this barrier when late prim export, as I + * can't observe test fail without this barrier. + */ + nir_memory_barrier_buffer(b); + + free(info); +} + void ac_nir_lower_ngg_nogs(nir_shader *shader, enum radeon_family family, @@ -1503,6 +1778,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, bool provoking_vtx_last, bool use_edgeflags, bool has_prim_query, + bool disable_streamout, uint32_t instance_rate_inputs, uint32_t clipdist_enable_mask, uint32_t user_clip_plane_enable_mask) @@ -1517,12 +1793,21 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL; nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL; + bool streamout_enabled = shader->xfb_info && !disable_streamout; + /* streamout need to be done before either prim or vertex export. Because when no + * param export, rasterization can start right after prim and vertex export, + * which left streamout buffer writes un-finished. + */ + if (streamout_enabled) + early_prim_export = false; + lower_ngg_nogs_state state = { .passthrough = passthrough, .export_prim_id = export_prim_id, .early_prim_export = early_prim_export, .use_edgeflags = use_edgeflags, .has_prim_query = has_prim_query, + .streamout_enabled = streamout_enabled, .num_vertices_per_primitives = num_vertices_per_primitives, .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0, .position_value_var = position_value_var, @@ -1599,9 +1884,19 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, } } + /* determine the LDS vertex stride */ + if (state.streamout_enabled) { + /* The extra dword is used to avoid LDS bank conflicts and store the primitive id. + * TODO: only alloc space for outputs that really need streamout. + */ + state.pervertex_lds_bytes = (shader->num_outputs * 4 + 1) * 4; + } else if (need_prim_id_store_shared) + state.pervertex_lds_bytes = 4; + if (need_prim_id_store_shared) { /* We need LDS space when VS needs to export the primitive ID. */ - state.total_lds_bytes = MAX2(state.total_lds_bytes, max_num_es_vertices * 4u); + state.total_lds_bytes = MAX2(state.total_lds_bytes, + state.pervertex_lds_bytes * max_num_es_vertices); emit_ngg_nogs_prim_id_store_shared(b, &state); @@ -1620,13 +1915,26 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, b->cursor = nir_after_cf_list(&if_es_thread->then_list); if (state.export_prim_id) - emit_store_ngg_nogs_es_primitive_id(b); + emit_store_ngg_nogs_es_primitive_id(b, &state); /* Export all vertex attributes (including the primitive ID) */ export_vertex_instr = nir_export_vertex_amd(b); } nir_pop_if(b, if_es_thread); + if (state.streamout_enabled) { + /* TODO: support culling after streamout. */ + assert(!can_cull); + + ngg_nogs_build_streamout(b, &state); + } + + /* streamout may be disabled by ngg_nogs_build_streamout() */ + if (state.streamout_enabled) { + ngg_nogs_store_all_outputs_to_lds(shader, &state); + b->cursor = nir_after_cf_list(&impl->body); + } + /* Take care of late primitive export */ if (!state.early_prim_export) { emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var)); diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index e1c9b7f8c46..e9145443bb4 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -1333,7 +1333,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_ info->workgroup_size, info->wave_size, info->has_ngg_culling, info->has_ngg_early_prim_export, info->is_ngg_passthrough, export_prim_id, pl_key->vs.provoking_vtx_last, false, pl_key->primitives_generated_query, - pl_key->vs.instance_rate_inputs, 0, 0); + true, pl_key->vs.instance_rate_inputs, 0, 0); /* Increase ESGS ring size so the LLVM binary contains the correct LDS size. */ ngg_stage->info.ngg_info.esgs_ring_size = nir->info.shared_size;