radv: Use provided handles for switch cases in RT shaders.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21406>
This commit is contained in:
Bas Nieuwenhuizen 2023-01-11 01:30:24 +01:00 committed by Marge Bot
parent 430170702e
commit 913de78731
3 changed files with 66 additions and 25 deletions

View file

@ -435,7 +435,8 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
goto pipeline_fail;
}
shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes, &key);
shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes,
rt_pipeline->group_handles, &key);
module.nir = shader;
result = radv_compute_pipeline_compile(
&rt_pipeline->base, pipeline_layout, device, cache, &key, &stage, pCreateInfo->flags,

View file

@ -1078,10 +1078,20 @@ init_traversal_vars(nir_builder *b)
return ret;
}
struct traversal_data {
struct radv_device *device;
const VkRayTracingPipelineCreateInfoKHR *createInfo;
struct rt_variables *vars;
struct rt_traversal_vars *trav_vars;
nir_variable *barycentrics;
const struct radv_pipeline_group_handle *handles;
};
static void
visit_any_hit_shaders(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
struct rt_variables *vars)
struct traversal_data *data, struct rt_variables *vars)
{
nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
@ -1102,25 +1112,26 @@ visit_any_hit_shaders(struct radv_device *device,
if (shader_id == VK_SHADER_UNUSED_KHR)
continue;
/* Avoid emitting stages with the same shaders/handles multiple times. */
bool is_dup = false;
for (unsigned j = 0; j < i; ++j)
if (data->handles[j].any_hit_index == data->handles[i].any_hit_index)
is_dup = true;
if (is_dup)
continue;
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key);
vars->stage_idx = shader_id;
insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->handles[i].any_hit_index);
}
if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
nir_pop_if(b, NULL);
}
struct traversal_data {
struct radv_device *device;
const VkRayTracingPipelineCreateInfoKHR *createInfo;
struct rt_variables *vars;
struct rt_traversal_vars *trav_vars;
nir_variable *barycentrics;
};
static void
handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
const struct radv_ray_traversal_args *args,
@ -1158,7 +1169,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
visit_any_hit_shaders(data->device, data->createInfo, b, &inner_vars);
visit_any_hit_shaders(data->device, data->createInfo, b, args->data, &inner_vars);
nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
{
@ -1237,6 +1248,15 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
if (shader_id == VK_SHADER_UNUSED_KHR)
continue;
/* Avoid emitting stages with the same shaders/handles multiple times. */
bool is_dup = false;
for (unsigned j = 0; j < i; ++j)
if (data->handles[j].intersection_index == data->handles[i].intersection_index)
is_dup = true;
if (is_dup)
continue;
const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id];
nir_shader *nir_stage = parse_rt_stage(data->device, stage, data->vars->key);
@ -1250,7 +1270,8 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
}
inner_vars.stage_idx = shader_id;
insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0,
data->handles[i].intersection_index);
}
if (!(data->vars->create_info->flags &
@ -1297,6 +1318,7 @@ static nir_shader *
build_traversal_shader(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes,
const struct radv_pipeline_group_handle *handles,
const struct radv_pipeline_key *key)
{
/* Create the traversal shader as an intersection shader to prevent validation failures due to
@ -1383,6 +1405,7 @@ build_traversal_shader(struct radv_device *device,
.vars = &vars,
.trav_vars = &trav_vars,
.barycentrics = barycentrics,
.handles = handles,
};
struct radv_ray_traversal_args args = {
@ -1518,6 +1541,7 @@ lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs)
nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes,
const struct radv_pipeline_group_handle *handles,
const struct radv_pipeline_key *key)
{
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
@ -1554,23 +1578,37 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_ssa_def *idx = nir_load_var(&b, vars.idx);
/* Insert traversal shader */
nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, key);
nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, handles, key);
assert(b.shader->info.shared_size == 0);
b.shader->info.shared_size = traversal->info.shared_size;
assert(b.shader->info.shared_size <= 32768);
insert_rt_case(&b, traversal, &vars, idx, 0, 1);
/* We do a trick with the indexing of the resume shaders so that the first
* shader of stage x always gets id x and the resume shader ids then come after
* stageCount. This makes the shadergroup handles independent of compilation. */
unsigned call_idx_base = pCreateInfo->stageCount + 1;
for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) {
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE &&
type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
unsigned call_idx_base = 1;
for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
unsigned stage_idx = VK_SHADER_UNUSED_KHR;
if (pCreateInfo->pGroups[i].type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
stage_idx = pCreateInfo->pGroups[i].generalShader;
else
stage_idx = pCreateInfo->pGroups[i].closestHitShader;
if (stage_idx == VK_SHADER_UNUSED_KHR)
continue;
/* Avoid emitting stages with the same shaders/handles multiple times. */
bool is_dup = false;
for (unsigned j = 0; j < i; ++j)
if (handles[j].general_index == handles[i].general_index)
is_dup = true;
if (is_dup)
continue;
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[stage_idx];
ASSERTED gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
assert(type == MESA_SHADER_RAYGEN || type == MESA_SHADER_CALLABLE ||
type == MESA_SHADER_CLOSEST_HIT || type == MESA_SHADER_MISS);
nir_shader *nir_stage = parse_rt_stage(device, stage, key);
/* Move ray tracing system values to the top that are set by rt_trace_ray
@ -1588,8 +1626,8 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_shader **resume_shaders = NULL;
nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, &num_resume_shaders, nir_stage);
vars.stage_idx = i;
insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
vars.stage_idx = stage_idx;
insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, handles[i].general_index);
for (unsigned j = 0; j < num_resume_shaders; ++j) {
insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
}

View file

@ -47,6 +47,7 @@ struct radv_physical_device;
struct radv_device;
struct radv_pipeline;
struct radv_pipeline_cache;
struct radv_pipeline_group_handle;
struct radv_pipeline_key;
struct radv_shader_args;
struct radv_vs_input_state;
@ -755,6 +756,7 @@ bool radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage
nir_shader *create_rt_shader(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes,
const struct radv_pipeline_group_handle *handles,
const struct radv_pipeline_key *key);
#endif