From 60dd9d797e80da32cce5f9b165d97b2d7698de5b Mon Sep 17 00:00:00 2001 From: Natalie Vock Date: Tue, 6 Jan 2026 14:12:18 +0100 Subject: [PATCH] aco: Swizzle ray launch IDs in the RT prolog This converts from 1D workgroups to 2D ray launch IDs entirely via shader ALU, including handling partial/cut-off workgroups optimally. Doing this entirely in-shader means it Just Works(TM) with indirect dispatches as well. Previous approaches manipulating various things on CPU depending on the dispatch size couldn't handle indirect dispatches. The swizzle implemented here also swizzles with a recursive Z-order pattern, which should be a little more optimal than arranging invocations linearly within the wave. Part-of: --- .../aco_select_rt_prolog.cpp | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/src/amd/compiler/instruction_selection/aco_select_rt_prolog.cpp b/src/amd/compiler/instruction_selection/aco_select_rt_prolog.cpp index 3b473c9fcde..76870f87ec7 100644 --- a/src/amd/compiler/instruction_selection/aco_select_rt_prolog.cpp +++ b/src/amd/compiler/instruction_selection/aco_select_rt_prolog.cpp @@ -88,6 +88,12 @@ select_rt_prolog(Program* program, ac_shader_config* config, PhysReg out_record_ptr = get_arg_reg(out_args, out_args->rt.shader_record); /* Temporaries: */ + PhysReg tmp_wg_start_x = PhysReg{num_sgprs}; + num_sgprs++; + PhysReg tmp_wg_start_y = PhysReg{num_sgprs}; + num_sgprs++; + PhysReg tmp_swizzle_bound_y = PhysReg{num_sgprs}; + num_sgprs++; PhysReg tmp_wg_id_y; if (program->gfx_level >= GFX12) { tmp_wg_id_y = PhysReg{num_sgprs}; @@ -101,6 +107,11 @@ select_rt_prolog(Program* program, ac_shader_config* config, PhysReg tmp_ring_offsets = PhysReg{num_sgprs}; num_sgprs += 2; + PhysReg tmp_swizzled_id_x = PhysReg{256 + num_vgprs++}; + PhysReg tmp_swizzled_id_y = PhysReg{256 + num_vgprs++}; + PhysReg tmp_swizzled_id_shifted_x = PhysReg{256 + num_vgprs++}; + PhysReg tmp_swizzled_id_shifted_y = PhysReg{256 + num_vgprs++}; + /* Confirm some assumptions about register aliasing */ assert(in_ring_offsets == out_uniform_shader_addr); assert(get_arg_reg(in_args, in_args->push_constants) == @@ -163,10 +174,139 @@ select_rt_prolog(Program* program, ac_shader_config* config, bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[2], v1), Operand(in_wg_id_z, s1)); } + /* Swizzle ray launch IDs. We dispatch a 1D 32x1/64x1 workgroup natively. Many games dispatch + * rays in a 2D grid and write RT results to an image indexed by the x/y launch ID. + * In image space, a 1D workgroup maps to a 32/64-pixel wide line, which is inefficient for two + * reasons: + * - Image data is usually arranged on a Z-order curve, a long line makes for inefficient + * memory access patterns. + * - Each wave working on a "line" in image space may increase divergence. It's better to trace + * rays in a small square, since that makes it more likely all rays hit the same or similar + * objects. + * + * It turns out arranging rays along a Z-order curve is best for both image access patterns and + * ray divergence. Since image data is swizzled along a Z-order curve as well, swizzling the + * launch ID should result in each lane accessing whole cachelines at once. For traced rays, + * the Z-order curve means that each quad is arranged in a 2x2 square in image space as well. + * Since the RT unit processes 4 lanes at a time, reducing divergence per quad may result in + * better RT unit utilization (for example by the RT unit being able to skip the quad entirely + * if all 4 lanes are inactive). + * + * To swizzle along a Z-order curve, treat the 1D lane ID as a morton code. Then, do the inverse + * of morton code generation (i.e. deinterleaving the bits) to recover the x-y + * coordinates on the Z-order curve. + */ + + /* Deinterleave bits - odd bits go to tmp_swizzled_id_x, even ones to tmp_swizzled_id_y */ + bld.vop2(aco_opcode::v_and_b32, Definition(tmp_swizzled_id_x, v1), Operand::c32(0x55), + Operand(in_local_id, v1)); + bld.vop2(aco_opcode::v_and_b32, Definition(tmp_swizzled_id_y, v1), Operand::c32(0xaa), + Operand(in_local_id, v1)); + bld.vop2(aco_opcode::v_lshrrev_b32, Definition(tmp_swizzled_id_y, v1), Operand::c32(1), + Operand(tmp_swizzled_id_y, v1)); + + /* The deinterleaved bits are currently padded with a zero between each bit, like so: + * 0 A 0 B 0 C 0 D + * Compact the deinterleaved bits by factor 2 to remove the padding, resulting in + * A B C D + */ + bld.vop2(aco_opcode::v_lshrrev_b32, Definition(tmp_swizzled_id_shifted_x, v1), Operand::c32(1), + Operand(tmp_swizzled_id_x, v1)); + bld.vop2(aco_opcode::v_lshrrev_b32, Definition(tmp_swizzled_id_shifted_y, v1), Operand::c32(1), + Operand(tmp_swizzled_id_y, v1)); + bld.vop2(aco_opcode::v_or_b32, Definition(tmp_swizzled_id_x, v1), Operand(tmp_swizzled_id_x, v1), + Operand(tmp_swizzled_id_shifted_x, v1)); + bld.vop2(aco_opcode::v_or_b32, Definition(tmp_swizzled_id_y, v1), Operand(tmp_swizzled_id_y, v1), + Operand(tmp_swizzled_id_shifted_y, v1)); + bld.vop2(aco_opcode::v_and_b32, Definition(tmp_swizzled_id_x, v1), Operand::c32(0x33u), + Operand(tmp_swizzled_id_x, v1)); + bld.vop2(aco_opcode::v_and_b32, Definition(tmp_swizzled_id_y, v1), Operand::c32(0x33u), + Operand(tmp_swizzled_id_y, v1)); + + bld.vop2(aco_opcode::v_lshrrev_b32, Definition(tmp_swizzled_id_shifted_x, v1), Operand::c32(2), + Operand(tmp_swizzled_id_x, v1)); + bld.vop2(aco_opcode::v_lshrrev_b32, Definition(tmp_swizzled_id_shifted_y, v1), Operand::c32(2), + Operand(tmp_swizzled_id_y, v1)); + bld.vop2(aco_opcode::v_or_b32, Definition(tmp_swizzled_id_x, v1), Operand(tmp_swizzled_id_x, v1), + Operand(tmp_swizzled_id_shifted_x, v1)); + bld.vop2(aco_opcode::v_or_b32, Definition(tmp_swizzled_id_y, v1), Operand(tmp_swizzled_id_y, v1), + Operand(tmp_swizzled_id_shifted_y, v1)); + bld.vop2(aco_opcode::v_and_b32, Definition(tmp_swizzled_id_x, v1), Operand::c32(0x0Fu), + Operand(tmp_swizzled_id_x, v1)); + bld.vop2(aco_opcode::v_and_b32, Definition(tmp_swizzled_id_y, v1), Operand::c32(0x0Fu), + Operand(tmp_swizzled_id_y, v1)); + + /* Fix up the workgroup IDs after converting from 32x1/64x1 to 8x4/8x8. The X dimension of the + * workgroup size gets divided by 4/8, while the Y dimension gets multiplied by the same amount. + * Rearrange the workgroups to make up for that, by rounding the Y component of the workgroup ID + * to the nearest multiple of 4/8. The remainder gets added to the X dimension, to make up for + * the fact we divided the X component of the ID. + */ + uint32_t workgroup_size_log2 = util_logbase2(program->workgroup_size); + bld.sop2(aco_opcode::s_lshl_b32, Definition(tmp_wg_start_x, s1), Definition(scc, s1), + Operand(in_wg_id_x, s1), Operand::c32(workgroup_size_log2)); + + /* unsigned y_remainder = tmp_wg_id_y % wg_height + * We use tmp_wg_start_y to store y_rem, and overwrite it later with the real wg_start_y. + */ + uint32_t workgroup_width_log2 = 3u; + uint32_t workgroup_height_mask = program->workgroup_size == 32 ? 0x3u : 0x7u; + bld.sop2(aco_opcode::s_and_b32, Definition(tmp_wg_start_y, s1), Definition(scc, s1), + Operand(tmp_wg_id_y, s1), Operand::c32(workgroup_height_mask)); + /* wg_start_x += y_remainder * workgroup_width (workgroup_width == 8) */ + bld.sop2(aco_opcode::s_lshl_b32, Definition(tmp_wg_start_y, s1), Definition(scc, s1), + Operand(tmp_wg_start_y, s1), Operand::c32(workgroup_width_log2)); + bld.sop2(aco_opcode::s_add_u32, Definition(tmp_wg_start_x, s1), Definition(scc, s1), + Operand(tmp_wg_start_x, s1), Operand(tmp_wg_start_y, s1)); + /* wg_start_y = ROUND_DOWN_TO(in_wg_y, workgroup_height) */ + bld.sop2(aco_opcode::s_and_b32, Definition(tmp_wg_start_y, s1), Definition(scc, s1), + Operand(tmp_wg_id_y, s1), Operand::c32(~workgroup_height_mask)); + + bld.vop2(aco_opcode::v_add_u32, Definition(tmp_swizzled_id_x, v1), Operand(tmp_wg_start_x, s1), + Operand(tmp_swizzled_id_x, v1)); + bld.vop2(aco_opcode::v_add_u32, Definition(tmp_swizzled_id_y, v1), Operand(tmp_wg_start_y, s1), + Operand(tmp_swizzled_id_y, v1)); + + /* We can only swizzle launch IDs if we run a full workgroup, and the resulting launch IDs + * won't exceed the launch size. Calculate unswizzled launch IDs here to fall back to them + * if the swizzled launch IDs are out of bounds. + */ bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[1], v1), Operand(tmp_wg_id_y, s1)); bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[0], v1), Operand(in_wg_id_x, s1), Operand::c32(program->workgroup_size), Operand(in_local_id, v1)); + /* Round the launch size down to the nearest multiple of workgroup_height. If the workgroup ID + * exceeds this, then the swizzled IDs' Y component will exceed the Y launch size and we have to + * fall back to unswizzled IDs. + */ + bld.sop2(aco_opcode::s_and_b32, Definition(tmp_swizzle_bound_y, s1), Definition(scc, s1), + Operand(out_launch_size_y, s1), Operand::c32(~workgroup_height_mask)); + /* If we are only running a partial workgroup, swizzling would yield a wrong result. */ + if (program->gfx_level >= GFX8) { + bld.sopc(Builder::s_cmp_lg, Definition(scc, s1), Operand(exec, bld.lm), + Operand::c32_or_c64(-1u, program->workgroup_size == 64)); + } else { + /* Write the XOR result to vcc because it's currently unused and a convenient register (always + * the same size as exec). We only care about the value of scc, i.e. if the result is nonzero + * (vcc is about to be overwritten anyway). + */ + bld.sop2(Builder::s_xor, Definition(vcc, bld.lm), Definition(scc, s1), Operand(exec, bld.lm), + Operand::c32_or_c64(-1u, program->workgroup_size == 64)); + } + bld.sop2(Builder::s_cselect, Definition(vcc, bld.lm), + Operand::c32_or_c64(-1u, program->wave_size == 64), + Operand::c32_or_c64(0u, program->wave_size == 64), Operand(scc, s1)); + bld.sopc(aco_opcode::s_cmp_ge_u32, Definition(scc, s1), Operand(tmp_wg_id_y, s1), + Operand(tmp_swizzle_bound_y, s1)); + bld.sop2(Builder::s_cselect, Definition(vcc, bld.lm), + Operand::c32_or_c64(-1u, program->wave_size == 64), + Operand(vcc, bld.lm), Operand(scc, s1)); + + bld.vop2(aco_opcode::v_cndmask_b32, Definition(out_launch_ids[0], v1), + Operand(tmp_swizzled_id_x, v1), Operand(out_launch_ids[0], v1), Operand(vcc, bld.lm)); + bld.vop2(aco_opcode::v_cndmask_b32, Definition(out_launch_ids[1], v1), + Operand(tmp_swizzled_id_y, v1), Operand(out_launch_ids[1], v1), Operand(vcc, bld.lm)); + /* calculate shader record ptr: SBT + RADV_RT_HANDLE_SIZE */ if (options->gfx_level < GFX9) { bld.vop2_e64(aco_opcode::v_add_co_u32, Definition(out_record_ptr, v1), Definition(vcc, s2),