Merge branch 'ac_taskmesh_payload_dont_hardcode' into 'main'

ac/nir/lower_taskmesh_io_to_mem: Don't hardcode number of entries and payload entry size in shaders

See merge request mesa/mesa!39032
This commit is contained in:
Timur Kristóf 2025-12-20 00:47:17 +00:00
commit 766617c2e7
4 changed files with 41 additions and 33 deletions

View file

@ -238,14 +238,10 @@ ac_nir_lower_ngg_mesh(nir_shader *shader,
bool
ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
unsigned task_payload_entry_bytes,
unsigned task_num_entries,
bool has_query);
bool
ac_nir_lower_mesh_inputs_to_mem(nir_shader *shader,
unsigned task_payload_entry_bytes,
unsigned task_num_entries);
ac_nir_lower_mesh_inputs_to_mem(nir_shader *shader);
bool
ac_nir_lower_global_access(nir_shader *shader);

View file

@ -17,14 +17,36 @@
*/
typedef struct {
unsigned payload_entry_bytes;
unsigned draw_entry_bytes;
unsigned num_entries;
/* True if the lowering needs to insert shader query. */
bool has_query;
} lower_tsms_io_state;
static nir_def *
task_num_entries(nir_builder *b,
lower_tsms_io_state *s)
{
nir_def *ring = nir_load_ring_task_draw_amd(b);
nir_def *bytes = nir_channel(b, ring, 2);
return nir_udiv_imm(b, bytes, s->draw_entry_bytes);
}
static nir_def *
task_payload_entry_bytes(nir_builder *b,
lower_tsms_io_state *s)
{
nir_def *num_entries = task_num_entries(b, s);
nir_def *ring = nir_load_ring_task_payload_amd(b);
nir_def *bytes = nir_channel(b, ring, 2);
/* num_entries must be a power of two,
* use that to implement a division using a shift.
*/
nir_def *lsb = nir_find_lsb(b, num_entries);
return nir_ushr(b, bytes, lsb);
}
static nir_def *
task_workgroup_index(nir_builder *b,
lower_tsms_io_state *s)
@ -58,8 +80,9 @@ task_ring_entry_index(nir_builder *b,
* Note that num_entries must be a power of two.
*/
nir_def *ring_entry = nir_load_task_ring_entry_amd(b);
nir_def *num_entries = task_num_entries(b, s);
nir_def *idx = nir_iadd_nuw(b, ring_entry, task_workgroup_index(b, s));
return nir_iand_imm(b, idx, s->num_entries - 1);
return nir_iand(b, idx, nir_isub(b, num_entries, nir_imm_int(b, 1)));
}
static nir_def *
@ -90,10 +113,13 @@ task_draw_ready_bit(nir_builder *b,
*/
nir_def *ring_entry = nir_load_task_ring_entry_amd(b);
nir_def *num_entries = task_num_entries(b, s);
nir_def *workgroup_index = task_workgroup_index(b, s);
nir_def *idx = nir_iadd_nuw(b, ring_entry, workgroup_index);
return nir_u2u8(b, nir_ubfe_imm(b, idx, util_bitcount(s->num_entries - 1), 1));
nir_def *one = nir_imm_int(b, 1);
nir_def *num_entries_minus_1 = nir_isub(b, num_entries, one);
return nir_u2u8(b, nir_ubfe(b, idx, nir_bit_count(b, num_entries_minus_1), one));
}
static nir_def *
@ -109,7 +135,9 @@ mesh_ring_entry_index(nir_builder *b,
* AND with num_entries - 1 to get the correct meaning.
* Note that num_entries must be a power of two.
*/
return nir_iand_imm(b, nir_load_task_ring_entry_amd(b), s->num_entries - 1);
nir_def *num_entries = task_num_entries(b, s);
nir_def *num_entries_minus_1 = nir_isub(b, num_entries, nir_imm_int(b, 1));
return nir_iand(b, nir_load_task_ring_entry_amd(b), num_entries_minus_1);
}
static void
@ -219,7 +247,7 @@ lower_task_payload_store(nir_builder *b,
nir_def *addr = intrin->src[1].ssa;
nir_def *ring = nir_load_ring_task_payload_amd(b);
nir_def *ptr = task_ring_entry_index(b, s);
nir_def *ring_off = nir_imul_imm(b, ptr, s->payload_entry_bytes);
nir_def *ring_off = nir_imul(b, ptr, task_payload_entry_bytes(b, s));
nir_def *zero = nir_imm_int(b, 0);
nir_store_buffer_amd(b, store_val, ring, addr, ring_off, zero, .base = base,
@ -246,7 +274,7 @@ lower_taskmesh_payload_load(nir_builder *b,
nir_def *addr = intrin->src[0].ssa;
nir_def *ring = nir_load_ring_task_payload_amd(b);
nir_def *ring_off = nir_imul_imm(b, ptr, s->payload_entry_bytes);
nir_def *ring_off = nir_imul(b, ptr, task_payload_entry_bytes(b, s));
nir_def *zero = nir_imm_int(b, 0);
return nir_load_buffer_amd(b, num_components, bit_size, ring, addr, ring_off, zero, .base = base,
@ -277,11 +305,8 @@ lower_task_intrinsics(nir_builder *b,
bool
ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
unsigned task_payload_entry_bytes,
unsigned task_num_entries,
bool has_query)
{
assert(util_is_power_of_two_nonzero(task_num_entries));
bool progress = false;
nir_lower_task_shader_options lower_ts_opt = {
@ -294,8 +319,6 @@ ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
lower_tsms_io_state state = {
.draw_entry_bytes = 16,
.payload_entry_bytes = task_payload_entry_bytes,
.num_entries = task_num_entries,
.has_query = has_query,
};
@ -342,16 +365,10 @@ lower_mesh_intrinsics(nir_builder *b,
}
bool
ac_nir_lower_mesh_inputs_to_mem(nir_shader *shader,
unsigned task_payload_entry_bytes,
unsigned task_num_entries)
ac_nir_lower_mesh_inputs_to_mem(nir_shader *shader)
{
assert(util_is_power_of_two_nonzero(task_num_entries));
lower_tsms_io_state state = {
.draw_entry_bytes = 16,
.payload_entry_bytes = task_payload_entry_bytes,
.num_entries = task_num_entries,
};
return nir_shader_lower_instructions(shader,

View file

@ -262,11 +262,10 @@ radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *s
NIR_PASS(_, nir, ac_nir_lower_gs_inputs_to_mem, map_input, pdev->info.gfx_level, false);
return true;
} else if (nir->info.stage == MESA_SHADER_TASK) {
ac_nir_lower_task_outputs_to_mem(nir, pdev->task_info.payload_entry_size, pdev->task_info.num_entries,
info->cs.has_query);
ac_nir_lower_task_outputs_to_mem(nir, info->cs.has_query);
return true;
} else if (nir->info.stage == MESA_SHADER_MESH) {
ac_nir_lower_mesh_inputs_to_mem(nir, pdev->task_info.payload_entry_size, pdev->task_info.num_entries);
ac_nir_lower_mesh_inputs_to_mem(nir);
return true;
}

View file

@ -324,13 +324,9 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir)
NIR_PASS(_, nir, nir_lower_gs_intrinsics, flags);
} else if (nir->info.stage == MESA_SHADER_TASK) {
NIR_PASS(_, nir, ac_nir_lower_task_outputs_to_mem,
sscreen->task_info.payload_entry_size,
sscreen->task_info.num_entries, false);
NIR_PASS(_, nir, ac_nir_lower_task_outputs_to_mem, false);
} else if (nir->info.stage == MESA_SHADER_MESH) {
NIR_PASS(_, nir, ac_nir_lower_mesh_inputs_to_mem,
sscreen->task_info.payload_entry_size,
sscreen->task_info.num_entries);
NIR_PASS(_, nir, ac_nir_lower_mesh_inputs_to_mem);
}
if (mesa_shader_stage_is_compute(nir->info.stage)) {