radv/rt: radv: gather push constant size from shaders for RT

And store the total push constant size to the RT prolog.

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37769>
This commit is contained in:
Samuel Pitoiset 2025-10-08 14:10:00 +02:00 committed by Marge Bot
parent aa44a5a4ae
commit 97dbf7b895

View file

@ -432,6 +432,9 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
/* Info might be out-of-date after inlining in radv_nir_lower_rt_abi(). */
nir_shader_gather_info(temp_stage.nir, nir_shader_get_entrypoint(temp_stage.nir));
radv_nir_shader_info_pass(device, temp_stage.nir, &stage->layout, &stage->key, NULL, RADV_PIPELINE_RAY_TRACING,
false, &stage->info);
radv_optimize_nir(temp_stage.nir, stage->key.optimisations_disabled);
radv_postprocess_nir(device, NULL, &temp_stage);
stage->info.nir_shared_size = MAX2(stage->info.nir_shared_size, temp_stage.info.nir_shared_size);
@ -829,7 +832,7 @@ compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, stru
}
static void
combine_config(struct ac_shader_config *config, struct ac_shader_config *other)
combine_config(struct ac_shader_config *config, const struct ac_shader_config *other)
{
config->num_sgprs = MAX2(config->num_sgprs, other->num_sgprs);
config->num_vgprs = MAX2(config->num_vgprs, other->num_vgprs);
@ -859,21 +862,35 @@ static void
compile_rt_prolog(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
uint32_t push_constant_size = 0;
pipeline->prolog = radv_create_rt_prolog(device);
/* create combined config */
struct ac_shader_config *config = &pipeline->prolog->config;
for (unsigned i = 0; i < pipeline->stage_count; i++)
if (pipeline->stages[i].shader)
combine_config(config, &pipeline->stages[i].shader->config);
for (unsigned i = 0; i < pipeline->stage_count; i++) {
const struct radv_shader *shader = pipeline->stages[i].shader;
if (pipeline->base.base.shaders[MESA_SHADER_INTERSECTION])
combine_config(config, &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]->config);
if (!shader)
continue;
combine_config(config, &shader->config);
push_constant_size = MAX2(push_constant_size, shader->info.push_constant_size);
}
if (pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]) {
const struct radv_shader *traversal_shader = pipeline->base.base.shaders[MESA_SHADER_INTERSECTION];
combine_config(config, &traversal_shader->config);
push_constant_size = MAX2(push_constant_size, traversal_shader->info.push_constant_size);
}
postprocess_rt_config(config, &pdev->info, pdev->rt_wave_size);
pipeline->prolog->max_waves = radv_get_max_waves(device, config, &pipeline->prolog->info);
pipeline->prolog->info.push_constant_size = push_constant_size;
}
void