From 9b587e5ff1da9798a0febd6918f67bda7d536e13 Mon Sep 17 00:00:00 2001 From: squidbus <1249084-squidbus@users.noreply.gitlab.freedesktop.org> Date: Thu, 4 Jun 2026 05:47:34 -0700 Subject: [PATCH] kk: Ensure some vertex lowerings happen on hardware stage Lowerings like ensuring correct point size outputs, vertex position output, and clip space control are expected to happen on the shader stage which becomes the vertex stage on the actual hardware. Reviewed-by: Aitor Camacho Part-of: --- .../compiler/msl_nir_lower_common.c | 9 ++-- src/kosmickrisp/vulkan/kk_shader.c | 46 +++++++++++++------ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/src/kosmickrisp/compiler/msl_nir_lower_common.c b/src/kosmickrisp/compiler/msl_nir_lower_common.c index a756ac947d6..8699eadad62 100644 --- a/src/kosmickrisp/compiler/msl_nir_lower_common.c +++ b/src/kosmickrisp/compiler/msl_nir_lower_common.c @@ -219,7 +219,8 @@ msl_ensure_depth_write(nir_shader *nir) bool msl_ensure_vertex_position_output(nir_shader *nir) { - assert(nir->info.stage == MESA_SHADER_VERTEX); + assert(nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL); bool has_position_write = nir->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_POS); @@ -247,7 +248,8 @@ msl_ensure_vertex_position_output(nir_shader *nir) bool msl_ensure_vertex_point_size_output(nir_shader *nir) { - assert(nir->info.stage == MESA_SHADER_VERTEX); + assert(nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL); bool has_point_size_write = nir->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_PSIZ); @@ -332,7 +334,8 @@ msl_vs_io_types(nir_builder *b, nir_intrinsic_instr *intr, void *data) bool msl_nir_vs_io_types(nir_shader *nir) { - assert(nir->info.stage == MESA_SHADER_VERTEX); + assert(nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL); return nir_shader_intrinsics_pass(nir, msl_vs_io_types, nir_metadata_all, NULL); } diff --git a/src/kosmickrisp/vulkan/kk_shader.c b/src/kosmickrisp/vulkan/kk_shader.c index b6498a1fbfa..19a5e648ad9 100644 --- a/src/kosmickrisp/vulkan/kk_shader.c +++ b/src/kosmickrisp/vulkan/kk_shader.c @@ -326,17 +326,21 @@ kk_lower_vs_vbo(nir_shader *nir, const struct vk_graphics_pipeline_state *state, NIR_PASS(_, nir, kk_nir_lower_vbo, attributes, robustness2); } +/* Lowering for the stage which ends up as the vertex stage on hardware */ static void -kk_lower_vs(nir_shader *nir, const struct vk_graphics_pipeline_state *state) +kk_lower_hw_vs(nir_shader *nir, const struct vk_graphics_pipeline_state *state) { - NIR_PASS(_, nir, msl_ensure_vertex_position_output); - if (state->ia->primitive_topology == VK_PRIMITIVE_TOPOLOGY_POINT_LIST) + bool is_point = + nir->info.stage == MESA_SHADER_TESS_EVAL + ? nir->info.tess.point_mode + : state->ia->primitive_topology == VK_PRIMITIVE_TOPOLOGY_POINT_LIST; + if (is_point) NIR_PASS(_, nir, msl_ensure_vertex_point_size_output); - - if (state->ia->primitive_topology != VK_PRIMITIVE_TOPOLOGY_POINT_LIST) + else nir_shader_intrinsics_pass(nir, msl_nir_vs_remove_point_size_write, nir_metadata_control_flow, NULL); + NIR_PASS(_, nir, msl_ensure_vertex_position_output); NIR_PASS(_, nir, nir_lower_clip_halfz_dynamic); NIR_PASS(_, nir, msl_nir_vs_io_types); } @@ -481,7 +485,7 @@ kk_lower_fs(struct kk_device *dev, nir_shader *nir, } static void -kk_lower_nir(struct kk_device *dev, nir_shader *nir, +kk_lower_nir(struct kk_device *dev, nir_shader *nir, bool emulated_stage, const struct vk_pipeline_robustness_state *rs, uint32_t set_layout_count, struct vk_descriptor_set_layout *const *set_layouts, @@ -579,8 +583,10 @@ kk_lower_nir(struct kk_device *dev, nir_shader *nir, NIR_PASS(_, nir, nir_opt_constant_folding); /* These passes operate on lowered IO. */ - if (nir->info.stage == MESA_SHADER_VERTEX) { - kk_lower_vs(nir, state); + if ((nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL) && + !emulated_stage) { + kk_lower_hw_vs(nir, state); } else if (nir->info.stage == MESA_SHADER_FRAGMENT) { kk_lower_fs(dev, nir, state); } @@ -805,9 +811,6 @@ kk_compile_shader(struct kk_device *dev, nir_shader *nir, /* This destroys info so it needs to happen after the gather */ NIR_PASS(_, nir, poly_nir_lower_tes, true); - - NIR_PASS(_, nir, msl_ensure_vertex_position_output); - NIR_PASS(_, nir, msl_nir_vs_io_types); } NIR_PASS(_, nir, kk_nir_lower_poly); @@ -933,7 +936,7 @@ get_empty_nir(struct kk_device *dev, mesa_shader_stage stage, .null_uniform_buffer_descriptor = false, .null_storage_buffer_descriptor = false, }; - kk_lower_nir(dev, nir, &no_robustness, 0u, NULL, state); + kk_lower_nir(dev, nir, false, &no_robustness, 0u, NULL, state); nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir)); return nir; @@ -1271,6 +1274,17 @@ kk_compile_shaders(struct vk_device *device, uint32_t shader_count, nir_shader *nir_shaders[shader_count + 1u]; struct kk_shader *shaders[shader_count + 1u]; + /* Determine if the pipeline contains tessellation stages */ + bool tess = false; + for (uint32_t i = 0u; i < shader_count; ++i) { + const struct vk_shader_compile_info *info = &infos[i]; + if (info->nir->info.stage == MESA_SHADER_TESS_CTRL || + info->nir->info.stage == MESA_SHADER_TESS_EVAL) { + tess = true; + break; + } + } + /* Lower shaders, notably lowering IO. This is a prerequisite for intershader * optimization. */ const struct vk_pipeline_robustness_state *vertex_robustness = &rs_none; @@ -1278,9 +1292,13 @@ kk_compile_shaders(struct vk_device *device, uint32_t shader_count, const struct vk_shader_compile_info *info = &infos[i]; nir_shader *nir = info->nir; + /* For tessellation pipelines, some stages may be emulated in compute */ + bool emulated_stage = tess && (nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_CTRL); + msl_preprocess_nir_workarounds(nir, dev->disabled_workarounds); - kk_lower_nir(dev, nir, info->robustness, info->set_layout_count, - info->set_layouts, state); + kk_lower_nir(dev, nir, emulated_stage, info->robustness, + info->set_layout_count, info->set_layouts, state); if (nir->info.stage == MESA_SHADER_VERTEX) vertex_robustness = info->robustness;