radv: Rewrite the RT prolog in NIR

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40008>
This commit is contained in:
Natalie Vock 2026-02-20 13:41:32 +01:00 committed by Marge Bot
parent b53dc3f052
commit afe519406b
5 changed files with 192 additions and 33 deletions

View file

@ -10,6 +10,9 @@
#include "nir/radv_nir_rt_stage_functions.h"
#include "aco_nir_call_attribs.h"
#include "nir_builder.h"
#include "radv_device.h"
#include "radv_meta_nir.h"
#include "radv_physical_device.h"
struct radv_nir_sbt_data
radv_nir_load_sbt_entry(nir_builder *b, nir_def *base, nir_def *idx, enum radv_nir_sbt_type binding,
@ -252,3 +255,156 @@ radv_nir_return_param_from_type(nir_parameter *param, const glsl_type *type, boo
param->driver_attributes = driver_attribs;
param->is_return = true;
}
void
radv_build_rt_prolog(struct radv_device *device, struct radv_shader_stage *stage)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
nir_builder b = radv_meta_nir_init_shader(MESA_SHADER_COMPUTE, "rt_prolog");
stage->stage = MESA_SHADER_COMPUTE;
stage->nir = b.shader;
stage->info.stage = MESA_SHADER_COMPUTE;
stage->info.loads_push_constants = true;
stage->info.loads_dynamic_offsets = true;
stage->info.force_indirect_descriptors = true;
stage->info.wave_size = pdev->rt_wave_size;
stage->info.workgroup_size = stage->info.wave_size;
stage->info.user_data_0 = R_00B900_COMPUTE_USER_DATA_0;
stage->info.type = RADV_SHADER_TYPE_RT_PROLOG;
stage->info.cs.block_size[0] = pdev->rt_wave_size;
stage->info.cs.block_size[1] = 1;
stage->info.cs.block_size[2] = 1;
stage->info.cs.uses_thread_id[0] = true;
for (unsigned i = 0; i < 3; i++)
stage->info.cs.uses_block_id[i] = true;
radv_declare_shader_args(device, NULL, &stage->info, MESA_SHADER_COMPUTE, MESA_SHADER_NONE, &stage->args);
stage->info.user_sgprs_locs = stage->args.user_sgprs_locs;
b.shader->info.workgroup_size[0] = pdev->rt_wave_size;
b.shader->info.api_subgroup_size = pdev->rt_wave_size;
b.shader->info.max_subgroup_size = pdev->rt_wave_size;
b.shader->info.min_subgroup_size = pdev->rt_wave_size;
nir_function *raygen_function = nir_function_create(b.shader, "raygen_func");
radv_nir_init_rt_function_params(raygen_function, MESA_SHADER_RAYGEN, 0, 0);
nir_def *descriptors = ac_nir_load_arg(&b, &stage->args.ac, stage->args.descriptors[0]);
nir_def *push_constants = ac_nir_load_arg(&b, &stage->args.ac, stage->args.ac.push_constants);
nir_def *dynamic_descriptors = ac_nir_load_arg(&b, &stage->args.ac, stage->args.ac.dynamic_descriptors);
nir_def *sbt_desc = nir_pack_64_2x32(&b, ac_nir_load_arg(&b, &stage->args.ac, stage->args.ac.rt.sbt_descriptors));
nir_def *launch_size_addr = nir_pack_64_2x32(&b, ac_nir_load_arg(&b, &stage->args.ac, stage->args.ac.rt.launch_size_addr));
nir_def *traversal_addr =
nir_pack_64_2x32_split(&b, ac_nir_load_arg(&b, &stage->args.ac, stage->args.ac.rt.traversal_shader_addr),
nir_imm_int(&b, pdev->info.address32_hi));
nir_def *raygen_sbt = nir_pack_64_2x32(&b, ac_nir_load_smem(&b, 2, sbt_desc, nir_imm_int(&b, 0), 4, 0));
nir_def *launch_sizes = ac_nir_load_smem(&b, 3, launch_size_addr, nir_imm_int(&b, 0), 4, 0);
nir_def *wg_id_vec = nir_load_workgroup_id(&b);
nir_def *wg_ids[3] = {
nir_channel(&b, wg_id_vec, 0),
nir_channel(&b, wg_id_vec, 1),
nir_channel(&b, wg_id_vec, 2),
};
nir_def *local_id = nir_channel(&b, nir_load_local_invocation_id(&b), 0);
nir_def *unswizzled_id_x = nir_iadd(&b, nir_imul_imm(&b, wg_ids[0], pdev->rt_wave_size), local_id);
nir_def *unswizzled_id_y = wg_ids[1];
/* Swizzle ray launch IDs. We dispatch a 1D 32x1/64x1 workgroup natively. Many games dispatch
* rays in a 2D grid and write RT results to an image indexed by the x/y launch ID.
* In image space, a 1D workgroup maps to a 32/64-pixel wide line, which is inefficient for two
* reasons:
* - Image data is usually arranged on a Z-order curve, a long line makes for inefficient
* memory access patterns.
* - Each wave working on a "line" in image space may increase divergence. It's better to trace
* rays in a small square, since that makes it more likely all rays hit the same or similar
* objects.
*
* It turns out arranging rays along a Z-order curve is best for both image access patterns and
* ray divergence. Since image data is swizzled along a Z-order curve as well, swizzling the
* launch ID should result in each lane accessing whole cachelines at once. For traced rays,
* the Z-order curve means that each quad is arranged in a 2x2 square in image space as well.
* Since the RT unit processes 4 lanes at a time, reducing divergence per quad may result in
* better RT unit utilization (for example by the RT unit being able to skip the quad entirely
* if all 4 lanes are inactive).
*
* To swizzle along a Z-order curve, treat the 1D lane ID as a morton code. Then, do the inverse
* of morton code generation (i.e. deinterleaving the bits) to recover the x-y
* coordinates on the Z-order curve.
*/
/* Deinterleave bits - even bits go to swizzled_id_x, odd ones to swizzled_id_y */
nir_def *swizzled_id_x = local_id;
nir_def *swizzled_id_y = nir_ushr_imm(&b, local_id, 1);
/* The deinterleaved bits are currently separated by single bit, like so:
* ...0 0 0 A ? B ? C
* Compact the deinterleaved bits by factor 2 to remove the padding, resulting in
* ...0 0 0 0 0 A B C
*/
nir_def *swizzled_id_shifted_x = nir_ushr_imm(&b, swizzled_id_x, 1);
nir_def *swizzled_id_shifted_y = nir_ushr_imm(&b, swizzled_id_y, 1);
swizzled_id_x = nir_bitfield_select(&b, nir_imm_int(&b, 0x11), swizzled_id_x, swizzled_id_shifted_x);
swizzled_id_y = nir_bitfield_select(&b, nir_imm_int(&b, 0x11), swizzled_id_y, swizzled_id_shifted_y);
swizzled_id_shifted_x = nir_ushr_imm(&b, swizzled_id_x, 2);
swizzled_id_shifted_y = nir_ushr_imm(&b, swizzled_id_y, 2);
swizzled_id_x = nir_bitfield_select(&b, nir_imm_int(&b, 0x3), swizzled_id_x, swizzled_id_shifted_x);
swizzled_id_y = nir_bitfield_select(&b, nir_imm_int(&b, 0x3), swizzled_id_y, swizzled_id_shifted_y);
uint32_t workgroup_width = 8;
uint32_t workgroup_height = pdev->rt_wave_size == 32 ? 4 : 8;
uint32_t workgroup_height_mask = workgroup_height - 1;
/* Fix up the workgroup IDs after converting from 32x1/64x1 to 8x4/8x8. The X dimension of the
* workgroup size gets divided by 4/8, while the Y dimension gets multiplied by the same amount.
* Rearrange the workgroups to make up for that, by rounding the Y component of the workgroup ID
* to the nearest multiple of 4/8. The remainder gets added to the X dimension, to make up for
* the fact we divided the X component of the ID.
*/
nir_def *wg_id_y_rem = nir_iand_imm(&b, wg_ids[1], workgroup_height_mask);
nir_def *new_wg_start_x = nir_imul_imm(&b, wg_ids[0], pdev->rt_wave_size);
new_wg_start_x = nir_iadd(&b, new_wg_start_x, nir_imul_imm(&b, wg_id_y_rem, workgroup_width));
nir_def *new_wg_start_y = nir_iand_imm(&b, wg_ids[1], ~workgroup_height_mask);
swizzled_id_x = nir_iadd(&b, swizzled_id_x, new_wg_start_x);
swizzled_id_y = nir_iadd(&b, swizzled_id_y, new_wg_start_y);
/* Round the launch size down to the nearest multiple of workgroup_height. If the workgroup ID
* exceeds this, then the swizzled IDs' Y component will exceed the Y launch size and we have to
* fall back to unswizzled IDs.
*/
nir_def *y_wg_bound = nir_iand_imm(&b, nir_channel(&b, launch_sizes, 1), ~workgroup_height_mask);
/* If parts of this wave would've exceeded the launch size in the X dimension, their threads will be masked out and
* exec won't equal -1. In that case, using swizzled IDs is invalid.
*/
nir_def *partial_oob_x = nir_ine_imm(&b, nir_ballot(&b, 1, pdev->rt_wave_size, nir_imm_true(&b)), -1);
nir_def *partial_oob_y = nir_uge(&b, wg_ids[1], y_wg_bound);
nir_def *partial_oob = nir_ior(&b, partial_oob_x, partial_oob_y);
nir_def *id_x = nir_bcsel(&b, partial_oob, unswizzled_id_x, swizzled_id_x);
nir_def *id_y = nir_bcsel(&b, partial_oob, unswizzled_id_y, swizzled_id_y);
/* shaderGroupBaseAlignment is RADV_RT_HANDLE_SIZE */
nir_def *raygen_addr = nir_pack_64_2x32(&b, ac_nir_load_smem(&b, 2, raygen_sbt, nir_imm_int(&b, 0), RADV_RT_HANDLE_SIZE, 0));
nir_def *shader_record_ptr = nir_iadd_imm(&b, raygen_sbt, RADV_RT_HANDLE_SIZE);
nir_def *params[RAYGEN_ARG_COUNT];
params[RT_ARG_LAUNCH_ID] = nir_vec3(&b, id_x, id_y, wg_ids[2]);
params[RT_ARG_LAUNCH_SIZE] = launch_sizes;
params[RT_ARG_DESCRIPTORS] = descriptors;
params[RT_ARG_DYNAMIC_DESCRIPTORS] = dynamic_descriptors;
params[RT_ARG_PUSH_CONSTANTS] = push_constants;
params[RT_ARG_SBT_DESCRIPTORS] = sbt_desc;
params[RAYGEN_ARG_SHADER_RECORD_PTR] = shader_record_ptr;
params[RAYGEN_ARG_TRAVERSAL_ADDR] = traversal_addr;
nir_build_indirect_call(&b, raygen_function, raygen_addr, RAYGEN_ARG_COUNT, params);
}

