From 48ae92ceea83c84d39f3fbb2d9b9ff4a7cef947e Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Sat, 10 Dec 2022 12:36:13 +0100 Subject: [PATCH] radv/rt: Propagate radv_pipeline_key Reviewed-by: Bas Nieuwenhuizen Part-of: --- src/amd/vulkan/radv_pipeline_rt.c | 2 +- src/amd/vulkan/radv_rt_shader.c | 39 ++++++++++++++++--------------- src/amd/vulkan/radv_shader.c | 2 +- src/amd/vulkan/radv_shader.h | 3 ++- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 68a8212e8c1..f9c8cbf071c 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -321,7 +321,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, goto pipeline_fail; } - shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes); + shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes, &key); module.nir = shader; result = radv_create_shaders(&rt_pipeline->base.base, pipeline_layout, device, cache, &key, &stage, 1, pCreateInfo->flags, hash, creation_feedback, diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 52c5833cd96..62b9c70b4d6 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -82,6 +82,7 @@ lower_rt_derefs(nir_shader *shader) */ struct rt_variables { const VkRayTracingPipelineCreateInfoKHR *create_info; + const struct radv_pipeline_key *key; /* idx of the next shader to run in the next iteration of the main loop. * During traversal, idx is used to store the SBT index and will contain @@ -143,10 +144,12 @@ reserve_stack_size(struct rt_variables *vars, uint32_t size) static struct rt_variables create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info, - struct radv_pipeline_shader_stack_size *stack_sizes) + struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_key *key) { struct rt_variables vars = { .create_info = create_info, + .key = key, }; vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx"); vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg"); @@ -754,7 +757,8 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni nir_opt_dead_cf(shader); - struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes); + struct rt_variables src_vars = + create_rt_variables(shader, vars->create_info, vars->stack_sizes, vars->key); map_rt_variables(var_remap, &src_vars, vars); NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base); @@ -776,16 +780,14 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni } static nir_shader * -parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo) +parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo, + const struct radv_pipeline_key *key) { - struct radv_pipeline_key key; - memset(&key, 0, sizeof(key)); - struct radv_pipeline_stage rt_stage; radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage)); - nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key); + nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key); if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) { @@ -1087,7 +1089,7 @@ visit_any_hit_shaders(struct radv_device *device, continue; const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; - nir_shader *nir_stage = parse_rt_stage(device, stage); + nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key); vars->stage_idx = shader_id; insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2); @@ -1217,12 +1219,12 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio continue; const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id]; - nir_shader *nir_stage = parse_rt_stage(data->device, stage); + nir_shader *nir_stage = parse_rt_stage(data->device, stage, data->vars->key); nir_shader *any_hit_stage = NULL; if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) { stage = &data->createInfo->pStages[any_hit_shader_id]; - any_hit_stage = parse_rt_stage(data->device, stage); + any_hit_stage = parse_rt_stage(data->device, stage, data->vars->key); nir_lower_intersection_shader(nir_stage, any_hit_stage); ralloc_free(any_hit_stage); @@ -1272,7 +1274,8 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave static nir_shader * build_traversal_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, - struct radv_pipeline_shader_stack_size *stack_sizes) + struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_key *key) { /* Create the traversal shader as an intersection shader to prevent validation failures due to * invalid variable modes.*/ @@ -1282,7 +1285,7 @@ build_traversal_shader(struct radv_device *device, b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t); - struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes); + struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes, key); nir_variable *barycentrics = nir_variable_create( b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics"); @@ -1552,17 +1555,15 @@ lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs) nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, - struct radv_pipeline_shader_stack_size *stack_sizes) + struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_key *key) { - struct radv_pipeline_key key; - memset(&key, 0, sizeof(key)); - nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined"); b.shader->info.internal = false; b.shader->info.workgroup_size[0] = 8; b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; - struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes); + struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes, key); load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, SBT_GENERAL_IDX); if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1); @@ -1583,7 +1584,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_ssa_def *idx = nir_load_var(&b, vars.idx); /* Insert traversal shader */ - nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes); + nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, key); assert(b.shader->info.shared_size == 0); b.shader->info.shared_size = traversal->info.shared_size; assert(b.shader->info.shared_size <= 32768); @@ -1600,7 +1601,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS) continue; - nir_shader *nir_stage = parse_rt_stage(device, stage); + nir_shader *nir_stage = parse_rt_stage(device, stage, key); /* Move ray tracing system values to the top that are set by rt_trace_ray * to prevent them from being overwritten by other rt_trace_ray calls. diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 7a1e0e69d0a..6ea3f36ef4b 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -723,7 +723,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_pipeline_ /* Only compute shaders currently support requiring a * specific subgroup size. */ - assert(stage->stage == MESA_SHADER_COMPUTE); + assert(stage->stage >= MESA_SHADER_COMPUTE); subgroup_size = key->cs.compute_subgroup_size; ballot_bit_size = key->cs.compute_subgroup_size; } diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index 1723ba4bd36..062fe173f4f 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -754,6 +754,7 @@ bool radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage nir_shader *create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, - struct radv_pipeline_shader_stack_size *stack_sizes); + struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_key *key); #endif