radv/rt,aco: Always dispatch 1D workgroups for RT

We will swizzle the workgroups ourselves in the next commit.
Removes the need for 1D dispatch workarounds.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39142>
This commit is contained in:
Natalie Vock 2026-01-06 14:09:07 +01:00
parent 8baa95e4aa
commit 1f6ac3fa93
5 changed files with 21 additions and 116 deletions

View file

@ -18,9 +18,6 @@
extern "C" {
#endif
/* Special launch size to indicate this dispatch is a 1D dispatch converted into a 2D one */
#define ACO_RT_CONVERTED_2D_LAUNCH_SIZE -1u
struct ac_shader_config;
struct aco_shader_info;
struct aco_vs_prolog_info;

View file

@ -60,10 +60,7 @@ select_rt_prolog(Program* program, ac_shader_config* config,
in_scratch_offset = get_arg_reg(in_args, in_args->scratch_offset);
struct ac_arg arg_id = options->gfx_level >= GFX11 ? in_args->local_invocation_ids_packed
: in_args->local_invocation_id_x;
PhysReg in_local_ids[2] = {
get_arg_reg(in_args, arg_id),
get_arg_reg(in_args, arg_id).advance(4),
};
PhysReg in_local_id = get_arg_reg(in_args, arg_id);
/* Outputs:
* Callee shader PC: s[0-1]
@ -91,15 +88,18 @@ select_rt_prolog(Program* program, ac_shader_config* config,
PhysReg out_record_ptr = get_arg_reg(out_args, out_args->rt.shader_record);
/* Temporaries: */
PhysReg tmp_wg_id_y;
if (program->gfx_level >= GFX12) {
tmp_wg_id_y = PhysReg{num_sgprs};
num_sgprs++;
} else {
tmp_wg_id_y = in_wg_id_y;
}
num_sgprs = align(num_sgprs, 2);
PhysReg tmp_raygen_sbt = PhysReg{num_sgprs};
num_sgprs += 2;
PhysReg tmp_ring_offsets = PhysReg{num_sgprs};
num_sgprs += 2;
PhysReg tmp_wg_id_x_times_size = PhysReg{num_sgprs};
num_sgprs++;
PhysReg tmp_invocation_idx = PhysReg{256 + num_vgprs++};
/* Confirm some assumptions about register aliasing */
assert(in_ring_offsets == out_uniform_shader_addr);
@ -113,7 +113,7 @@ select_rt_prolog(Program* program, ac_shader_config* config,
get_arg_reg(out_args, out_args->rt.traversal_shader_addr));
assert(in_launch_size_addr == out_launch_size_x);
assert(in_stack_base == out_launch_size_z);
assert(in_local_ids[0] == out_launch_ids[0]);
assert(in_local_id == out_launch_ids[0]);
/* <gfx9 reads in_scratch_offset at the end of the prolog to write out the scratch_offset
* arg. Make sure no other outputs have overwritten it by then.
@ -154,28 +154,18 @@ select_rt_prolog(Program* program, ac_shader_config* config,
}
/* calculate ray launch ids */
if (options->gfx_level >= GFX11) {
/* Thread IDs are packed in VGPR0, 10 bits per component. */
bld.vop3(aco_opcode::v_bfe_u32, Definition(in_local_ids[1], v1), Operand(in_local_ids[0], v1),
Operand::c32(10u), Operand::c32(3u));
bld.vop2(aco_opcode::v_and_b32, Definition(in_local_ids[0], v1), Operand::c32(0x7),
Operand(in_local_ids[0], v1));
}
/* Do this backwards to reduce some RAW hazards on GFX11+ */
if (options->gfx_level >= GFX12) {
bld.vop2_e64(aco_opcode::v_lshrrev_b32, Definition(out_launch_ids[2], v1), Operand::c32(16),
Operand(in_wg_id_y, s1));
bld.vop3(aco_opcode::v_mad_u32_u16, Definition(out_launch_ids[1], v1),
Operand(in_wg_id_y, s1), Operand::c32(program->workgroup_size == 32 ? 4 : 8),
Operand(in_local_ids[1], v1));
bld.sop2(aco_opcode::s_pack_ll_b32_b16, Definition(tmp_wg_id_y, s1), Operand(in_wg_id_y, s1),
Operand::c32(0));
} else {
bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[2], v1), Operand(in_wg_id_z, s1));
bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[1], v1),
Operand(in_wg_id_y, s1), Operand::c32(program->workgroup_size == 32 ? 4 : 8),
Operand(in_local_ids[1], v1));
}
bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[1], v1), Operand(tmp_wg_id_y, s1));
bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[0], v1), Operand(in_wg_id_x, s1),
Operand::c32(8), Operand(in_local_ids[0], v1));
Operand::c32(program->workgroup_size), Operand(in_local_id, v1));
/* calculate shader record ptr: SBT + RADV_RT_HANDLE_SIZE */
if (options->gfx_level < GFX9) {
@ -188,38 +178,6 @@ select_rt_prolog(Program* program, ac_shader_config* config,
bld.vop1(aco_opcode::v_mov_b32, Definition(out_record_ptr.advance(4), v1),
Operand(tmp_raygen_sbt.advance(4), s1));
/* For 1D dispatches converted into 2D ones, we need to fix up the launch IDs.
* Calculating the 1D launch ID is: id = local_invocation_index + (wg_id.x * wg_size).
* tmp_wg_id_x_times_size now holds wg_id.x * wg_size.
*/
bld.sop2(aco_opcode::s_lshl_b32, Definition(tmp_wg_id_x_times_size, s1), Definition(scc, s1),
Operand(in_wg_id_x, s1), Operand::c32(program->workgroup_size == 32 ? 5 : 6));
/* Calculate and add local_invocation_index */
bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, Definition(tmp_invocation_idx, v1), Operand::c32(-1u),
Operand(tmp_wg_id_x_times_size, s1));
if (program->wave_size == 64) {
if (program->gfx_level <= GFX7)
bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, Definition(tmp_invocation_idx, v1),
Operand::c32(-1u), Operand(tmp_invocation_idx, v1));
else
bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32_e64, Definition(tmp_invocation_idx, v1),
Operand::c32(-1u), Operand(tmp_invocation_idx, v1));
}
/* Make fixup operations a no-op if this is not a converted 2D dispatch. */
bld.sopc(aco_opcode::s_cmp_lg_u32, Definition(scc, s1),
Operand::c32(ACO_RT_CONVERTED_2D_LAUNCH_SIZE), Operand(out_launch_size_y, s1));
bld.sop2(Builder::s_cselect, Definition(vcc, bld.lm),
Operand::c32_or_c64(-1u, program->wave_size == 64),
Operand::c32_or_c64(0, program->wave_size == 64), Operand(scc, s1));
bld.sop2(aco_opcode::s_cselect_b32, Definition(out_launch_size_y, s1),
Operand(out_launch_size_y, s1), Operand::c32(1), Operand(scc, s1));
bld.vop2(aco_opcode::v_cndmask_b32, Definition(out_launch_ids[0], v1),
Operand(tmp_invocation_idx, v1), Operand(out_launch_ids[0], v1), Operand(vcc, bld.lm));
bld.vop2(aco_opcode::v_cndmask_b32, Definition(out_launch_ids[1], v1), Operand::zero(),
Operand(out_launch_ids[1], v1), Operand(vcc, bld.lm));
if (options->gfx_level < GFX9) {
/* write scratch/ring offsets to outputs, if needed */
bld.sop1(aco_opcode::s_mov_b32,

View file

@ -1143,8 +1143,7 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
* invalid variable modes.*/
nir_builder b = radv_meta_nir_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
b.shader->info.internal = false;
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = pdev->rt_wave_size == 64 ? 8 : 4;
b.shader->info.workgroup_size[0] = pdev->rt_wave_size;
b.shader->info.api_subgroup_size = pdev->rt_wave_size;
b.shader->info.max_subgroup_size = pdev->rt_wave_size;
b.shader->info.min_subgroup_size = pdev->rt_wave_size;

View file

@ -13895,17 +13895,6 @@ radv_after_trace_rays(struct radv_cmd_buffer *cmd_buffer, bool dgc)
radv_cmd_buffer_after_draw(cmd_buffer, RADV_CMD_FLAG_CS_PARTIAL_FLUSH, dgc);
}
static void
radv_rt_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info)
{
struct radv_ray_tracing_pipeline *rt_pipeline = cmd_buffer->state.rt_pipeline;
const struct radv_shader *rt_prolog = cmd_buffer->state.rt_prolog;
radv_before_trace_rays(cmd_buffer, rt_pipeline);
radv_emit_dispatch_packets(cmd_buffer, rt_prolog, info);
radv_after_trace_rays(cmd_buffer, false);
}
VKAPI_ATTR void VKAPI_CALL
radv_CmdDispatchBase(VkCommandBuffer commandBuffer, uint32_t base_x, uint32_t base_y, uint32_t base_z, uint32_t x,
uint32_t y, uint32_t z)
@ -14059,11 +14048,7 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, VkTraceRaysIndirectCommand2K
radv_trace_trace_rays(cmd_buffer, tables, indirect_va);
struct radv_shader *rt_prolog = cmd_buffer->state.rt_prolog;
/* Since the workgroup size is 8x4 (or 8x8), 1D dispatches can only fill 8 threads per wave at most. To increase
* occupancy, it's beneficial to convert to a 2D dispatch in these cases. */
if (tables && tables->height == 1 && tables->width >= cmd_buffer->state.rt_prolog->info.cs.block_size[0])
tables->height = ACO_RT_CONVERTED_2D_LAUNCH_SIZE;
struct radv_ray_tracing_pipeline *rt_pipeline = cmd_buffer->state.rt_pipeline;
struct radv_dispatch_info info = {0};
info.unaligned = true;
@ -14079,24 +14064,10 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, VkTraceRaysIndirectCommand2K
sbt_va = indirect_va;
}
uint32_t remaining_ray_count = 0;
if (mode == radv_rt_mode_direct) {
info.blocks[0] = tables->width;
info.blocks[1] = tables->height;
info.blocks[2] = tables->depth;
if (tables->height == ACO_RT_CONVERTED_2D_LAUNCH_SIZE) {
/* We need the ray count for the 2D dispatch to be a multiple of the y block size for the division to work, and
* a multiple of the x block size because the invocation offset must be a multiple of the block size when
* dispatching the remaining rays. Fortunately, the x block size is itself a multiple of the y block size, so
* we only need to ensure that the ray count is a multiple of the x block size. */
remaining_ray_count = tables->width % rt_prolog->info.cs.block_size[0];
uint32_t ray_count = tables->width - remaining_ray_count;
info.blocks[0] = ray_count / rt_prolog->info.cs.block_size[1];
info.blocks[1] = rt_prolog->info.cs.block_size[1];
}
} else
info.indirect_va = launch_size_va;
@ -14126,28 +14097,9 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, VkTraceRaysIndirectCommand2K
assert(cs->b->cdw <= cdw_max);
radv_rt_dispatch(cmd_buffer, &info);
if (remaining_ray_count) {
info.blocks[0] = remaining_ray_count;
info.blocks[1] = 1;
info.offsets[0] = tables->width - remaining_ray_count;
/* Reset the ray launch size so the prolog doesn't think this is a converted dispatch */
tables->height = 1;
radv_upload_trace_rays_params(cmd_buffer, tables, mode, &launch_size_va, NULL);
if (ray_launch_size_addr_offset) {
radeon_begin(cs);
if (pdev->info.gfx_level >= GFX12) {
gfx12_push_64bit_pointer(ray_launch_size_addr_offset, launch_size_va);
} else {
radeon_emit_64bit_pointer(ray_launch_size_addr_offset, launch_size_va);
}
radeon_end();
}
radv_rt_dispatch(cmd_buffer, &info);
}
radv_before_trace_rays(cmd_buffer, rt_pipeline);
radv_emit_dispatch_packets(cmd_buffer, rt_prolog, &info);
radv_after_trace_rays(cmd_buffer, false);
radv_resume_conditional_rendering(cmd_buffer);
}

View file

@ -3424,11 +3424,10 @@ radv_create_rt_prolog(struct radv_device *device)
info.workgroup_size = info.wave_size;
info.user_data_0 = R_00B900_COMPUTE_USER_DATA_0;
info.type = RADV_SHADER_TYPE_RT_PROLOG;
info.cs.block_size[0] = 8;
info.cs.block_size[1] = pdev->rt_wave_size == 64 ? 8 : 4;
info.cs.block_size[0] = pdev->rt_wave_size;
info.cs.block_size[1] = 1;
info.cs.block_size[2] = 1;
info.cs.uses_thread_id[0] = true;
info.cs.uses_thread_id[1] = true;
for (unsigned i = 0; i < 3; i++)
info.cs.uses_block_id[i] = true;