diff --git a/src/amd/vulkan/radv_pipeline_graphics.c b/src/amd/vulkan/radv_pipeline_graphics.c index fdfc0d23916..494bae90ad1 100644 --- a/src/amd/vulkan/radv_pipeline_graphics.c +++ b/src/amd/vulkan/radv_pipeline_graphics.c @@ -3568,43 +3568,43 @@ radv_emit_vgt_vertex_reuse(const struct radv_device *device, struct radeon_cmdbu } static struct radv_vgt_shader_key -radv_pipeline_generate_vgt_shader_key(const struct radv_device *device, const struct radv_graphics_pipeline *pipeline) +radv_get_vgt_shader_key(const struct radv_device *device, struct radv_shader **shaders, + const struct radv_shader *gs_copy_shader) { - const struct radv_shader *last_vgt_shader = radv_get_last_vgt_shader(pipeline); uint8_t hs_size = 64, gs_size = 64, vs_size = 64; + struct radv_shader *last_vgt_shader = NULL; struct radv_vgt_shader_key key; memset(&key, 0, sizeof(key)); - if (radv_pipeline_has_stage(pipeline, MESA_SHADER_TESS_CTRL)) - hs_size = pipeline->base.shaders[MESA_SHADER_TESS_CTRL]->info.wave_size; - - if (pipeline->base.shaders[MESA_SHADER_GEOMETRY]) { - vs_size = gs_size = pipeline->base.shaders[MESA_SHADER_GEOMETRY]->info.wave_size; - if (radv_pipeline_has_gs_copy_shader(&pipeline->base)) - vs_size = pipeline->base.gs_copy_shader->info.wave_size; - } else if (pipeline->base.shaders[MESA_SHADER_TESS_EVAL]) - vs_size = pipeline->base.shaders[MESA_SHADER_TESS_EVAL]->info.wave_size; - else if (pipeline->base.shaders[MESA_SHADER_VERTEX]) - vs_size = pipeline->base.shaders[MESA_SHADER_VERTEX]->info.wave_size; - else if (pipeline->base.shaders[MESA_SHADER_MESH]) - vs_size = gs_size = pipeline->base.shaders[MESA_SHADER_MESH]->info.wave_size; - - if (last_vgt_shader->info.is_ngg) { - assert(!radv_pipeline_has_gs_copy_shader(&pipeline->base)); - gs_size = vs_size; + if (shaders[MESA_SHADER_GEOMETRY]) { + last_vgt_shader = shaders[MESA_SHADER_GEOMETRY]; + } else if (shaders[MESA_SHADER_TESS_EVAL]) { + last_vgt_shader = shaders[MESA_SHADER_TESS_EVAL]; + } else if (shaders[MESA_SHADER_VERTEX]) { + last_vgt_shader = shaders[MESA_SHADER_VERTEX]; + } else { + assert(shaders[MESA_SHADER_MESH]); + last_vgt_shader = shaders[MESA_SHADER_MESH]; } - key.tess = radv_pipeline_has_stage(pipeline, MESA_SHADER_TESS_CTRL); - key.gs = radv_pipeline_has_stage(pipeline, MESA_SHADER_GEOMETRY); + vs_size = gs_size = last_vgt_shader->info.wave_size; + if (gs_copy_shader) + vs_size = gs_copy_shader->info.wave_size; + + if (shaders[MESA_SHADER_TESS_CTRL]) + hs_size = shaders[MESA_SHADER_TESS_CTRL]->info.wave_size; + + key.tess = !!shaders[MESA_SHADER_TESS_CTRL]; + key.gs = !!shaders[MESA_SHADER_GEOMETRY]; if (last_vgt_shader->info.is_ngg) { key.ngg = 1; key.ngg_passthrough = last_vgt_shader->info.is_ngg_passthrough; key.ngg_streamout = last_vgt_shader->info.so.num_outputs > 0; } - if (radv_pipeline_has_stage(pipeline, MESA_SHADER_MESH)) { + if (shaders[MESA_SHADER_MESH]) { key.mesh = 1; - key.mesh_scratch_ring = pipeline->base.shaders[MESA_SHADER_MESH]->info.ms.needs_ms_scratch_ring; + key.mesh_scratch_ring = shaders[MESA_SHADER_MESH]->info.ms.needs_ms_scratch_ring; } key.hs_wave32 = hs_size == 32; @@ -3763,7 +3763,8 @@ radv_pipeline_emit_pm4(const struct radv_device *device, struct radv_graphics_pi cs->buf = malloc(4 * (cs->max_dw + ctx_cs->max_dw)); ctx_cs->buf = cs->buf + cs->max_dw; - struct radv_vgt_shader_key vgt_shader_key = radv_pipeline_generate_vgt_shader_key(device, pipeline); + const struct radv_vgt_shader_key vgt_shader_key = + radv_get_vgt_shader_key(device, pipeline->base.shaders, pipeline->base.gs_copy_shader); radv_emit_blend_state(ctx_cs, ps, blend->spi_shader_col_format, blend->cb_shader_mask); radv_emit_vgt_gs_mode(device, ctx_cs, pipeline->base.shaders[pipeline->last_vgt_api_stage]);