diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 597671510ac..44506205724 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -597,10 +597,15 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_HIT, SBT_CLOSEST_HIT_IDX); nir_ssa_def *should_return = - nir_ior(&b_shader, - nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags), - SpvRayFlagsSkipClosestHitShaderKHRMask), - nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0)); + nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags), + SpvRayFlagsSkipClosestHitShaderKHRMask); + + if (!(vars->create_info->flags & + VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) { + should_return = + nir_ior(&b_shader, should_return, + nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0)); + } /* should_return is set if we had a hit but we won't be calling the closest hit * shader and hence need to return immediately to the calling shader. */ @@ -619,10 +624,15 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca nir_ssa_def *miss_index = nir_load_var(&b_shader, vars->miss_index); load_sbt_entry(&b_shader, vars, miss_index, SBT_MISS, SBT_GENERAL_IDX); - /* In case of a NULL miss shader, do nothing and just return. */ - nir_push_if(&b_shader, nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0)); - insert_rt_return(&b_shader, vars); - nir_pop_if(&b_shader, NULL); + if (!(vars->create_info->flags & + VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) { + /* In case of a NULL miss shader, do nothing and just return. */ + nir_push_if(&b_shader, + nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0)); + insert_rt_return(&b_shader, vars); + nir_pop_if(&b_shader, NULL); + } + break; } default: @@ -1078,7 +1088,9 @@ visit_any_hit_shaders(struct radv_device *device, { nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx); - nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); + if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) + nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); + for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; uint32_t shader_id = VK_SHADER_UNUSED_KHR; @@ -1099,7 +1111,9 @@ visit_any_hit_shaders(struct radv_device *device, vars->stage_idx = shader_id; insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2); } - nir_pop_if(b, NULL); + + if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) + nir_pop_if(b, NULL); } struct traversal_data { @@ -1206,7 +1220,10 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1); nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1); - nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0)); + if (!(data->vars->create_info->flags & + VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR)) + nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0)); + for (unsigned i = 0; i < data->createInfo->groupCount; ++i) { const VkRayTracingShaderGroupCreateInfoKHR *group_info = &data->createInfo->pGroups[i]; uint32_t shader_id = VK_SHADER_UNUSED_KHR; @@ -1238,7 +1255,10 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio inner_vars.stage_idx = shader_id; insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2); } - nir_pop_if(b, NULL); + + if (!(data->vars->create_info->flags & + VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR)) + nir_pop_if(b, NULL); nir_push_if(b, nir_load_var(b, data->vars->ahit_accept)); {