View file

@ -158,4 +158,6 @@ struct radv_nir_rt_traversal_result radv_build_traversal(struct radv_device *dev
struct radv_ray_tracing_pipeline *pipeline, nir_builder *b,
struct radv_nir_rt_traversal_params *params,
struct radv_ray_tracing_stage_info *info);
void radv_build_rt_prolog(struct radv_device *device, struct radv_shader_stage *stage);
#endif // MESA_RADV_NIR_RT_STAGE_COMMON_H

View file

@ -25,7 +25,10 @@
#include "radv_pipeline_layout.h"
#include "radv_pipeline_rt.h"
#include "nir/radv_nir_rt_stage_common.h"
#include "aco_interface.h"
#include "aco_nir_call_attribs.h"
#include "radv_aco_shader_info.h"
#include "radv_rmv.h"
#include "radv_shader.h"
@ -1042,13 +1045,21 @@ static void
compile_rt_prolog(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
struct nir_function raygen_stub = {0};
uint32_t push_constant_size = 0;
/* Create a dummy function signature for raygen shaders in order to pass parameter info to the prolog */
radv_nir_init_rt_function_params(&raygen_stub, MESA_SHADER_RAYGEN, 0, 0);
radv_nir_lower_callee_signature(&raygen_stub);
pipeline->prolog = radv_create_rt_prolog(device, raygen_stub.num_params, raygen_stub.params);
struct radv_shader_stage prolog_stage = {};
radv_build_rt_prolog(device, &prolog_stage);
prolog_stage.nir->options = &pdev->nir_options[MESA_SHADER_COMPUTE];
radv_optimize_nir(prolog_stage.nir, false);
radv_postprocess_nir(device, NULL, &prolog_stage);
NIR_PASS(_, prolog_stage.nir, radv_nir_lower_call_abi, prolog_stage.info.wave_size);
NIR_PASS(_, prolog_stage.nir, nir_lower_global_vars_to_local);
NIR_PASS(_, prolog_stage.nir, nir_lower_vars_to_ssa);
NIR_PASS(_, prolog_stage.nir, nir_opt_copy_prop);
NIR_PASS(_, prolog_stage.nir, nir_opt_remove_phis);
pipeline->prolog = radv_compile_rt_prolog(device, &prolog_stage);
bool has_traversal = !!pipeline->base.base.shaders[MESA_SHADER_INTERSECTION];

View file

@ -3470,33 +3470,23 @@ radv_aco_build_shader_part(void **bin, uint32_t num_sgprs, uint32_t num_vgprs, c
}
struct radv_shader *
radv_create_rt_prolog(struct radv_device *device, unsigned raygen_param_count, nir_parameter *raygen_params)
radv_compile_rt_prolog(struct radv_device *device, struct radv_shader_stage *stage)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
const struct radv_instance *instance = radv_physical_device_instance(pdev);
struct radv_instance *instance = radv_physical_device_instance(pdev);
struct radv_shader *prolog;
struct radv_shader_args in_args = {0};
struct radv_nir_compiler_options options = {0};
radv_fill_nir_compiler_options(&options, device, NULL, false, instance->debug_flags & RADV_DEBUG_DUMP_PROLOGS,
radv_device_fault_detection_enabled(device), false);
struct radv_shader_info info = {0};
info.stage = MESA_SHADER_COMPUTE;
info.loads_push_constants = true;
info.loads_dynamic_offsets = true;
info.force_indirect_descriptors = true;
info.wave_size = pdev->rt_wave_size;
info.workgroup_size = info.wave_size;
info.user_data_0 = R_00B900_COMPUTE_USER_DATA_0;
info.type = RADV_SHADER_TYPE_RT_PROLOG;
info.cs.block_size[0] = pdev->rt_wave_size;
info.cs.block_size[1] = 1;
info.cs.block_size[2] = 1;
info.cs.uses_thread_id[0] = true;
for (unsigned i = 0; i < 3; i++)
info.cs.uses_block_id[i] = true;
radv_declare_shader_args(device, NULL, &info, MESA_SHADER_COMPUTE, MESA_SHADER_NONE, &in_args);
info.user_sgprs_locs = in_args.user_sgprs_locs;
if (options.dump_shader) {
simple_mtx_lock(&instance->shader_dump_mtx);
if (instance->debug_flags & RADV_DEBUG_DUMP_NIR)
nir_print_shader(stage->nir, stderr);
}
#if AMD_LLVM_AVAILABLE
if (options.dump_shader || options.record_ir)
@ -3507,13 +3497,13 @@ radv_create_rt_prolog(struct radv_device *device, unsigned raygen_param_count, n
struct radv_shader_stage_key stage_key = {0};
struct aco_shader_info ac_info;
struct aco_compiler_options ac_opts;
radv_aco_convert_shader_info(&ac_info, &info, &in_args, &device->cache_key, pdev->info.gfx_level);
radv_aco_convert_opts(&ac_opts, &options, &in_args, &stage_key);
aco_compile_rt_prolog(&ac_opts, &ac_info, &in_args.ac, &in_args.descriptors[0], raygen_param_count, raygen_params,
&radv_aco_build_shader_binary, (void **)&binary);
binary->info = info;
radv_aco_convert_shader_info(&ac_info, &stage->info, &stage->args, &device->cache_key, pdev->info.gfx_level);
radv_aco_convert_opts(&ac_opts, &options, &stage->args, &stage_key);
aco_compile_shader(&ac_opts, &ac_info, 1, &stage->nir, &stage->args.ac, &radv_aco_build_shader_binary,
(void **)&binary);
binary->info = stage->info;
radv_postprocess_binary_config(device, binary, &in_args);
radv_postprocess_binary_config(device, binary, &stage->args);
radv_shader_create_uncached(device, binary, false, NULL, &prolog);
if (!prolog || radv_parse_binary_debug_info(device, binary, &prolog->dbg) != VK_SUCCESS)
goto done;
@ -3521,6 +3511,7 @@ radv_create_rt_prolog(struct radv_device *device, unsigned raygen_param_count, n
if (options.dump_shader) {
fprintf(stderr, "Raytracing prolog");
fprintf(stderr, "\ndisasm:\n%s\n", prolog->dbg.disasm_string);
simple_mtx_unlock(&instance->shader_dump_mtx);
}
done:

View file

@ -563,8 +563,7 @@ void radv_free_shader_memory(struct radv_device *device, union radv_shader_arena
struct radv_shader *radv_create_trap_handler_shader(struct radv_device *device);
struct radv_shader *radv_create_rt_prolog(struct radv_device *device, unsigned raygen_param_count,
nir_parameter *raygen_params);
struct radv_shader *radv_compile_rt_prolog(struct radv_device *device, struct radv_shader_stage *stage);
struct radv_shader_part *radv_shader_part_create(struct radv_device *device, struct radv_shader_part_binary *binary,
unsigned wave_size);