From e978f6e24718ca17c38ae246c22efff7a4604d35 Mon Sep 17 00:00:00 2001 From: Natalie Vock Date: Thu, 27 Mar 2025 18:47:43 +0100 Subject: [PATCH] radv/rt: Use ds_bvh_stack_push8_pop1_rtn_b32 Part-of: --- src/amd/vulkan/bvh/bvh.h | 14 +- src/amd/vulkan/nir/radv_nir_rt_common.c | 224 +++++++++++++++--------- 2 files changed, 159 insertions(+), 79 deletions(-) diff --git a/src/amd/vulkan/bvh/bvh.h b/src/amd/vulkan/bvh/bvh.h index 6a13cc8023b..2fb14839e87 100644 --- a/src/amd/vulkan/bvh/bvh.h +++ b/src/amd/vulkan/bvh/bvh.h @@ -128,8 +128,20 @@ struct radv_bvh_box32_node { #define RADV_BVH_ROOT_NODE radv_bvh_node_box32 #define RADV_BVH_INVALID_NODE 0xffffffffu -/* used by gfx11's ds_bvh_stack* only */ +/* used by gfx11's ds_bvh_stack* only + * Indicator to ignore everything in the intrinsic result (i.e. push nothing to the stack) and only pop the next node + * from the stack. + */ #define RADV_BVH_STACK_TERMINAL_NODE 0xfffffffeu +/* used by gfx12's ds_bvh_stack* only */ +#define RADV_BVH_STACK_SKIP_0_TO_3 0xfffffffdu +#define RADV_BVH_STACK_SKIP_4_TO_7 0xfffffffbu +#define RADV_BVH_STACK_SKIP_0_TO_7 0xfffffff9u + +/* On gfx12, bits 29-31 of the stack pointer contain flags. */ +#define RADV_BVH_STACK_FLAG_HAS_BLAS (1u << 29) +#define RADV_BVH_STACK_FLAG_OVERFLOW (1u << 30) +#define RADV_BVH_STACK_FLAG_TLAS_POP (1u << 31) /* GFX12 */ diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.c b/src/amd/vulkan/nir/radv_nir_rt_common.c index ebd97dc7a4e..3917119cbe1 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_common.c +++ b/src/amd/vulkan/nir/radv_nir_rt_common.c @@ -14,7 +14,9 @@ static nir_def *build_node_to_addr(struct radv_device *device, nir_builder *b, n bool radv_use_bvh_stack_rtn(const struct radv_physical_device *pdevice) { - return (pdevice->info.gfx_level == GFX11 || pdevice->info.gfx_level == GFX11_5) && !radv_emulate_rt(pdevice); + /* gfx12 requires using the bvh4 ds_bvh_stack_rtn differently - enable hw stack instrs on gfx12 only with bvh8 */ + return (pdevice->info.gfx_level == GFX11 || pdevice->info.gfx_level == GFX11_5 || radv_use_bvh8(pdevice)) && + !radv_emulate_rt(pdevice); } nir_def * @@ -739,11 +741,18 @@ build_instance_exit(nir_builder *b, const struct radv_physical_device *pdev, con nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit)); instance_exit->control = nir_selection_control_dont_flatten; { - nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1); + if (radv_use_bvh8(pdev) && args->use_bvh_stack_rtn) + nir_store_deref(b, args->vars.stack, + nir_ior_imm(b, nir_load_deref(b, args->vars.stack), RADV_BVH_STACK_FLAG_TLAS_POP), 0x1); + else + nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1); nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1); nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1); - nir_store_deref(b, args->vars.bvh_base, build_bvh_base(b, pdev, args->root_bvh_base, ptr_flags, true), 0x1); + nir_def *root_bvh_base = + radv_use_bvh8(pdev) ? args->root_bvh_base : build_bvh_base(b, pdev, args->root_bvh_base, ptr_flags, true); + + nir_store_deref(b, args->vars.bvh_base, root_bvh_base, 0x1); nir_store_deref(b, args->vars.origin, args->origin, 7); nir_store_deref(b, args->vars.dir, args->dir, 7); nir_store_deref(b, args->vars.inv_dir, nir_frcp(b, args->dir), 7); @@ -1061,6 +1070,8 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete"); nir_store_var(b, incomplete, nir_imm_true(b), 0x1); + nir_variable *intrinsic_result = nir_local_variable_create(b->impl, glsl_uvec_type(8), "intrinsic_result"); + nir_variable *last_visited_node = nir_local_variable_create(b->impl, glsl_uint_type(), "last_visited_node"); struct radv_ray_flags ray_flags = { .force_opaque = radv_test_flag(b, args, SpvRayFlagsOpaqueKHRMask, true), @@ -1078,36 +1089,42 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const nir_push_loop(b); { - nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE)); - { - /* Early exit if we never overflowed the stack, to avoid having to backtrack to - * the root for no reason. */ - nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride)); + /* When exiting instances via stack, current_node won't ever be invalid with ds_bvh_stack_rtn */ + if (args->use_bvh_stack_rtn) { + /* Early-exit when the stack is empty and there are no more nodes to process. */ + nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_STACK_TERMINAL_NODE)); { nir_store_var(b, incomplete, nir_imm_false(b), 0x1); nir_jump(b, nir_jump_break); } nir_pop_if(b, NULL); + build_instance_exit(b, pdev, args, + nir_test_mask(b, nir_load_deref(b, args->vars.stack), RADV_BVH_STACK_FLAG_TLAS_POP), NULL); + } - nir_def *stack_instance_exit = - nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack)); - nir_def *root_instance_exit = - nir_ieq(b, nir_load_deref(b, args->vars.previous_node), nir_load_deref(b, args->vars.instance_bottom_node)); - nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit)); - instance_exit->control = nir_selection_control_dont_flatten; - { - nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1); - nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1); - nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1); - - nir_store_deref(b, args->vars.bvh_base, args->root_bvh_base, 1); - nir_store_deref(b, args->vars.origin, args->origin, 7); - nir_store_deref(b, args->vars.dir, args->dir, 7); + nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE)); + { + /* Early exit if we never overflowed the stack, to avoid having to backtrack to + * the root for no reason. */ + if (!args->use_bvh_stack_rtn) { + nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride)); + { + nir_store_var(b, incomplete, nir_imm_false(b), 0x1); + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, NULL); + build_instance_exit( + b, pdev, args, nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack)), + NULL); } - nir_pop_if(b, NULL); - nir_push_if( - b, nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack))); + nir_def *overflow_cond = + nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack)); + /* ds_bvh_stack_rtn returns 0xFFFFFFFF if and only if there was a stack overflow. */ + if (args->use_bvh_stack_rtn) + overflow_cond = nir_imm_true(b); + + nir_push_if(b, overflow_cond); { nir_def *prev = nir_load_deref(b, args->vars.previous_node); nir_def *bvh_addr = build_node_to_addr(device, b, nir_load_deref(b, args->vars.bvh_base), true); @@ -1123,13 +1140,15 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const } nir_push_else(b, NULL); { - nir_store_deref(b, args->vars.stack, - nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1); + if (!args->use_bvh_stack_rtn) { + nir_store_deref(b, args->vars.stack, + nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1); - nir_def *stack_ptr = - nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries); - nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args); - nir_store_deref(b, args->vars.current_node, bvh_node, 0x1); + nir_def *stack_ptr = + nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries); + nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args); + nir_store_deref(b, args->vars.current_node, bvh_node, 0x1); + } nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); } nir_pop_if(b, NULL); @@ -1144,7 +1163,10 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const nir_def *prev_node = nir_load_deref(b, args->vars.previous_node); nir_store_deref(b, args->vars.previous_node, bvh_node, 0x1); - nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); + if (args->use_bvh_stack_rtn) + nir_store_var(b, last_visited_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); + else + nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); nir_def *global_bvh_node = nir_iadd(b, nir_load_deref(b, args->vars.bvh_base), nir_u2u64(b, bvh_node)); @@ -1152,6 +1174,7 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const nir_bvh8_intersect_ray_amd(b, 32, desc, nir_unpack_64_2x32(b, nir_load_deref(b, args->vars.bvh_base)), nir_ishr_imm(b, args->cull_mask, 24), nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin), nir_load_deref(b, args->vars.dir), bvh_node); + nir_store_var(b, intrinsic_result, nir_channels(b, result, 0xff), 0xff); nir_store_deref(b, args->vars.origin, nir_channels(b, result, 0x7 << 10), 0x7); nir_store_deref(b, args->vars.dir, nir_channels(b, result, 0x7 << 13), 0x7); @@ -1167,64 +1190,87 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const nir_def *next_node = nir_iand_imm(b, nir_channel(b, result, 7), 0xff); nir_push_if(b, nir_ieq_imm(b, next_node, 0xff)); - nir_store_deref(b, args->vars.origin, args->origin, 7); - nir_store_deref(b, args->vars.dir, args->dir, 7); - nir_jump(b, nir_jump_continue); + { + nir_store_deref(b, args->vars.origin, args->origin, 7); + nir_store_deref(b, args->vars.dir, args->dir, 7); + if (args->use_bvh_stack_rtn) { + nir_def *skip_0_7 = nir_imm_int(b, RADV_BVH_STACK_SKIP_0_TO_7); + nir_store_var(b, intrinsic_result, + nir_vector_insert_imm(b, nir_load_var(b, intrinsic_result), skip_0_7, 7), 0xff); + } else { + nir_jump(b, nir_jump_continue); + } + } + nir_push_else(b, NULL); + { + /* instance */ + nir_def *instance_node_addr = build_node_to_addr(device, b, global_bvh_node, false); + nir_store_deref(b, args->vars.instance_addr, instance_node_addr, 1); + + nir_store_deref(b, args->vars.sbt_offset_and_flags, nir_channel(b, result, 6), 1); + + nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1); + nir_store_deref(b, args->vars.bvh_base, nir_pack_64_2x32(b, nir_channels(b, result, 0x3 << 2)), 1); + + /* Push the instance root node onto the stack */ + if (args->use_bvh_stack_rtn) { + nir_def *comps[8]; + for (unsigned i = 0; i < 6; ++i) + comps[i] = nir_channel(b, result, i); + comps[6] = nir_imm_int(b, RADV_BVH_STACK_SKIP_0_TO_7); + comps[7] = next_node; + nir_store_var(b, intrinsic_result, nir_vec(b, comps, 8), 0xff); + } else { + nir_store_deref(b, args->vars.current_node, next_node, 0x1); + } + nir_store_deref(b, args->vars.instance_bottom_node, next_node, 1); + nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1); + } nir_pop_if(b, NULL); - - /* instance */ - nir_def *instance_node_addr = build_node_to_addr(device, b, global_bvh_node, false); - nir_store_deref(b, args->vars.instance_addr, instance_node_addr, 1); - - nir_store_deref(b, args->vars.sbt_offset_and_flags, nir_channel(b, result, 6), 1); - - nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1); - nir_store_deref(b, args->vars.bvh_base, nir_pack_64_2x32(b, nir_channels(b, result, 0x3 << 2)), 1); - - /* Push the instance root node onto the stack */ - nir_store_deref(b, args->vars.current_node, next_node, 0x1); - nir_store_deref(b, args->vars.instance_bottom_node, next_node, 1); - nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1); } nir_push_else(b, NULL); { /* box */ - nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE)); - { - nir_def *new_nodes[8]; - for (unsigned i = 0; i < 8; ++i) - new_nodes[i] = nir_channel(b, result, i); + if (args->use_bvh_stack_rtn) { + nir_store_var(b, last_visited_node, prev_node, 0x1); + } else { + nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE)); + { + nir_def *new_nodes[8]; + for (unsigned i = 0; i < 8; ++i) + new_nodes[i] = nir_channel(b, result, i); - for (unsigned i = 1; i < 8; ++i) - nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE)); + for (unsigned i = 1; i < 8; ++i) + nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE)); - for (unsigned i = 8; i-- > 1;) { - nir_def *stack = nir_load_deref(b, args->vars.stack); - nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride); - args->stack_store_cb(b, stack_ptr, new_nodes[i], args); - nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1); + for (unsigned i = 8; i-- > 1;) { + nir_def *stack = nir_load_deref(b, args->vars.stack); + nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride); + args->stack_store_cb(b, stack_ptr, new_nodes[i], args); + nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1); - if (i == 1) { - nir_def *new_watermark = - nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_entries * args->stack_stride); - new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark); - nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1); + if (i == 1) { + nir_def *new_watermark = nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), + -args->stack_entries * args->stack_stride); + new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark); + nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1); + } + + nir_pop_if(b, NULL); } - - nir_pop_if(b, NULL); + nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1); } - nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1); - } - nir_push_else(b, NULL); - { - nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE); - for (unsigned i = 0; i < 7; ++i) { - next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), nir_channel(b, result, i + 1), - next); + nir_push_else(b, NULL); + { + nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE); + for (unsigned i = 0; i < 7; ++i) { + next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), + nir_channel(b, result, i + 1), next); + } + nir_store_deref(b, args->vars.current_node, next, 0x1); } - nir_store_deref(b, args->vars.current_node, next, 0x1); + nir_pop_if(b, NULL); } - nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } @@ -1243,6 +1289,11 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const nir_pop_if(b, NULL); } nir_pop_if(b, NULL); + if (args->use_bvh_stack_rtn) { + nir_def *skip_0_7 = nir_imm_int(b, RADV_BVH_STACK_SKIP_0_TO_7); + nir_store_var(b, intrinsic_result, nir_vector_insert_imm(b, nir_load_var(b, intrinsic_result), skip_0_7, 7), + 0xff); + } } nir_pop_if(b, NULL); @@ -1251,6 +1302,23 @@ radv_build_ray_traversal_gfx12(struct radv_device *device, nir_builder *b, const iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1); nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1); } + + if (args->use_bvh_stack_rtn) { + nir_def *stack_result; + stack_result = + nir_bvh_stack_rtn_amd(b, 32, nir_load_deref(b, args->vars.stack), nir_load_var(b, last_visited_node), + nir_load_var(b, intrinsic_result), .stack_size = args->stack_entries); + nir_store_deref(b, args->vars.stack, nir_channel(b, stack_result, 0), 0x1); + nir_store_deref(b, args->vars.current_node, nir_channel(b, stack_result, 1), 0x1); + } + + if (args->vars.break_flag) { + nir_push_if(b, nir_load_deref(b, args->vars.break_flag)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, NULL); + } } nir_pop_loop(b, NULL);