diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 31a620718ba..09d5a707366 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -185,7 +185,10 @@ fail: * Global variables for an RT pipeline */ struct rt_variables { - /* idx of the next shader to run in the next iteration of the main loop */ + /* idx of the next shader to run in the next iteration of the main loop. + * During traversal, idx is used to store the SBT index and will contain + * the correct resume index upon returning. + */ nir_variable *idx; /* scratch offset of the argument area relative to stack_ptr */ @@ -1076,7 +1079,7 @@ struct rt_traversal_vars { nir_variable *instance_id; nir_variable *custom_instance_and_mask; nir_variable *instance_addr; - nir_variable *should_return; + nir_variable *hit; nir_variable *bvh_base; nir_variable *stack; nir_variable *top_stack; @@ -1100,8 +1103,7 @@ init_traversal_vars(nir_builder *b) b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask"); ret.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr"); - ret.should_return = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), - "traversal_should_return"); + ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit"); ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_bvh_base"); ret.stack = @@ -1248,14 +1250,8 @@ insert_traversal_triangle_case(struct radv_device *device, nir_store_var(b, vars->custom_instance_and_mask, nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); - load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0); - - nir_store_var(b, trav_vars->should_return, - nir_ior(b, - nir_test_mask(b, nir_load_var(b, vars->flags), - SpvRayFlagsSkipClosestHitShaderKHRMask), - nir_ieq_imm(b, nir_load_var(b, vars->idx), 0)), - 1); + nir_store_var(b, vars->idx, sbt_idx, 1); + nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1); nir_ssa_def *terminate_on_first_hit = nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask); @@ -1399,14 +1395,8 @@ insert_traversal_aabb_case(struct radv_device *device, nir_store_var(b, vars->custom_instance_and_mask, nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); - load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0); - - nir_store_var(b, trav_vars->should_return, - nir_ior(b, - nir_test_mask(b, nir_load_var(b, vars->flags), - SpvRayFlagsSkipClosestHitShaderKHRMask), - nir_ieq_imm(b, nir_load_var(b, vars->idx), 0)), - 1); + nir_store_var(b, vars->idx, sbt_idx, 1); + nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1); nir_ssa_def *terminate_on_first_hit = nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask); @@ -1449,11 +1439,7 @@ build_traversal_shader(struct radv_device *device, struct rt_traversal_vars trav_vars = init_traversal_vars(&b); - /* Initialize the follow-up shader idx to 0, to be replaced by the miss shader - * if we actually miss. */ - nir_store_var(&b, vars.idx, nir_imm_int(&b, 0), 1); - - nir_store_var(&b, trav_vars.should_return, nir_imm_bool(&b, false), 1); + nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1); nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0)); { @@ -1617,22 +1603,30 @@ build_traversal_shader(struct radv_device *device, } nir_pop_if(&b, NULL); - /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence - * need to return immediately to the calling shader. */ - nir_push_if(&b, nir_load_var(&b, trav_vars.should_return)); + /* Initialize follow-up shader. */ + nir_push_if(&b, nir_load_var(&b, trav_vars.hit)); { - insert_rt_return(&b, &vars); + /* vars.idx contains the SBT index at this point. */ + load_sbt_entry(&b, &vars, nir_load_var(&b, vars.idx), SBT_HIT, 0); + + nir_ssa_def *should_return = nir_ior(&b, + nir_test_mask(&b, nir_load_var(&b, vars.flags), + SpvRayFlagsSkipClosestHitShaderKHRMask), + nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0)); + + /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence + * need to return immediately to the calling shader. */ + nir_push_if(&b, should_return); + { + insert_rt_return(&b, &vars); + } + nir_pop_if(&b, NULL); } nir_push_else(&b, NULL); { - /* Only load the miss shader if we actually miss, which we determining by not having set - * a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none - * of the rays miss. */ - nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0)); - { - load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0); - } - nir_pop_if(&b, NULL); + /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer + * for miss shaders if none of the rays miss. */ + load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0); } nir_pop_if(&b, NULL);