diff --git a/src/amd/vulkan/bvh/bvh.h b/src/amd/vulkan/bvh/bvh.h index e21b92c93c7..6a13cc8023b 100644 --- a/src/amd/vulkan/bvh/bvh.h +++ b/src/amd/vulkan/bvh/bvh.h @@ -128,6 +128,8 @@ struct radv_bvh_box32_node { #define RADV_BVH_ROOT_NODE radv_bvh_node_box32 #define RADV_BVH_INVALID_NODE 0xffffffffu +/* used by gfx11's ds_bvh_stack* only */ +#define RADV_BVH_STACK_TERMINAL_NODE 0xfffffffeu /* GFX12 */ diff --git a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c index 2dc30e5a474..5188eaf1ac6 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c +++ b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c @@ -80,6 +80,7 @@ enum radv_ray_query_field { radv_ray_query_trav_instance_top_node, radv_ray_query_trav_instance_bottom_node, radv_ray_query_stack, + radv_ray_query_break_flag, radv_ray_query_field_count, }; @@ -116,6 +117,7 @@ radv_get_ray_query_type() FIELD(trav_instance_top_node, glsl_uint_type()); FIELD(trav_instance_bottom_node, glsl_uint_type()); FIELD(stack, glsl_array_type(glsl_uint_type(), MAX_SCRATCH_STACK_ENTRY_COUNT, 0)); + FIELD(break_flag, glsl_bool_type()); #undef FIELD @@ -145,17 +147,20 @@ struct ray_query_vars { static void init_ray_query_vars(nir_shader *shader, const glsl_type *opaque_type, struct ray_query_vars *dst, const char *base_name, - uint32_t max_shared_size) + const struct radv_physical_device *pdev) { memset(dst, 0, sizeof(*dst)); uint32_t workgroup_size = shader->info.workgroup_size[0] * shader->info.workgroup_size[1] * shader->info.workgroup_size[2]; uint32_t shared_stack_entries = shader->info.ray_queries == 1 ? 16 : 8; + /* ds_bvh_stack* instructions use a fixed stride of 32 dwords. */ + if (radv_use_bvh_stack_rtn(pdev)) + workgroup_size = MAX2(workgroup_size, 32); uint32_t shared_stack_size = workgroup_size * shared_stack_entries * 4; uint32_t shared_offset = align(shader->info.shared_size, 4); if (shader->info.stage != MESA_SHADER_COMPUTE || glsl_type_is_array(opaque_type) || - shared_offset + shared_stack_size > max_shared_size) { + shared_offset + shared_stack_size > pdev->max_shared_size) { dst->stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT; } else { dst->shared_stack = true; @@ -170,11 +175,11 @@ init_ray_query_vars(nir_shader *shader, const glsl_type *opaque_type, struct ray } static void -lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht, uint32_t max_shared_size) +lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht, const struct radv_physical_device *pdev) { struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars); - init_ray_query_vars(shader, ray_query->type, vars, ray_query->name == NULL ? "" : ray_query->name, max_shared_size); + init_ray_query_vars(shader, ray_query->type, vars, ray_query->name == NULL ? "" : ray_query->name, pdev); _mesa_hash_table_insert(ht, ray_query, vars); } @@ -279,10 +284,18 @@ lower_rq_initialize(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query rq_store(b, rq, trav_bvh_base, bvh_base); if (vars->shared_stack) { - nir_def *base_offset = nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)); - base_offset = nir_iadd_imm(b, base_offset, vars->shared_base); - rq_store(b, rq, trav_stack, base_offset); - rq_store(b, rq, trav_stack_low_watermark, base_offset); + if (radv_use_bvh_stack_rtn(pdev)) { + uint32_t workgroup_size = + b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2]; + nir_def *addr = radv_build_bvh_stack_rtn_addr(b, pdev, workgroup_size, vars->shared_base, vars->stack_entries); + rq_store(b, rq, trav_stack, addr); + rq_store(b, rq, trav_stack_low_watermark, addr); + } else { + nir_def *base_offset = nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)); + base_offset = nir_iadd_imm(b, base_offset, vars->shared_base); + rq_store(b, rq, trav_stack, base_offset); + rq_store(b, rq, trav_stack_low_watermark, base_offset); + } } else { rq_store(b, rq, trav_stack, nir_imm_int(b, 0)); rq_store(b, rq, trav_stack_low_watermark, nir_imm_int(b, 0)); @@ -387,6 +400,7 @@ lower_rq_load(struct radv_device *device, nir_builder *b, nir_intrinsic_instr *i } struct traversal_data { + const struct radv_device *device; struct ray_query_vars *vars; nir_deref_instr *rq; }; @@ -404,7 +418,10 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio isec_store(b, candidate, opaque, intersection->opaque); isec_store(b, candidate, intersection_type, nir_imm_int(b, intersection_type_aabb)); - nir_jump(b, nir_jump_break); + if (args->use_bvh_stack_rtn) + rq_store(b, data->rq, break_flag, nir_imm_true(b)); + else + nir_jump(b, nir_jump_break); } static void @@ -430,7 +447,10 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int } nir_push_else(b, NULL); { - nir_jump(b, nir_jump_break); + if (args->use_bvh_stack_rtn) + rq_store(b, data->rq, break_flag, nir_imm_true(b)); + else + nir_jump(b, nir_jump_break); } nir_pop_if(b, NULL); } @@ -493,9 +513,11 @@ lower_rq_proceed(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query_va .instance_bottom_node = rq_deref(b, rq, trav_instance_bottom_node), .instance_addr = isec_deref(b, candidate, instance_addr), .sbt_offset_and_flags = isec_deref(b, candidate, sbt_offset_and_flags), + .break_flag = rq_deref(b, rq, break_flag), }; struct traversal_data data = { + .device = device, .vars = vars, .rq = rq, }; @@ -518,15 +540,23 @@ lower_rq_proceed(nir_builder *b, nir_intrinsic_instr *instr, struct ray_query_va }; if (vars->shared_stack) { - uint32_t workgroup_size = - b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2]; - args.stack_stride = workgroup_size * 4; - args.stack_base = vars->shared_base; + args.use_bvh_stack_rtn = radv_use_bvh_stack_rtn(pdev); + if (args.use_bvh_stack_rtn) { + args.stack_stride = 1; + args.stack_base = 0; + } else { + uint32_t workgroup_size = + b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2]; + args.stack_stride = workgroup_size * 4; + args.stack_base = vars->shared_base; + } } else { args.stack_stride = 1; args.stack_base = 0; } + rq_store(b, rq, break_flag, nir_imm_false(b)); + nir_push_if(b, rq_load(b, rq, incomplete)); { nir_def *incomplete; @@ -569,7 +599,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device if (!var->data.ray_query) continue; - lower_ray_query(shader, var, query_ht, pdev->max_shared_size); + lower_ray_query(shader, var, query_ht, pdev); progress = true; } @@ -581,7 +611,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device if (!var->data.ray_query) continue; - lower_ray_query(shader, var, query_ht, pdev->max_shared_size); + lower_ray_query(shader, var, query_ht, pdev); progress = true; } diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.c b/src/amd/vulkan/nir/radv_nir_rt_common.c index c8dbdf3b923..ebd97dc7a4e 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_common.c +++ b/src/amd/vulkan/nir/radv_nir_rt_common.c @@ -11,6 +11,42 @@ static nir_def *build_node_to_addr(struct radv_device *device, nir_builder *b, nir_def *node, bool skip_type_and); +bool +radv_use_bvh_stack_rtn(const struct radv_physical_device *pdevice) +{ + return (pdevice->info.gfx_level == GFX11 || pdevice->info.gfx_level == GFX11_5) && !radv_emulate_rt(pdevice); +} + +nir_def * +radv_build_bvh_stack_rtn_addr(nir_builder *b, const struct radv_physical_device *pdev, uint32_t workgroup_size, + uint32_t stack_base, uint32_t max_stack_entries) +{ + assert(stack_base % 4 == 0); + + nir_def *stack_idx = nir_load_local_invocation_index(b); + /* RDNA3's ds_bvh_stack_rtn instruction uses a special encoding for the stack address. + * Bits 0-17 encode the current stack index (set to 0 initially) + * Bits 18-31 encodes the stack base in multiples of 4 + * + * The hardware uses a stride of 128 bytes (32 entries) for the stack index so the upper 32 threads need a different + * base offset with wave64. + */ + if (workgroup_size > 32) { + nir_def *wave32_thread_id = nir_iand_imm(b, stack_idx, 0x1f); + nir_def *wave32_group_id = nir_ushr_imm(b, stack_idx, 5); + uint32_t stack_entries_per_group = max_stack_entries * 32; + nir_def *group_stack_base = nir_imul_imm(b, wave32_group_id, stack_entries_per_group); + stack_idx = nir_iadd(b, wave32_thread_id, group_stack_base); + } + stack_idx = nir_iadd_imm(b, stack_idx, stack_base / 4); + /* There are 4 bytes in each stack entry so no further arithmetic is needed. */ + if (pdev->info.gfx_level >= GFX12) + stack_idx = nir_ishl_imm(b, stack_idx, 15); + else + stack_idx = nir_ishl_imm(b, stack_idx, 18); + return stack_idx; +} + static void nir_sort_hit_pair(nir_builder *b, nir_variable *var_distances, nir_variable *var_indices, uint32_t chan_1, uint32_t chan_2) @@ -693,12 +729,36 @@ build_bvh_base(nir_builder *b, const struct radv_physical_device *pdev, nir_def return nir_pack_64_2x32(b, nir_vector_insert_imm(b, base_addr_vec, addr_hi, 1)); } +static void +build_instance_exit(nir_builder *b, const struct radv_physical_device *pdev, const struct radv_ray_traversal_args *args, + nir_def *stack_instance_exit, nir_def *ptr_flags) +{ + nir_def *root_instance_exit = nir_iand( + b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE), + nir_ieq(b, nir_load_deref(b, args->vars.previous_node), nir_load_deref(b, args->vars.instance_bottom_node))); + nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit)); + instance_exit->control = nir_selection_control_dont_flatten; + { + nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1); + nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1); + nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1); + + nir_store_deref(b, args->vars.bvh_base, build_bvh_base(b, pdev, args->root_bvh_base, ptr_flags, true), 0x1); + nir_store_deref(b, args->vars.origin, args->origin, 7); + nir_store_deref(b, args->vars.dir, args->dir, 7); + nir_store_deref(b, args->vars.inv_dir, nir_frcp(b, args->dir), 7); + } + nir_pop_if(b, NULL); +} + nir_def * radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args) { const struct radv_physical_device *pdev = radv_device_physical(device); nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete"); nir_store_var(b, incomplete, nir_imm_true(b), 0x1); + nir_variable *intrinsic_result = nir_local_variable_create(b->impl, glsl_uvec4_type(), "intrinsic_result"); + nir_variable *last_visited_node = nir_local_variable_create(b->impl, glsl_uint_type(), "last_visited_node"); struct radv_ray_flags ray_flags = { .force_opaque = radv_test_flag(b, args, SpvRayFlagsOpaqueKHRMask, true), @@ -722,38 +782,47 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc nir_push_loop(b); { - nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE)); - { - /* Early exit if we never overflowed the stack, to avoid having to backtrack to - * the root for no reason. */ - nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride)); + /* When exiting instances via stack, current_node won't ever be invalid with ds_bvh_stack_rtn */ + if (args->use_bvh_stack_rtn) { + /* Early-exit when the stack is empty and there are no more nodes to process. */ + nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_STACK_TERMINAL_NODE)); { nir_store_var(b, incomplete, nir_imm_false(b), 0x1); nir_jump(b, nir_jump_break); } nir_pop_if(b, NULL); + build_instance_exit(b, pdev, args, + nir_ilt(b, nir_load_deref(b, args->vars.stack), nir_load_deref(b, args->vars.top_stack)), + ptr_flags); + } - nir_def *stack_instance_exit = - nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack)); - nir_def *root_instance_exit = - nir_ieq(b, nir_load_deref(b, args->vars.previous_node), nir_load_deref(b, args->vars.instance_bottom_node)); - nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit)); - instance_exit->control = nir_selection_control_dont_flatten; - { - nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1); - nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1); - nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1); - - nir_store_deref(b, args->vars.bvh_base, build_bvh_base(b, pdev, args->root_bvh_base, ptr_flags, true), 0x1); - nir_store_deref(b, args->vars.origin, args->origin, 7); - nir_store_deref(b, args->vars.dir, args->dir, 7); - nir_store_deref(b, args->vars.inv_dir, nir_fdiv(b, vec3ones, args->dir), 7); + nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE)); + { + /* Early exit if we never overflowed the stack, to avoid having to backtrack to + * the root for no reason. */ + if (!args->use_bvh_stack_rtn) { + nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride)); + { + nir_store_var(b, incomplete, nir_imm_false(b), 0x1); + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, NULL); + build_instance_exit( + b, pdev, args, nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack)), + ptr_flags); } - nir_pop_if(b, NULL); - nir_push_if( - b, nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack))); + nir_def *overflow_cond = + nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack)); + /* ds_bvh_stack_rtn returns 0xFFFFFFFF if and only if there was a stack overflow. */ + if (args->use_bvh_stack_rtn) + overflow_cond = nir_imm_true(b); + + nir_push_if(b, overflow_cond); { + /* Fix up the stack pointer if we overflowed. The HW will decrement the stack pointer by one in that case. */ + if (args->use_bvh_stack_rtn) + nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), 1), 0x1); nir_def *prev = nir_load_deref(b, args->vars.previous_node); nir_def *bvh_addr = build_node_to_addr(device, b, nir_load_deref(b, args->vars.bvh_base), true); @@ -768,13 +837,15 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc } nir_push_else(b, NULL); { - nir_store_deref(b, args->vars.stack, - nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1); + if (!args->use_bvh_stack_rtn) { + nir_store_deref(b, args->vars.stack, + nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1); - nir_def *stack_ptr = - nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries); - nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args); - nir_store_deref(b, args->vars.current_node, bvh_node, 0x1); + nir_def *stack_ptr = + nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries); + nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args); + nir_store_deref(b, args->vars.current_node, bvh_node, 0x1); + } nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); } nir_pop_if(b, NULL); @@ -786,19 +857,25 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc nir_pop_if(b, NULL); nir_def *bvh_node = nir_load_deref(b, args->vars.current_node); + if (args->use_bvh_stack_rtn) + nir_store_var(b, last_visited_node, nir_imm_int(b, RADV_BVH_STACK_TERMINAL_NODE), 0x1); + else + nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); nir_def *prev_node = nir_load_deref(b, args->vars.previous_node); nir_store_deref(b, args->vars.previous_node, bvh_node, 0x1); - nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); nir_def *global_bvh_node = nir_iadd(b, nir_load_deref(b, args->vars.bvh_base), nir_u2u64(b, bvh_node)); - nir_def *intrinsic_result = NULL; + bool has_result = false; if (pdev->info.has_image_bvh_intersect_ray && !radv_emulate_rt(pdev)) { - intrinsic_result = + nir_store_var( + b, intrinsic_result, nir_bvh64_intersect_ray_amd(b, 32, desc, nir_unpack_64_2x32(b, global_bvh_node), nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin), - nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir)); + nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir)), + 0xf); + has_result = true; } nir_push_if(b, nir_test_mask(b, bvh_node, BITFIELD64_BIT(ffs(radv_bvh_node_box16) - 1))); @@ -833,9 +910,10 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc nir_def *instance_and_mask = nir_channel(b, instance_data, 2); nir_push_if(b, nir_ult(b, nir_iand(b, instance_and_mask, args->cull_mask), nir_imm_int(b, 1 << 24))); { - nir_jump(b, nir_jump_continue); + if (!args->use_bvh_stack_rtn) + nir_jump(b, nir_jump_continue); } - nir_pop_if(b, NULL); + nir_push_else(b, NULL); } nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1); @@ -856,7 +934,15 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc build_bvh_base(b, pdev, instance_pointer, ptr_flags, false), 0x1); /* Push the instance root node onto the stack */ - nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1); + if (args->use_bvh_stack_rtn) { + nir_store_var(b, last_visited_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); + nir_store_var(b, intrinsic_result, + nir_imm_ivec4(b, RADV_BVH_ROOT_NODE, RADV_BVH_INVALID_NODE, RADV_BVH_INVALID_NODE, + RADV_BVH_INVALID_NODE), + 0xf); + } else { + nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1); + } nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 1); nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1); @@ -864,13 +950,17 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc nir_store_deref(b, args->vars.origin, nir_build_vec3_mat_mult(b, args->origin, wto_matrix, true), 7); nir_store_deref(b, args->vars.dir, nir_build_vec3_mat_mult(b, args->dir, wto_matrix, false), 7); nir_store_deref(b, args->vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_deref(b, args->vars.dir)), 7); + if (!args->ignore_cull_mask) + nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } nir_push_else(b, NULL); { - nir_def *result = intrinsic_result; - if (!result) { + nir_def *result; + if (has_result) { + result = nir_load_var(b, intrinsic_result); + } else { /* If we didn't run the intrinsic cause the hardware didn't support it, * emulate ray/box intersection here */ result = intersect_ray_amd_software_box( @@ -879,49 +969,55 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc } /* box */ - nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE)); - { - nir_def *new_nodes[4]; - for (unsigned i = 0; i < 4; ++i) - new_nodes[i] = nir_channel(b, result, i); + if (args->use_bvh_stack_rtn) { + nir_store_var(b, last_visited_node, prev_node, 0x1); + } else { + nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE)); + { + nir_def *new_nodes[4]; + for (unsigned i = 0; i < 4; ++i) + new_nodes[i] = nir_channel(b, result, i); - for (unsigned i = 1; i < 4; ++i) - nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE)); + for (unsigned i = 1; i < 4; ++i) + nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE)); - for (unsigned i = 4; i-- > 1;) { - nir_def *stack = nir_load_deref(b, args->vars.stack); - nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride); - args->stack_store_cb(b, stack_ptr, new_nodes[i], args); - nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1); + for (unsigned i = 4; i-- > 1;) { + nir_def *stack = nir_load_deref(b, args->vars.stack); + nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride); + args->stack_store_cb(b, stack_ptr, new_nodes[i], args); + nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1); - if (i == 1) { - nir_def *new_watermark = - nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_entries * args->stack_stride); - new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark); - nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1); + if (i == 1) { + nir_def *new_watermark = nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), + -args->stack_entries * args->stack_stride); + new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark); + nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1); + } + + nir_pop_if(b, NULL); } - - nir_pop_if(b, NULL); + nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1); } - nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1); - } - nir_push_else(b, NULL); - { - nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE); - for (unsigned i = 0; i < 3; ++i) { - next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), nir_channel(b, result, i + 1), - next); + nir_push_else(b, NULL); + { + nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE); + for (unsigned i = 0; i < 3; ++i) { + next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), + nir_channel(b, result, i + 1), next); + } + nir_store_deref(b, args->vars.current_node, next, 0x1); } - nir_store_deref(b, args->vars.current_node, next, 0x1); + nir_pop_if(b, NULL); } - nir_pop_if(b, NULL); } nir_pop_if(b, NULL); } nir_push_else(b, NULL); { - nir_def *result = intrinsic_result; - if (!result) { + nir_def *result; + if (has_result) { + result = nir_load_var(b, intrinsic_result); + } else { /* If we didn't run the intrinsic cause the hardware didn't support it, * emulate ray/tri intersection here */ result = intersect_ray_amd_software_tri( @@ -937,6 +1033,21 @@ radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struc iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1); nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1); } + if (args->use_bvh_stack_rtn) { + nir_def *stack_result = + nir_bvh_stack_rtn_amd(b, 32, nir_load_deref(b, args->vars.stack), nir_load_var(b, last_visited_node), + nir_load_var(b, intrinsic_result), .stack_size = args->stack_entries); + nir_store_deref(b, args->vars.stack, nir_channel(b, stack_result, 0), 0x1); + nir_store_deref(b, args->vars.current_node, nir_channel(b, stack_result, 1), 0x1); + } + + if (args->vars.break_flag) { + nir_push_if(b, nir_load_deref(b, args->vars.break_flag)); + { + nir_jump(b, nir_jump_break); + } + nir_pop_if(b, NULL); + } } nir_pop_loop(b, NULL); diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.h b/src/amd/vulkan/nir/radv_nir_rt_common.h index aaef6ca72cd..c0da9348c0c 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_common.h +++ b/src/amd/vulkan/nir/radv_nir_rt_common.h @@ -13,6 +13,12 @@ #include "compiler/spirv/spirv.h" struct radv_device; +struct radv_physical_device; + +bool radv_use_bvh_stack_rtn(const struct radv_physical_device *pdevice); + +nir_def *radv_build_bvh_stack_rtn_addr(nir_builder *b, const struct radv_physical_device *pdev, uint32_t workgroup_size, + uint32_t stack_base, uint32_t max_stack_entries); nir_def *build_addr_to_node(struct radv_device *device, nir_builder *b, nir_def *addr, nir_def *flags); @@ -107,6 +113,9 @@ struct radv_ray_traversal_vars { nir_deref_instr *instance_addr; nir_deref_instr *sbt_offset_and_flags; + /* If non-NULL, contains a boolean flag whether to break after the current iteration. */ + nir_deref_instr *break_flag; + /* Statistics. Iteration count in the low 16 bits, candidate instance counts in the high 16 bits. */ nir_deref_instr *iteration_instance_count; }; @@ -132,6 +141,7 @@ struct radv_ray_traversal_args { bool ignore_cull_mask; + bool use_bvh_stack_rtn; radv_rt_stack_store_cb stack_store_cb; radv_rt_stack_load_cb stack_load_cb; diff --git a/src/amd/vulkan/nir/radv_nir_rt_shader.c b/src/amd/vulkan/nir/radv_nir_rt_shader.c index 086726052ec..1d9d97586d1 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_shader.c +++ b/src/amd/vulkan/nir/radv_nir_rt_shader.c @@ -1411,23 +1411,26 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept))); { nir_store_var(b, data->barycentrics, prev_barycentrics, 0x3); - nir_jump(b, nir_jump_continue); } nir_pop_if(b, NULL); } nir_pop_if(b, NULL); - nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1); - nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1); - nir_store_var(b, data->vars->tmax, intersection->t, 0x1); - nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1); - nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1); + nir_push_if(b, nir_load_var(b, data->vars->ahit_accept)); + { + nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1); + nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1); + nir_store_var(b, data->vars->tmax, intersection->t, 0x1); + nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1); + nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1); - nir_store_var(b, data->vars->idx, sbt_idx, 1); - nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1); + nir_store_var(b, data->vars->idx, sbt_idx, 1); + nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1); - nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate); - nir_break_if(b, nir_ior(b, ray_flags->terminate_on_first_hit, ray_terminated)); + nir_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate); + nir_break_if(b, nir_ior(b, ray_flags->terminate_on_first_hit, ray_terminated)); + } + nir_pop_if(b, NULL); } static void @@ -1528,6 +1531,17 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin nir_store_var(b, trav_vars.bvh_base, root_bvh_base, 1); + nir_def *stack_idx = nir_load_local_invocation_index(b); + uint32_t stack_stride; + + if (radv_use_bvh_stack_rtn(pdev)) { + stack_idx = radv_build_bvh_stack_rtn_addr(b, pdev, pdev->rt_wave_size, 0, MAX_STACK_ENTRY_COUNT); + stack_stride = 1; + } else { + stack_idx = nir_imul_imm(b, stack_idx, sizeof(uint32_t)); + stack_stride = pdev->rt_wave_size * sizeof(uint32_t); + } + nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0); nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); @@ -1536,7 +1550,7 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1); nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); - nir_store_var(b, trav_vars.stack, nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)), 1); + nir_store_var(b, trav_vars.stack, stack_idx, 1); nir_store_var(b, trav_vars.stack_low_watermark, nir_load_var(b, trav_vars.stack), 1); nir_store_var(b, trav_vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1); nir_store_var(b, trav_vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); @@ -1588,7 +1602,7 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin .tmin = nir_load_var(b, vars->tmin), .dir = nir_load_var(b, vars->direction), .vars = trav_vars_args, - .stack_stride = pdev->rt_wave_size * sizeof(uint32_t), + .stack_stride = stack_stride, .stack_entries = MAX_STACK_ENTRY_COUNT, .stack_base = 0, .ignore_cull_mask = ignore_cull_mask, @@ -1602,6 +1616,7 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin .triangle_cb = (pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR) ? NULL : handle_candidate_triangle, + .use_bvh_stack_rtn = radv_use_bvh_stack_rtn(pdev), .data = &data, };