diff --git a/src/amd/common/ac_nir_lower_tess_io_to_mem.c b/src/amd/common/ac_nir_lower_tess_io_to_mem.c index 2521ed97ec3..4043a56f4f7 100644 --- a/src/amd/common/ac_nir_lower_tess_io_to_mem.c +++ b/src/amd/common/ac_nir_lower_tess_io_to_mem.c @@ -122,9 +122,6 @@ typedef struct { uint64_t tes_inputs_read; uint32_t tes_patch_inputs_read; - unsigned tcs_num_reserved_outputs; - unsigned tcs_num_reserved_patch_outputs; - /* True if the output patch fits the subgroup, so all TCS outputs are always written in the same * subgroup that reads them. */ @@ -400,9 +397,15 @@ hs_output_lds_offset(nir_builder *b, (intrin->intrinsic == nir_intrinsic_store_per_vertex_output || intrin->intrinsic == nir_intrinsic_load_per_vertex_output); - unsigned output_vertex_size = st->tcs_num_reserved_outputs * 16u; + const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(b->shader); + const uint64_t tf_mask = tcs_lds_tf_out_mask(b->shader, st); + const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(b->shader); + + unsigned tcs_num_reserved_outputs = util_bitcount64(per_vertex_mask); + unsigned tcs_num_reserved_patch_outputs = util_bitcount64(tf_mask) + util_bitcount(patch_out_mask); + unsigned output_vertex_size = tcs_num_reserved_outputs * 16u; unsigned pervertex_output_patch_size = b->shader->info.tess.tcs_vertices_out * output_vertex_size; - unsigned output_patch_stride = pervertex_output_patch_size + st->tcs_num_reserved_patch_outputs * 16u; + unsigned output_patch_stride = pervertex_output_patch_size + tcs_num_reserved_patch_outputs * 16u; nir_def *off = intrin ? ac_nir_calc_io_offset_mapped(b, intrin, nir_imm_int(b, 16u), 4u, @@ -956,8 +959,6 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, .gfx_level = gfx_level, .tes_inputs_read = tes_inputs_read, .tes_patch_inputs_read = tes_patch_inputs_read, - .tcs_num_reserved_outputs = num_reserved_tcs_outputs, - .tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs, .tcs_out_patch_fits_subgroup = wave_size % shader->info.tess.tcs_vertices_out == 0, .tcs_pass_tessfactors_by_reg = pass_tessfactors_by_reg, .tcs_no_inputs_in_lds = no_inputs_in_lds,