diff --git a/src/kosmickrisp/compiler/msl_nir_lower_common.c b/src/kosmickrisp/compiler/msl_nir_lower_common.c index a756ac947d6..8699eadad62 100644 --- a/src/kosmickrisp/compiler/msl_nir_lower_common.c +++ b/src/kosmickrisp/compiler/msl_nir_lower_common.c @@ -219,7 +219,8 @@ msl_ensure_depth_write(nir_shader *nir) bool msl_ensure_vertex_position_output(nir_shader *nir) { - assert(nir->info.stage == MESA_SHADER_VERTEX); + assert(nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL); bool has_position_write = nir->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_POS); @@ -247,7 +248,8 @@ msl_ensure_vertex_position_output(nir_shader *nir) bool msl_ensure_vertex_point_size_output(nir_shader *nir) { - assert(nir->info.stage == MESA_SHADER_VERTEX); + assert(nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL); bool has_point_size_write = nir->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_PSIZ); @@ -332,7 +334,8 @@ msl_vs_io_types(nir_builder *b, nir_intrinsic_instr *intr, void *data) bool msl_nir_vs_io_types(nir_shader *nir) { - assert(nir->info.stage == MESA_SHADER_VERTEX); + assert(nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL); return nir_shader_intrinsics_pass(nir, msl_vs_io_types, nir_metadata_all, NULL); } diff --git a/src/kosmickrisp/vulkan/kk_shader.c b/src/kosmickrisp/vulkan/kk_shader.c index b6498a1fbfa..19a5e648ad9 100644 --- a/src/kosmickrisp/vulkan/kk_shader.c +++ b/src/kosmickrisp/vulkan/kk_shader.c @@ -326,17 +326,21 @@ kk_lower_vs_vbo(nir_shader *nir, const struct vk_graphics_pipeline_state *state, NIR_PASS(_, nir, kk_nir_lower_vbo, attributes, robustness2); } +/* Lowering for the stage which ends up as the vertex stage on hardware */ static void -kk_lower_vs(nir_shader *nir, const struct vk_graphics_pipeline_state *state) +kk_lower_hw_vs(nir_shader *nir, const struct vk_graphics_pipeline_state *state) { - NIR_PASS(_, nir, msl_ensure_vertex_position_output); - if (state->ia->primitive_topology == VK_PRIMITIVE_TOPOLOGY_POINT_LIST) + bool is_point = + nir->info.stage == MESA_SHADER_TESS_EVAL + ? nir->info.tess.point_mode + : state->ia->primitive_topology == VK_PRIMITIVE_TOPOLOGY_POINT_LIST; + if (is_point) NIR_PASS(_, nir, msl_ensure_vertex_point_size_output); - - if (state->ia->primitive_topology != VK_PRIMITIVE_TOPOLOGY_POINT_LIST) + else nir_shader_intrinsics_pass(nir, msl_nir_vs_remove_point_size_write, nir_metadata_control_flow, NULL); + NIR_PASS(_, nir, msl_ensure_vertex_position_output); NIR_PASS(_, nir, nir_lower_clip_halfz_dynamic); NIR_PASS(_, nir, msl_nir_vs_io_types); } @@ -481,7 +485,7 @@ kk_lower_fs(struct kk_device *dev, nir_shader *nir, } static void -kk_lower_nir(struct kk_device *dev, nir_shader *nir, +kk_lower_nir(struct kk_device *dev, nir_shader *nir, bool emulated_stage, const struct vk_pipeline_robustness_state *rs, uint32_t set_layout_count, struct vk_descriptor_set_layout *const *set_layouts, @@ -579,8 +583,10 @@ kk_lower_nir(struct kk_device *dev, nir_shader *nir, NIR_PASS(_, nir, nir_opt_constant_folding); /* These passes operate on lowered IO. */ - if (nir->info.stage == MESA_SHADER_VERTEX) { - kk_lower_vs(nir, state); + if ((nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_EVAL) && + !emulated_stage) { + kk_lower_hw_vs(nir, state); } else if (nir->info.stage == MESA_SHADER_FRAGMENT) { kk_lower_fs(dev, nir, state); } @@ -805,9 +811,6 @@ kk_compile_shader(struct kk_device *dev, nir_shader *nir, /* This destroys info so it needs to happen after the gather */ NIR_PASS(_, nir, poly_nir_lower_tes, true); - - NIR_PASS(_, nir, msl_ensure_vertex_position_output); - NIR_PASS(_, nir, msl_nir_vs_io_types); } NIR_PASS(_, nir, kk_nir_lower_poly); @@ -933,7 +936,7 @@ get_empty_nir(struct kk_device *dev, mesa_shader_stage stage, .null_uniform_buffer_descriptor = false, .null_storage_buffer_descriptor = false, }; - kk_lower_nir(dev, nir, &no_robustness, 0u, NULL, state); + kk_lower_nir(dev, nir, false, &no_robustness, 0u, NULL, state); nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir)); return nir; @@ -1271,6 +1274,17 @@ kk_compile_shaders(struct vk_device *device, uint32_t shader_count, nir_shader *nir_shaders[shader_count + 1u]; struct kk_shader *shaders[shader_count + 1u]; + /* Determine if the pipeline contains tessellation stages */ + bool tess = false; + for (uint32_t i = 0u; i < shader_count; ++i) { + const struct vk_shader_compile_info *info = &infos[i]; + if (info->nir->info.stage == MESA_SHADER_TESS_CTRL || + info->nir->info.stage == MESA_SHADER_TESS_EVAL) { + tess = true; + break; + } + } + /* Lower shaders, notably lowering IO. This is a prerequisite for intershader * optimization. */ const struct vk_pipeline_robustness_state *vertex_robustness = &rs_none; @@ -1278,9 +1292,13 @@ kk_compile_shaders(struct vk_device *device, uint32_t shader_count, const struct vk_shader_compile_info *info = &infos[i]; nir_shader *nir = info->nir; + /* For tessellation pipelines, some stages may be emulated in compute */ + bool emulated_stage = tess && (nir->info.stage == MESA_SHADER_VERTEX || + nir->info.stage == MESA_SHADER_TESS_CTRL); + msl_preprocess_nir_workarounds(nir, dev->disabled_workarounds); - kk_lower_nir(dev, nir, info->robustness, info->set_layout_count, - info->set_layouts, state); + kk_lower_nir(dev, nir, emulated_stage, info->robustness, + info->set_layout_count, info->set_layouts, state); if (nir->info.stage == MESA_SHADER_VERTEX) vertex_robustness = info->robustness;