diff --git a/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c b/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c index 3d3c545c47a..2bfc8b1f306 100644 --- a/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c +++ b/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c @@ -384,14 +384,9 @@ hs_output_lds_map_io_location(nir_shader *shader, } static nir_def * -hs_output_lds_offset(nir_builder *b, - lower_tess_io_state *st, - nir_intrinsic_instr *intrin) +hs_output_lds_offset(nir_builder *b, lower_tess_io_state *st, unsigned location, unsigned component, + nir_def *vertex_index, nir_def *io_offset) { - bool per_vertex = intrin && - (intrin->intrinsic == nir_intrinsic_store_per_vertex_output || - intrin->intrinsic == nir_intrinsic_load_per_vertex_output); - const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(b->shader); const uint64_t tf_mask = tcs_lds_tf_out_mask(b->shader, st); const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(b->shader); @@ -404,10 +399,10 @@ hs_output_lds_offset(nir_builder *b, nir_def *off = NULL; - if (intrin) { - const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin); - const unsigned mapped = hs_output_lds_map_io_location(b->shader, per_vertex, io_sem.location, st); - off = ac_nir_calc_io_off(b, nir_intrinsic_component(intrin), nir_get_io_offset_src(intrin)->ssa, + if (io_offset) { + const unsigned mapped = hs_output_lds_map_io_location(b->shader, vertex_index != NULL, + location, st); + off = ac_nir_calc_io_off(b, component, io_offset, nir_imm_int(b, 16u), 4, mapped); } else { off = nir_imm_int(b, 0); @@ -423,8 +418,7 @@ hs_output_lds_offset(nir_builder *b, nir_def *output_patch_offset = nir_iadd_nuw(b, patch_offset, output_patch0_offset); nir_def *lds_offset; - if (per_vertex) { - nir_def *vertex_index = nir_get_io_arrayed_index_src(intrin)->ssa; + if (vertex_index) { nir_def *vertex_index_off = nir_imul_imm(b, vertex_index, output_vertex_size); off = nir_iadd_nuw(b, off, vertex_index_off); @@ -547,6 +541,7 @@ lower_hs_output_store(nir_builder *b, const unsigned component = nir_intrinsic_component(intrin); nir_def *store_val = intrin->src[0].ssa; const unsigned write_mask = nir_intrinsic_write_mask(intrin); + const bool per_vertex = intrin->intrinsic == nir_intrinsic_store_per_vertex_output; const bool write_to_vmem = tcs_output_needs_vmem(intrin, b->shader, st); const bool write_to_lds = tcs_output_needs_lds(intrin, b->shader, st); @@ -568,7 +563,9 @@ lower_hs_output_store(nir_builder *b, } if (write_to_lds) { - nir_def *lds_off = hs_output_lds_offset(b, st, intrin); + nir_def *vertex_index = per_vertex ? nir_get_io_arrayed_index_src(intrin)->ssa : NULL; + nir_def *lds_off = hs_output_lds_offset(b, st, semantics.location, component, + vertex_index, nir_get_io_offset_src(intrin)->ssa); AC_NIR_STORE_IO(b, store_val, 0, write_mask, semantics.high_16bits, nir_store_shared, lds_off, .write_mask = store_write_mask, .base = store_const_offset); } @@ -625,7 +622,10 @@ lower_hs_output_load(nir_builder *b, if (!tcs_output_needs_lds(intrin, b->shader, st)) return nir_undef(b, intrin->def.num_components, intrin->def.bit_size); - nir_def *off = hs_output_lds_offset(b, st, intrin); + nir_def *vertex_index = intrin->intrinsic == nir_intrinsic_load_per_vertex_output ? + nir_get_io_arrayed_index_src(intrin)->ssa : NULL; + nir_def *off = hs_output_lds_offset(b, st, io_sem.location, nir_intrinsic_component(intrin), + vertex_index, nir_get_io_offset_src(intrin)->ssa); nir_def *load = NULL; AC_NIR_LOAD_IO(load, b, intrin->def.num_components, intrin->def.bit_size, io_sem.high_16bits, @@ -701,7 +701,7 @@ hs_load_tess_levels(nir_builder *b, } } else { /* Base LDS address of per-patch outputs in the current patch. */ - nir_def *lds_base = hs_output_lds_offset(b, st, NULL); + nir_def *lds_base = hs_output_lds_offset(b, st, 0, 0, NULL, NULL); /* Load all tessellation factors (aka. tess levels) from LDS. */ if (st->tcs_tess_level_outer_mask) {