kk: Ensure some vertex lowerings happen on hardware stage

Lowerings like ensuring correct point size outputs, vertex position
output, and clip space control are expected to happen on the shader
stage which becomes the vertex stage on the actual hardware.

Reviewed-by: Aitor Camacho <aitor@lunarg.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/42020>
This commit is contained in:
squidbus 2026-06-04 05:47:34 -07:00 committed by Marge Bot
parent 859a7f5436
commit 9b587e5ff1
2 changed files with 38 additions and 17 deletions

View file

@ -219,7 +219,8 @@ msl_ensure_depth_write(nir_shader *nir)
bool
msl_ensure_vertex_position_output(nir_shader *nir)
{
assert(nir->info.stage == MESA_SHADER_VERTEX);
assert(nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_EVAL);
bool has_position_write =
nir->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_POS);
@ -247,7 +248,8 @@ msl_ensure_vertex_position_output(nir_shader *nir)
bool
msl_ensure_vertex_point_size_output(nir_shader *nir)
{
assert(nir->info.stage == MESA_SHADER_VERTEX);
assert(nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_EVAL);
bool has_point_size_write =
nir->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_PSIZ);
@ -332,7 +334,8 @@ msl_vs_io_types(nir_builder *b, nir_intrinsic_instr *intr, void *data)
bool
msl_nir_vs_io_types(nir_shader *nir)
{
assert(nir->info.stage == MESA_SHADER_VERTEX);
assert(nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_EVAL);
return nir_shader_intrinsics_pass(nir, msl_vs_io_types, nir_metadata_all,
NULL);
}

View file

@ -326,17 +326,21 @@ kk_lower_vs_vbo(nir_shader *nir, const struct vk_graphics_pipeline_state *state,
NIR_PASS(_, nir, kk_nir_lower_vbo, attributes, robustness2);
}
/* Lowering for the stage which ends up as the vertex stage on hardware */
static void
kk_lower_vs(nir_shader *nir, const struct vk_graphics_pipeline_state *state)
kk_lower_hw_vs(nir_shader *nir, const struct vk_graphics_pipeline_state *state)
{
NIR_PASS(_, nir, msl_ensure_vertex_position_output);
if (state->ia->primitive_topology == VK_PRIMITIVE_TOPOLOGY_POINT_LIST)
bool is_point =
nir->info.stage == MESA_SHADER_TESS_EVAL
? nir->info.tess.point_mode
: state->ia->primitive_topology == VK_PRIMITIVE_TOPOLOGY_POINT_LIST;
if (is_point)
NIR_PASS(_, nir, msl_ensure_vertex_point_size_output);
if (state->ia->primitive_topology != VK_PRIMITIVE_TOPOLOGY_POINT_LIST)
else
nir_shader_intrinsics_pass(nir, msl_nir_vs_remove_point_size_write,
nir_metadata_control_flow, NULL);
NIR_PASS(_, nir, msl_ensure_vertex_position_output);
NIR_PASS(_, nir, nir_lower_clip_halfz_dynamic);
NIR_PASS(_, nir, msl_nir_vs_io_types);
}
@ -481,7 +485,7 @@ kk_lower_fs(struct kk_device *dev, nir_shader *nir,
}
static void
kk_lower_nir(struct kk_device *dev, nir_shader *nir,
kk_lower_nir(struct kk_device *dev, nir_shader *nir, bool emulated_stage,
const struct vk_pipeline_robustness_state *rs,
uint32_t set_layout_count,
struct vk_descriptor_set_layout *const *set_layouts,
@ -579,8 +583,10 @@ kk_lower_nir(struct kk_device *dev, nir_shader *nir,
NIR_PASS(_, nir, nir_opt_constant_folding);
/* These passes operate on lowered IO. */
if (nir->info.stage == MESA_SHADER_VERTEX) {
kk_lower_vs(nir, state);
if ((nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_EVAL) &&
!emulated_stage) {
kk_lower_hw_vs(nir, state);
} else if (nir->info.stage == MESA_SHADER_FRAGMENT) {
kk_lower_fs(dev, nir, state);
}
@ -805,9 +811,6 @@ kk_compile_shader(struct kk_device *dev, nir_shader *nir,
/* This destroys info so it needs to happen after the gather */
NIR_PASS(_, nir, poly_nir_lower_tes, true);
NIR_PASS(_, nir, msl_ensure_vertex_position_output);
NIR_PASS(_, nir, msl_nir_vs_io_types);
}
NIR_PASS(_, nir, kk_nir_lower_poly);
@ -933,7 +936,7 @@ get_empty_nir(struct kk_device *dev, mesa_shader_stage stage,
.null_uniform_buffer_descriptor = false,
.null_storage_buffer_descriptor = false,
};
kk_lower_nir(dev, nir, &no_robustness, 0u, NULL, state);
kk_lower_nir(dev, nir, false, &no_robustness, 0u, NULL, state);
nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
return nir;
@ -1271,6 +1274,17 @@ kk_compile_shaders(struct vk_device *device, uint32_t shader_count,
nir_shader *nir_shaders[shader_count + 1u];
struct kk_shader *shaders[shader_count + 1u];
/* Determine if the pipeline contains tessellation stages */
bool tess = false;
for (uint32_t i = 0u; i < shader_count; ++i) {
const struct vk_shader_compile_info *info = &infos[i];
if (info->nir->info.stage == MESA_SHADER_TESS_CTRL ||
info->nir->info.stage == MESA_SHADER_TESS_EVAL) {
tess = true;
break;
}
}
/* Lower shaders, notably lowering IO. This is a prerequisite for intershader
* optimization. */
const struct vk_pipeline_robustness_state *vertex_robustness = &rs_none;
@ -1278,9 +1292,13 @@ kk_compile_shaders(struct vk_device *device, uint32_t shader_count,
const struct vk_shader_compile_info *info = &infos[i];
nir_shader *nir = info->nir;
/* For tessellation pipelines, some stages may be emulated in compute */
bool emulated_stage = tess && (nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_CTRL);
msl_preprocess_nir_workarounds(nir, dev->disabled_workarounds);
kk_lower_nir(dev, nir, info->robustness, info->set_layout_count,
info->set_layouts, state);
kk_lower_nir(dev, nir, emulated_stage, info->robustness,
info->set_layout_count, info->set_layouts, state);
if (nir->info.stage == MESA_SHADER_VERTEX)
vertex_robustness = info->robustness;