diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index 01b5be7f6df..040e9c8339b 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -313,6 +313,17 @@ update_so_info(struct zink_shader *sh, } } +static bool +last_vertex_stage(struct zink_shader *zs) +{ + assert(zs->nir->info.stage != MESA_SHADER_FRAGMENT); + if (zs->has_geometry_shader) + return zs->nir->info.stage == MESA_SHADER_GEOMETRY; + if (zs->has_tess_shader) + return zs->nir->info.stage == MESA_SHADER_TESS_EVAL; + return true; +} + VkShaderModule zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs, struct zink_shader_key *key, unsigned char *shader_slot_map, unsigned char *shader_slots_reserved) @@ -321,26 +332,18 @@ zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs, struct z void *streamout = NULL; nir_shader *nir = zs->nir; /* TODO: use a separate mem ctx here for ralloc */ - if (zs->has_geometry_shader) { - if (zs->nir->info.stage == MESA_SHADER_GEOMETRY) { - streamout = &zs->streamout; - NIR_PASS_V(nir, nir_lower_clip_halfz); - } - } else if (zs->has_tess_shader) { - if (zs->nir->info.stage == MESA_SHADER_TESS_EVAL) { - streamout = &zs->streamout; + if (zs->nir->info.stage != MESA_SHADER_FRAGMENT) { + if (last_vertex_stage(zs)) { + if (zs->streamout.so_info_slots) + streamout = &zs->streamout; + + nir = nir_shader_clone(NULL, zs->nir); NIR_PASS_V(nir, nir_lower_clip_halfz); } } else { - streamout = &zs->streamout; - NIR_PASS_V(nir, nir_lower_clip_halfz); - } - if (!zs->streamout.so_info_slots) - streamout = NULL; - if (zs->nir->info.stage == MESA_SHADER_FRAGMENT) { - nir = nir_shader_clone(NULL, nir); if (!zink_fs_key(key)->samples && nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_SAMPLE_MASK)) { + nir = nir_shader_clone(NULL, zs->nir); /* VK will always use gl_SampleMask[] values even if sample count is 0, * so we need to skip this write here to mimic GL's behavior of ignoring it */ @@ -376,7 +379,7 @@ zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs, struct z if (vkCreateShaderModule(screen->dev, &smci, NULL, &mod) != VK_SUCCESS) mod = VK_NULL_HANDLE; - if (zs->nir->info.stage == MESA_SHADER_FRAGMENT) + if (nir != zs->nir) ralloc_free(nir); /* TODO: determine if there's any reason to cache spirv output? */