radv/rt: Use ds_bvh_stack_push8_pop1_rtn_b32

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35269>
This commit is contained in:
Natalie Vock 2025-03-27 18:47:43 +01:00 committed by Marge Bot
parent ea66a8d1c5
commit e978f6e247
2 changed files with 159 additions and 79 deletions

View file

@ -128,8 +128,20 @@ struct radv_bvh_box32_node {
#define RADV_BVH_ROOT_NODE radv_bvh_node_box32
#define RADV_BVH_INVALID_NODE 0xffffffffu
/* used by gfx11's ds_bvh_stack* only */
/* used by gfx11's ds_bvh_stack* only
* Indicator to ignore everything in the intrinsic result (i.e. push nothing to the stack) and only pop the next node
* from the stack.
*/
#define RADV_BVH_STACK_TERMINAL_NODE 0xfffffffeu
/* used by gfx12's ds_bvh_stack* only */
#define RADV_BVH_STACK_SKIP_0_TO_3 0xfffffffdu
#define RADV_BVH_STACK_SKIP_4_TO_7 0xfffffffbu
#define RADV_BVH_STACK_SKIP_0_TO_7 0xfffffff9u
/* On gfx12, bits 29-31 of the stack pointer contain flags. */
#define RADV_BVH_STACK_FLAG_HAS_BLAS (1u << 29)
#define RADV_BVH_STACK_FLAG_OVERFLOW (1u << 30)
#define RADV_BVH_STACK_FLAG_TLAS_POP (1u << 31)
/* GFX12 */

View file

@ -14,7 +14,9 @@ static nir_def *build_node_to_addr(struct radv_device *device, nir_builder *b, n
bool
radv_use_bvh_stack_rtn(const struct radv_physical_device *pdevice)
{
return (pdevice->info.gfx_level == GFX11 || pdevice->info.gfx_level == GFX11_5) && !radv_emulate_rt(pdevice);
/* gfx12 requires using the bvh4 ds_bvh_stack_rtn differently - enable hw stack instrs on gfx12 only with bvh8 */
return (pdevice->info.gfx_level == GFX11 || pdevice->info.gfx_level == GFX11_5 || radv_use_bvh8(pdevice)) &&
!radv_emulate_rt(pdevice);
}
nir_def *
@ -739,11 +741,18 @@ build_instance_exit(nir_builder *b, const struct radv_physical_device *pdev, con
nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit));
instance_exit->control = nir_selection_control_dont_flatten;
{
nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1);
if (radv_use_bvh8(pdev) && args->use_bvh_stack_rtn)
nir_store_deref(b, args->vars.stack,
nir_ior_imm(b, nir_load_deref(b, args->vars.stack), RADV_BVH_STACK_FLAG_TLAS_POP), 0x1);
else
nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1);
nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1);
nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1);
nir_store_deref(b, args->vars.bvh_base, build_bvh_base(b, pdev, args->root_bvh_base, ptr_flags, true), 0x1);
nir_def *root_bvh_base =
radv_use_bvh8(pdev) ? args->root_bvh_base : build_bvh_base(b, pdev, args->root_bvh_base, ptr_flags, true);
nir_store_deref(b, args->vars.bvh_base, root_bvh_base, 0x1);
nir_store_deref(b, args->vars.origin, args->origin, 7);
nir_store_deref(b, args->vars.dir, args->dir, 7);
nir_store_deref(b, args->vars.inv_dir, nir_frcp(b, args->dir), 7);
@ -1061,6 +1070,8 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete");
nir_store_var(b, incomplete, nir_imm_true(b), 0x1);
nir_variable *intrinsic_result = nir_local_variable_create(b->impl, glsl_uvec_type(8), "intrinsic_result");
nir_variable *last_visited_node = nir_local_variable_create(b->impl, glsl_uint_type(), "last_visited_node");
struct radv_ray_flags ray_flags = {
.force_opaque = radv_test_flag(b, args, SpvRayFlagsOpaqueKHRMask, true),
@ -1078,36 +1089,42 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
nir_push_loop(b);
{
nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE));
{
/* Early exit if we never overflowed the stack, to avoid having to backtrack to
* the root for no reason. */
nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride));
/* When exiting instances via stack, current_node won't ever be invalid with ds_bvh_stack_rtn */
if (args->use_bvh_stack_rtn) {
/* Early-exit when the stack is empty and there are no more nodes to process. */
nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_STACK_TERMINAL_NODE));
{
nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
nir_jump(b, nir_jump_break);
}
nir_pop_if(b, NULL);
build_instance_exit(b, pdev, args,
nir_test_mask(b, nir_load_deref(b, args->vars.stack), RADV_BVH_STACK_FLAG_TLAS_POP), NULL);
}
nir_def *stack_instance_exit =
nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack));
nir_def *root_instance_exit =
nir_ieq(b, nir_load_deref(b, args->vars.previous_node), nir_load_deref(b, args->vars.instance_bottom_node));
nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit));
instance_exit->control = nir_selection_control_dont_flatten;
{
nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1);
nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1);
nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1);
nir_store_deref(b, args->vars.bvh_base, args->root_bvh_base, 1);
nir_store_deref(b, args->vars.origin, args->origin, 7);
nir_store_deref(b, args->vars.dir, args->dir, 7);
nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE));
{
/* Early exit if we never overflowed the stack, to avoid having to backtrack to
* the root for no reason. */
if (!args->use_bvh_stack_rtn) {
nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride));
{
nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
nir_jump(b, nir_jump_break);
}
nir_pop_if(b, NULL);
build_instance_exit(
b, pdev, args, nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack)),
NULL);
}
nir_pop_if(b, NULL);
nir_push_if(
b, nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack)));
nir_def *overflow_cond =
nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack));
/* ds_bvh_stack_rtn returns 0xFFFFFFFF if and only if there was a stack overflow. */
if (args->use_bvh_stack_rtn)
overflow_cond = nir_imm_true(b);
nir_push_if(b, overflow_cond);
{
nir_def *prev = nir_load_deref(b, args->vars.previous_node);
nir_def *bvh_addr = build_node_to_addr(device, b, nir_load_deref(b, args->vars.bvh_base), true);
@ -1123,13 +1140,15 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
}
nir_push_else(b, NULL);
{
nir_store_deref(b, args->vars.stack,
nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1);
if (!args->use_bvh_stack_rtn) {
nir_store_deref(b, args->vars.stack,
nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1);
nir_def *stack_ptr =
nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries);
nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args);
nir_store_deref(b, args->vars.current_node, bvh_node, 0x1);
nir_def *stack_ptr =
nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries);
nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args);
nir_store_deref(b, args->vars.current_node, bvh_node, 0x1);
}
nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
}
nir_pop_if(b, NULL);
@ -1144,7 +1163,10 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
nir_def *prev_node = nir_load_deref(b, args->vars.previous_node);
nir_store_deref(b, args->vars.previous_node, bvh_node, 0x1);
nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
if (args->use_bvh_stack_rtn)
nir_store_var(b, last_visited_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
else
nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
nir_def *global_bvh_node = nir_iadd(b, nir_load_deref(b, args->vars.bvh_base), nir_u2u64(b, bvh_node));
@ -1152,6 +1174,7 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
nir_bvh8_intersect_ray_amd(b, 32, desc, nir_unpack_64_2x32(b, nir_load_deref(b, args->vars.bvh_base)),
nir_ishr_imm(b, args->cull_mask, 24), nir_load_deref(b, args->vars.tmax),
nir_load_deref(b, args->vars.origin), nir_load_deref(b, args->vars.dir), bvh_node);
nir_store_var(b, intrinsic_result, nir_channels(b, result, 0xff), 0xff);
nir_store_deref(b, args->vars.origin, nir_channels(b, result, 0x7 << 10), 0x7);
nir_store_deref(b, args->vars.dir, nir_channels(b, result, 0x7 << 13), 0x7);
@ -1167,64 +1190,87 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
nir_def *next_node = nir_iand_imm(b, nir_channel(b, result, 7), 0xff);
nir_push_if(b, nir_ieq_imm(b, next_node, 0xff));
nir_store_deref(b, args->vars.origin, args->origin, 7);
nir_store_deref(b, args->vars.dir, args->dir, 7);
nir_jump(b, nir_jump_continue);
{
nir_store_deref(b, args->vars.origin, args->origin, 7);
nir_store_deref(b, args->vars.dir, args->dir, 7);
if (args->use_bvh_stack_rtn) {
nir_def *skip_0_7 = nir_imm_int(b, RADV_BVH_STACK_SKIP_0_TO_7);
nir_store_var(b, intrinsic_result,
nir_vector_insert_imm(b, nir_load_var(b, intrinsic_result), skip_0_7, 7), 0xff);
} else {
nir_jump(b, nir_jump_continue);
}
}
nir_push_else(b, NULL);
{
/* instance */
nir_def *instance_node_addr = build_node_to_addr(device, b, global_bvh_node, false);
nir_store_deref(b, args->vars.instance_addr, instance_node_addr, 1);
nir_store_deref(b, args->vars.sbt_offset_and_flags, nir_channel(b, result, 6), 1);
nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1);
nir_store_deref(b, args->vars.bvh_base, nir_pack_64_2x32(b, nir_channels(b, result, 0x3 << 2)), 1);
/* Push the instance root node onto the stack */
if (args->use_bvh_stack_rtn) {
nir_def *comps[8];
for (unsigned i = 0; i < 6; ++i)
comps[i] = nir_channel(b, result, i);
comps[6] = nir_imm_int(b, RADV_BVH_STACK_SKIP_0_TO_7);
comps[7] = next_node;
nir_store_var(b, intrinsic_result, nir_vec(b, comps, 8), 0xff);
} else {
nir_store_deref(b, args->vars.current_node, next_node, 0x1);
}
nir_store_deref(b, args->vars.instance_bottom_node, next_node, 1);
nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1);
}
nir_pop_if(b, NULL);
/* instance */
nir_def *instance_node_addr = build_node_to_addr(device, b, global_bvh_node, false);
nir_store_deref(b, args->vars.instance_addr, instance_node_addr, 1);
nir_store_deref(b, args->vars.sbt_offset_and_flags, nir_channel(b, result, 6), 1);
nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1);
nir_store_deref(b, args->vars.bvh_base, nir_pack_64_2x32(b, nir_channels(b, result, 0x3 << 2)), 1);
/* Push the instance root node onto the stack */
nir_store_deref(b, args->vars.current_node, next_node, 0x1);
nir_store_deref(b, args->vars.instance_bottom_node, next_node, 1);
nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1);
}
nir_push_else(b, NULL);
{
/* box */
nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE));
{
nir_def *new_nodes[8];
for (unsigned i = 0; i < 8; ++i)
new_nodes[i] = nir_channel(b, result, i);
if (args->use_bvh_stack_rtn) {
nir_store_var(b, last_visited_node, prev_node, 0x1);
} else {
nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE));
{
nir_def *new_nodes[8];
for (unsigned i = 0; i < 8; ++i)
new_nodes[i] = nir_channel(b, result, i);
for (unsigned i = 1; i < 8; ++i)
nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE));
for (unsigned i = 1; i < 8; ++i)
nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE));
for (unsigned i = 8; i-- > 1;) {
nir_def *stack = nir_load_deref(b, args->vars.stack);
nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride);
args->stack_store_cb(b, stack_ptr, new_nodes[i], args);
nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1);
for (unsigned i = 8; i-- > 1;) {
nir_def *stack = nir_load_deref(b, args->vars.stack);
nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride);
args->stack_store_cb(b, stack_ptr, new_nodes[i], args);
nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1);
if (i == 1) {
nir_def *new_watermark =
nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_entries * args->stack_stride);
new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark);
nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1);
if (i == 1) {
nir_def *new_watermark = nir_iadd_imm(b, nir_load_deref(b, args->vars.stack),
-args->stack_entries * args->stack_stride);
new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark);
nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1);
}
nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1);
}
nir_push_else(b, NULL);
{
nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE);
for (unsigned i = 0; i < 7; ++i) {
next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), nir_channel(b, result, i + 1),
next);
nir_push_else(b, NULL);
{
nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE);
for (unsigned i = 0; i < 7; ++i) {
next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)),
nir_channel(b, result, i + 1), next);
}
nir_store_deref(b, args->vars.current_node, next, 0x1);
}
nir_store_deref(b, args->vars.current_node, next, 0x1);
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
@ -1243,6 +1289,11 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
if (args->use_bvh_stack_rtn) {
nir_def *skip_0_7 = nir_imm_int(b, RADV_BVH_STACK_SKIP_0_TO_7);
nir_store_var(b, intrinsic_result, nir_vector_insert_imm(b, nir_load_var(b, intrinsic_result), skip_0_7, 7),
0xff);
}
}
nir_pop_if(b, NULL);
@ -1251,6 +1302,23 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const
iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1);
nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1);
}
if (args->use_bvh_stack_rtn) {
nir_def *stack_result;
stack_result =
nir_bvh_stack_rtn_amd(b, 32, nir_load_deref(b, args->vars.stack), nir_load_var(b, last_visited_node),
nir_load_var(b, intrinsic_result), .stack_size = args->stack_entries);
nir_store_deref(b, args->vars.stack, nir_channel(b, stack_result, 0), 0x1);
nir_store_deref(b, args->vars.current_node, nir_channel(b, stack_result, 1), 0x1);
}
if (args->vars.break_flag) {
nir_push_if(b, nir_load_deref(b, args->vars.break_flag));
{
nir_jump(b, nir_jump_break);
}
nir_pop_if(b, NULL);
}
}
nir_pop_loop(b, NULL);