ac/nir: call nir_gather_tcs_info only once for RADV

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31673>
This commit is contained in:
Marek Olšák 2024-11-16 21:33:36 -05:00 committed by Marge Bot
parent 8c2f9f0665
commit 4d8a508510
7 changed files with 17 additions and 15 deletions

View file

@ -93,7 +93,7 @@ ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
uint64_t tcs_inputs_via_lds);
void
ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, const nir_tcs_info *info,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
uint64_t tes_inputs_read,

View file

@ -1223,7 +1223,7 @@ ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
}
void
ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, const nir_tcs_info *info,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
uint64_t tes_inputs_read,
@ -1234,15 +1234,13 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
lower_tess_io_state state = {
.gfx_level = gfx_level,
.tcs_info = *info,
.tes_inputs_read = tes_inputs_read,
.tes_patch_inputs_read = tes_patch_inputs_read,
.tcs_out_patch_fits_subgroup = wave_size % shader->info.tess.tcs_vertices_out == 0,
.map_io = map,
};
nir_gather_tcs_info(shader, &state.tcs_info, shader->info.tess._primitive_mode,
shader->info.tess.spacing);
if (state.tcs_info.all_invocations_define_tess_levels) {
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
state.tcs_tess_level_outer =

View file

@ -229,8 +229,8 @@ radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *s
} else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, map_input, pdev->info.gfx_level, info->vs.tcs_in_out_eq,
info->vs.tcs_inputs_via_temp, info->vs.tcs_inputs_via_lds);
NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, map_output, pdev->info.gfx_level, info->tcs.tes_inputs_read,
info->tcs.tes_patch_inputs_read, info->wave_size);
NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, &info->tcs.info, map_output, pdev->info.gfx_level,
info->tcs.tes_inputs_read, info->tcs.tes_patch_inputs_read, info->wave_size);
return true;
} else if (nir->info.stage == MESA_SHADER_TESS_EVAL) {

View file

@ -3508,8 +3508,9 @@ radv_emit_patch_control_points(struct radv_cmd_buffer *cmd_buffer)
radv_get_tess_wg_info(pdev, &tcs_info, d->vk.ts.patch_control_points,
/* TODO: This should be only inputs in LDS (not VGPR inputs) to reduce LDS usage */
vs->info.vs.num_linked_outputs, tcs->info.tcs.num_linked_outputs,
tcs->info.tcs.num_linked_patch_outputs, tcs->info.tcs.all_invocations_define_tess_levels,
&cmd_buffer->state.tess_num_patches, &cmd_buffer->state.tess_lds_size);
tcs->info.tcs.num_linked_patch_outputs,
tcs->info.tcs.info.all_invocations_define_tess_levels, &cmd_buffer->state.tess_num_patches,
&cmd_buffer->state.tess_lds_size);
}
ls_hs_config = S_028B58_NUM_PATCHES(cmd_buffer->state.tess_num_patches) |

View file

@ -633,9 +633,8 @@ gather_shader_info_tcs(struct radv_device *device, const nir_shader *nir,
const struct radv_graphics_state_key *gfx_state, struct radv_shader_info *info)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
nir_tcs_info tcs_info;
nir_gather_tcs_info(nir, &tcs_info, nir->info.tess._primitive_mode, nir->info.tess.spacing);
nir_gather_tcs_info(nir, &info->tcs.info, nir->info.tess._primitive_mode, nir->info.tess.spacing);
info->tcs.tcs_outputs_read = nir->info.outputs_read;
info->tcs.tcs_outputs_written = nir->info.outputs_written;
@ -644,7 +643,6 @@ gather_shader_info_tcs(struct radv_device *device, const nir_shader *nir,
info->tcs.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
info->tcs.tes_inputs_read = ~0ULL;
info->tcs.tes_patch_inputs_read = ~0ULL;
info->tcs.all_invocations_define_tess_levels = tcs_info.all_invocations_define_tess_levels;
if (!info->inputs_linked)
info->tcs.num_linked_inputs = util_last_bit64(radv_gather_unlinked_io_mask(nir->info.inputs_read));
@ -660,7 +658,7 @@ gather_shader_info_tcs(struct radv_device *device, const nir_shader *nir,
radv_get_tess_wg_info(pdev, &nir->info, gfx_state->ts.patch_control_points,
/* TODO: This should be only inputs in LDS (not VGPR inputs) to reduce LDS usage */
info->tcs.num_linked_inputs, info->tcs.num_linked_outputs,
info->tcs.num_linked_patch_outputs, tcs_info.all_invocations_define_tess_levels,
info->tcs.num_linked_patch_outputs, info->tcs.info.all_invocations_define_tess_levels,
&info->num_tess_patches, &info->tcs.num_lds_blocks);
}
}

View file

@ -14,6 +14,7 @@
#include <inttypes.h>
#include <stdbool.h>
#include "nir.h"
#include "radv_constants.h"
#include "radv_shader_args.h"
@ -245,7 +246,7 @@ struct radv_shader_info {
uint8_t num_linked_outputs; /* Number of reserved per-vertex output slots in VRAM. */
uint8_t num_linked_patch_outputs; /* Number of reserved per-patch output slots in VRAM. */
bool tes_reads_tess_factors : 1;
bool all_invocations_define_tess_levels : 1;
nir_tcs_info info;
} tcs;
struct {
enum mesa_prim output_prim;

View file

@ -1874,7 +1874,11 @@ static bool si_lower_io_to_mem(struct si_shader *shader, nir_shader *nir)
if (nir->info.tess._primitive_mode == TESS_PRIMITIVE_UNSPECIFIED)
nir->info.tess._primitive_mode = key->ge.opt.tes_prim_mode;
NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, si_map_io_driver_location,
nir_tcs_info tcs_info;
nir_gather_tcs_info(nir, &tcs_info, nir->info.tess._primitive_mode,
nir->info.tess.spacing);
NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, &tcs_info, si_map_io_driver_location,
sel->screen->info.gfx_level,
~0ULL, ~0U, /* no TES inputs filter */
shader->wave_size);