brw: Lower task shader payload access in NIR

We keep this separate from the other lowering infrastructure because
there's no semantic IO involved here, just byte offsets.  Also, it needs
to run after nir_lower_mem_access_bit_sizes, which means it needs to be
run from brw_postprocess_opts.  But we can't do the mesh URB lowering
there because that doesn't have the MUE map.

It's not that much code as a separate pass, though.

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38918>
This commit is contained in:
Kenneth Graunke 2025-12-08 00:21:45 -08:00 committed by Marge Bot
parent bd0c173595
commit d0dc45955d
3 changed files with 50 additions and 68 deletions

View file

@ -133,63 +133,6 @@ brw_print_tue_map(FILE *fp, const struct brw_tue_map *map)
fprintf(fp, "TUE (%d dwords)\n\n", map->size_dw);
}
static bool
brw_nir_adjust_task_payload_offsets_instr(struct nir_builder *b,
nir_intrinsic_instr *intrin,
void *data)
{
switch (intrin->intrinsic) {
case nir_intrinsic_store_task_payload:
case nir_intrinsic_load_task_payload: {
nir_src *offset_src = nir_get_io_offset_src(intrin);
if (nir_src_is_const(*offset_src))
assert(nir_src_as_uint(*offset_src) % 4 == 0);
b->cursor = nir_before_instr(&intrin->instr);
/* Regular I/O uses dwords while explicit I/O used for task payload uses
* bytes. Normalize it to dwords.
*
* TODO(mesh): Figure out how to handle 8-bit, 16-bit.
*/
nir_def *offset = nir_ishr_imm(b, offset_src->ssa, 2);
nir_src_rewrite(offset_src, offset);
unsigned base = nir_intrinsic_base(intrin);
assert(base % 4 == 0);
nir_intrinsic_set_base(intrin, base / 4);
return true;
}
default:
return false;
}
}
static bool
brw_nir_adjust_task_payload_offsets(nir_shader *nir)
{
return nir_shader_intrinsics_pass(nir,
brw_nir_adjust_task_payload_offsets_instr,
nir_metadata_control_flow,
NULL);
}
void
brw_nir_adjust_payload(nir_shader *shader)
{
/* Adjustment of task payload offsets must be performed *after* last pass
* which interprets them as bytes, because it changes their unit.
*/
bool adjusted = false;
NIR_PASS(adjusted, shader, brw_nir_adjust_task_payload_offsets);
if (adjusted) /* clean up the mess created by offset adjustments */
NIR_PASS(_, shader, nir_opt_constant_folding);
}
static bool
brw_nir_align_launch_mesh_workgroups_instr(nir_builder *b,
nir_intrinsic_instr *intrin,

View file

@ -117,6 +117,10 @@ static unsigned
io_base_slot(nir_intrinsic_instr *io,
const struct brw_lower_urb_cb_data *cb_data)
{
if (io->intrinsic == nir_intrinsic_load_task_payload ||
io->intrinsic == nir_intrinsic_store_task_payload)
return nir_intrinsic_base(io) / 16; /* bytes to vec4 slots */
const nir_io_semantics io_sem = nir_intrinsic_io_semantics(io);
if (is_per_primitive(io)) {
@ -437,6 +441,46 @@ brw_nir_lower_outputs_to_urb_intrinsics(nir_shader *nir,
nir_metadata_control_flow, (void *) cd);
}
static bool
lower_task_payload_to_urb(nir_builder *b, nir_intrinsic_instr *io, void *data)
{
const struct brw_lower_urb_cb_data *cb_data = data;
const enum mesa_shader_stage stage = b->shader->info.stage;
if (io->intrinsic != nir_intrinsic_load_task_payload &&
io->intrinsic != nir_intrinsic_store_task_payload)
return false;
b->cursor = nir_before_instr(&io->instr);
b->constant_fold_alu = true;
/* Convert byte offset to dword offset */
nir_def *offset = nir_ishr_imm(b, nir_get_io_offset_src(io)->ssa, 2);
if (io->intrinsic == nir_intrinsic_store_task_payload) {
store_urb(b, cb_data, io, output_handle(b), offset);
nir_instr_remove(&io->instr);
} else {
const bool input = stage == MESA_SHADER_MESH;
nir_def *handle = input ? input_handle(b, io) : output_handle(b);
nir_def *load = load_urb(b, cb_data, io, handle, offset,
ACCESS_CAN_REORDER |
(input ? ACCESS_NON_WRITEABLE : 0));
nir_def_replace(&io->def, load);
}
return true;
}
static bool
lower_task_payload_to_urb_intrinsics(nir_shader *nir,
const struct intel_device_info *devinfo)
{
struct brw_lower_urb_cb_data cb_data = { .devinfo = devinfo };
return nir_shader_intrinsics_pass(nir, lower_task_payload_to_urb,
nir_metadata_control_flow, &cb_data);
}
static bool
remap_tess_levels_legacy(nir_builder *b,
nir_intrinsic_instr *intrin,
@ -2559,6 +2603,12 @@ brw_postprocess_nir_opts(nir_shader *nir, const struct brw_compiler *compiler,
brw_vectorize_lower_mem_access(nir, compiler, robust_flags);
/* Do this after lowering memory access bit-sizes */
if (nir->info.stage == MESA_SHADER_MESH ||
nir->info.stage == MESA_SHADER_TASK) {
OPT(lower_task_payload_to_urb_intrinsics, devinfo);
}
/* Needs to be prior int64 lower because it generates 64bit address
* manipulations
*/
@ -2768,15 +2818,6 @@ brw_postprocess_nir_out_of_ssa(nir_shader *nir,
if (OPT(nir_opt_rematerialize_compares))
OPT(nir_opt_dce);
/* The mesh stages require this pass to be called at the last minute,
* but if anything is done by it, it will also constant fold, and that
* undoes the work done by nir_trivialize_registers, so call it right
* before that one instead.
*/
if (nir->info.stage == MESA_SHADER_MESH ||
nir->info.stage == MESA_SHADER_TASK)
brw_nir_adjust_payload(nir);
nir_trivialize_registers(nir);
nir_sweep(nir);

View file

@ -374,8 +374,6 @@ void brw_nir_quick_pressure_estimate(nir_shader *nir,
const struct glsl_type *brw_nir_get_var_type(const struct nir_shader *nir,
nir_variable *var);
void brw_nir_adjust_payload(nir_shader *shader);
static inline nir_variable_mode
brw_nir_no_indirect_mask(const struct brw_compiler *compiler,
mesa_shader_stage stage)