diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index b309ca1d9c6..a918ef8981f 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -4646,8 +4646,10 @@ zink_shader_free(struct zink_screen *screen, struct zink_shader *shader) /* only remove generated tcs during parent tes destruction */ if (stage == MESA_SHADER_TESS_EVAL && shader->non_fs.generated_tcs) prog->shaders[MESA_SHADER_TESS_CTRL] = NULL; - if (stage != MESA_SHADER_FRAGMENT && shader->non_fs.generated_gs) - prog->shaders[MESA_SHADER_GEOMETRY] = NULL; + for (unsigned int i = 0; i < ARRAY_SIZE(shader->non_fs.generated_gs); i++) { + if (stage != MESA_SHADER_FRAGMENT && shader->non_fs.generated_gs[i]) + prog->shaders[MESA_SHADER_GEOMETRY] = NULL; + } zink_gfx_program_reference(screen, &prog, NULL); } if (shader->nir->info.stage == MESA_SHADER_TESS_EVAL && @@ -4656,11 +4658,13 @@ zink_shader_free(struct zink_screen *screen, struct zink_shader *shader) zink_shader_free(screen, shader->non_fs.generated_tcs); shader->non_fs.generated_tcs = NULL; } - if (shader->nir->info.stage != MESA_SHADER_FRAGMENT && - shader->non_fs.generated_gs) { - /* automatically destroy generated gs shaders when owner is destroyed */ - zink_shader_free(screen, shader->non_fs.generated_gs); - shader->non_fs.generated_gs = NULL; + for (unsigned int i = 0; i < ARRAY_SIZE(shader->non_fs.generated_gs); i++) { + if (shader->nir->info.stage != MESA_SHADER_FRAGMENT && + shader->non_fs.generated_gs[i]) { + /* automatically destroy generated gs shaders when owner is destroyed */ + zink_shader_free(screen, shader->non_fs.generated_gs[i]); + shader->non_fs.generated_gs[i] = NULL; + } } _mesa_set_destroy(shader->programs, NULL); util_queue_fence_wait(&shader->precompile.fence); diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index bf8e7cf6659..9d4ae72e011 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -1562,6 +1562,22 @@ zink_get_compute_pipeline(struct zink_screen *screen, return state->pipeline; } +static void +bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader); + +static void +unbind_generated_gs(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader) +{ + for (int i = 0; i < ARRAY_SIZE(shader->non_fs.generated_gs); i++) { + if (ctx->gfx_stages[stage]->non_fs.generated_gs[i] && + ctx->gfx_stages[MESA_SHADER_GEOMETRY] == + ctx->gfx_stages[stage]->non_fs.generated_gs[i]) { + assert(stage != MESA_SHADER_GEOMETRY); /* let's not keep recursing! */ + bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, NULL); + } + } +} + static void bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader) { @@ -1573,15 +1589,10 @@ bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shad if (ctx->gfx_stages[stage]) { ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash; - /* unbind the generated GS */ - if (stage != MESA_SHADER_FRAGMENT && - ctx->gfx_stages[stage]->non_fs.generated_gs && - ctx->gfx_stages[MESA_SHADER_GEOMETRY] == - ctx->gfx_stages[stage]->non_fs.generated_gs) { - assert(stage != MESA_SHADER_GEOMETRY); /* let's not keep recursing! */ - bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, NULL); - } + if (stage != MESA_SHADER_FRAGMENT) + unbind_generated_gs(ctx, stage, shader); } + ctx->gfx_stages[stage] = shader; ctx->gfx_dirty = ctx->gfx_stages[MESA_SHADER_FRAGMENT] && ctx->gfx_stages[MESA_SHADER_VERTEX]; ctx->gfx_pipeline_state.modules_changed = true; @@ -2200,10 +2211,11 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx) ctx->gfx_stages[MESA_SHADER_TESS_EVAL] ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX; - if (!ctx->gfx_stages[MESA_SHADER_GEOMETRY]) { + if (!ctx->gfx_stages[MESA_SHADER_GEOMETRY] || + (ctx->gfx_stages[MESA_SHADER_GEOMETRY]->nir->info.gs.input_primitive != ctx->gfx_pipeline_state.gfx_prim_mode)) { assert(!screen->optimal_keys); - if (!ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs) { + if (!ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode]) { nir_shader *nir = nir_create_passthrough_gs( &screen->nir_options, ctx->gfx_stages[prev_vertex_stage]->nir, @@ -2211,12 +2223,12 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx) (lower_line_stipple || lower_line_smooth) ? 2 : 1); struct zink_shader *shader = zink_shader_create(screen, nir, NULL); - ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs = shader; + ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode] = shader; shader->non_fs.is_generated = true; } bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, - ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs); + ctx->gfx_stages[prev_vertex_stage]->non_fs.generated_gs[ctx->gfx_pipeline_state.gfx_prim_mode]); } } else if (ctx->gfx_stages[MESA_SHADER_GEOMETRY] && ctx->gfx_stages[MESA_SHADER_GEOMETRY]->non_fs.is_generated) diff --git a/src/gallium/drivers/zink/zink_types.h b/src/gallium/drivers/zink/zink_types.h index 430f7cb3fd3..eb447f5a898 100644 --- a/src/gallium/drivers/zink/zink_types.h +++ b/src/gallium/drivers/zink/zink_types.h @@ -758,7 +758,7 @@ struct zink_shader { union { struct { struct zink_shader *generated_tcs; // a generated shader that this shader "owns"; only valid in the tes stage - struct zink_shader *generated_gs; // a generated shader that this shader "owns" + struct zink_shader *generated_gs[PIPE_PRIM_MAX]; // generated shaders that this shader "owns" bool is_generated; // if this is a driver-created shader (e.g., tcs) } non_fs;