diff --git a/src/gallium/drivers/zink/zink_context.h b/src/gallium/drivers/zink/zink_context.h index 825593e26d7..f6ab6f16d07 100644 --- a/src/gallium/drivers/zink/zink_context.h +++ b/src/gallium/drivers/zink/zink_context.h @@ -174,6 +174,7 @@ struct zink_context { bool pipeline_changed[2]; //gfx, compute struct zink_shader *gfx_stages[ZINK_SHADER_COUNT]; + struct zink_shader *last_vertex_stage; struct zink_gfx_pipeline_state gfx_pipeline_state; enum pipe_prim_type gfx_prim_mode; struct hash_table *program_cache; diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index 3339884823d..7f3a0d1fcea 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -471,6 +471,13 @@ zink_create_gfx_program(struct zink_context *ctx, goto fail; } + if (stages[PIPE_SHADER_GEOMETRY]) + prog->last_vertex_stage = stages[PIPE_SHADER_GEOMETRY]; + else if (stages[PIPE_SHADER_TESS_EVAL]) + prog->last_vertex_stage = stages[PIPE_SHADER_TESS_EVAL]; + else + prog->last_vertex_stage = stages[PIPE_SHADER_VERTEX]; + struct mesa_sha1 sctx; _mesa_sha1_init(&sctx); for (int i = 0; i < ZINK_SHADER_COUNT; ++i) { @@ -906,7 +913,12 @@ static void zink_bind_vs_state(struct pipe_context *pctx, void *cso) { - bind_stage(zink_context(pctx), PIPE_SHADER_VERTEX, cso); + struct zink_context *ctx = zink_context(pctx); + bind_stage(ctx, PIPE_SHADER_VERTEX, cso); + if (!ctx->gfx_stages[PIPE_SHADER_GEOMETRY] && + !ctx->gfx_stages[PIPE_SHADER_TESS_EVAL]) { + ctx->last_vertex_stage = cso; + } } static void @@ -925,6 +937,14 @@ zink_bind_gs_state(struct pipe_context *pctx, ctx->dirty_shader_stages |= BITFIELD_BIT(PIPE_SHADER_VERTEX) | BITFIELD_BIT(PIPE_SHADER_TESS_EVAL); bind_stage(ctx, PIPE_SHADER_GEOMETRY, cso); + if (cso) + ctx->last_vertex_stage = cso; + else { + if (ctx->gfx_stages[PIPE_SHADER_TESS_EVAL]) + ctx->last_vertex_stage = ctx->gfx_stages[PIPE_SHADER_TESS_EVAL]; + else + ctx->last_vertex_stage = ctx->gfx_stages[PIPE_SHADER_VERTEX]; + } } static void @@ -948,6 +968,12 @@ zink_bind_tes_state(struct pipe_context *pctx, ctx->dirty_shader_stages |= BITFIELD_BIT(PIPE_SHADER_VERTEX); } bind_stage(ctx, PIPE_SHADER_TESS_EVAL, cso); + if (!ctx->gfx_stages[PIPE_SHADER_GEOMETRY]) { + if (cso) + ctx->last_vertex_stage = cso; + else + ctx->last_vertex_stage = ctx->gfx_stages[PIPE_SHADER_VERTEX]; + } } static void * diff --git a/src/gallium/drivers/zink/zink_program.h b/src/gallium/drivers/zink/zink_program.h index 087a0fa0124..a7dbe855d40 100644 --- a/src/gallium/drivers/zink/zink_program.h +++ b/src/gallium/drivers/zink/zink_program.h @@ -95,6 +95,8 @@ struct zink_gfx_program { struct zink_shader_module *default_variants[ZINK_SHADER_COUNT][2]; //[default, no streamout] const void *default_variant_key[ZINK_SHADER_COUNT]; + struct zink_shader *last_vertex_stage; + struct zink_shader *shaders[ZINK_SHADER_COUNT]; struct hash_table *pipelines[11]; // number of draw modes we support };