diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index 7697d50a11f..f4e1008ac0e 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -3211,36 +3211,27 @@ zink_shader_finalize(struct pipe_screen *pscreen, void *nirptr) void zink_shader_free(struct zink_context *ctx, struct zink_shader *shader) { + assert(shader->nir->info.stage != MESA_SHADER_COMPUTE); set_foreach(shader->programs, entry) { - if (shader->nir->info.stage == MESA_SHADER_COMPUTE) { - struct zink_compute_program *comp = (void*)entry->key; - if (!comp->base.removed) { - _mesa_hash_table_remove_key(&ctx->compute_program_cache, comp->shader); - comp->base.removed = true; - } - comp->shader = NULL; - zink_compute_program_reference(ctx, &comp, NULL); - } else { - struct zink_gfx_program *prog = (void*)entry->key; - gl_shader_stage stage = shader->nir->info.stage; - assert(stage < ZINK_GFX_SHADER_COUNT); - if (!prog->base.removed && (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated)) { - unsigned stages_present = prog->stages_present; - if (prog->shaders[MESA_SHADER_TESS_CTRL] && prog->shaders[MESA_SHADER_TESS_CTRL]->is_generated) - stages_present &= ~BITFIELD_BIT(MESA_SHADER_TESS_CTRL); - struct hash_table *ht = &ctx->program_cache[zink_program_cache_stages(stages_present)]; - struct hash_entry *he = _mesa_hash_table_search(ht, prog->shaders); - assert(he); - _mesa_hash_table_remove(ht, he); - prog->base.removed = true; - } - if (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated) - prog->shaders[stage] = NULL; - /* only remove generated tcs during parent tes destruction */ - if (stage == MESA_SHADER_TESS_EVAL && shader->generated) - prog->shaders[MESA_SHADER_TESS_CTRL] = NULL; - zink_gfx_program_reference(ctx, &prog, NULL); + struct zink_gfx_program *prog = (void*)entry->key; + gl_shader_stage stage = shader->nir->info.stage; + assert(stage < ZINK_GFX_SHADER_COUNT); + if (!prog->base.removed && (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated)) { + unsigned stages_present = prog->stages_present; + if (prog->shaders[MESA_SHADER_TESS_CTRL] && prog->shaders[MESA_SHADER_TESS_CTRL]->is_generated) + stages_present &= ~BITFIELD_BIT(MESA_SHADER_TESS_CTRL); + struct hash_table *ht = &ctx->program_cache[zink_program_cache_stages(stages_present)]; + struct hash_entry *he = _mesa_hash_table_search(ht, prog->shaders); + assert(he); + _mesa_hash_table_remove(ht, he); + prog->base.removed = true; } + if (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated) + prog->shaders[stage] = NULL; + /* only remove generated tcs during parent tes destruction */ + if (stage == MESA_SHADER_TESS_EVAL && shader->generated) + prog->shaders[MESA_SHADER_TESS_CTRL] = NULL; + zink_gfx_program_reference(ctx, &prog, NULL); } if (shader->nir->info.stage == MESA_SHADER_TESS_EVAL && shader->generated) { /* automatically destroy generated tcs shaders when tes is destroyed */ diff --git a/src/gallium/drivers/zink/zink_context.c b/src/gallium/drivers/zink/zink_context.c index 54a8561ea9b..c3331d86e72 100644 --- a/src/gallium/drivers/zink/zink_context.c +++ b/src/gallium/drivers/zink/zink_context.c @@ -102,10 +102,6 @@ zink_context_destroy(struct pipe_context *pctx) pg->removed = true; } } - hash_table_foreach(&ctx->compute_program_cache, entry) { - struct zink_program *pg = entry->data; - pg->removed = true; - } if (ctx->blitter) util_blitter_destroy(ctx->blitter); @@ -162,7 +158,6 @@ zink_context_destroy(struct pipe_context *pctx) slab_destroy_child(&ctx->transfer_pool); for (unsigned i = 0; i < ARRAY_SIZE(ctx->program_cache); i++) _mesa_hash_table_clear(&ctx->program_cache[i], NULL); - _mesa_hash_table_clear(&ctx->compute_program_cache, NULL); _mesa_hash_table_destroy(ctx->render_pass_cache, NULL); slab_destroy_child(&ctx->transfer_pool_unsync); @@ -4536,7 +4531,6 @@ zink_context_create(struct pipe_screen *pscreen, void *priv, unsigned flags) ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_TESS_CTRL].size = sizeof(struct zink_tcs_key); ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_GEOMETRY].size = sizeof(struct zink_vs_key_base); ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_FRAGMENT].size = sizeof(struct zink_fs_key); - _mesa_hash_table_init(&ctx->compute_program_cache, ctx, _mesa_hash_pointer, _mesa_key_pointer_equal); _mesa_hash_table_init(&ctx->framebuffer_cache, ctx, hash_framebuffer_imageless, equals_framebuffer_imageless); if (!zink_init_render_pass(ctx)) goto fail; diff --git a/src/gallium/drivers/zink/zink_draw.cpp b/src/gallium/drivers/zink/zink_draw.cpp index ed753eb8c4f..647cbdfdf32 100644 --- a/src/gallium/drivers/zink/zink_draw.cpp +++ b/src/gallium/drivers/zink/zink_draw.cpp @@ -910,7 +910,7 @@ zink_launch_grid(struct pipe_context *pctx, const struct pipe_grid_info *info) zink_select_launch_grid(ctx); } - if (BITSET_TEST(ctx->compute_stage->nir->info.system_values_read, SYSTEM_VALUE_WORK_DIM)) + if (BITSET_TEST(ctx->curr_compute->shader->nir->info.system_values_read, SYSTEM_VALUE_WORK_DIM)) VKCTX(CmdPushConstants)(batch->state->cmdbuf, ctx->curr_compute->base.layout, VK_SHADER_STAGE_COMPUTE_BIT, offsetof(struct zink_cs_push_constant, work_dim), sizeof(uint32_t), &info->work_dim); diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index e03fdb522dd..89137569256 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -727,17 +727,18 @@ equals_compute_pipeline_state(const void *a, const void *b) sa->module == sb->module; } -struct zink_compute_program * -zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader) +static struct zink_compute_program * +create_compute_program(struct zink_context *ctx, nir_shader *nir) { struct zink_screen *screen = zink_screen(ctx->base.screen); struct zink_compute_program *comp = create_program(ctx, true); if (!comp) goto fail; + comp->shader = zink_shader_create(screen, nir, NULL); comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module); assert(comp->module); - comp->module->shader = zink_shader_compile(screen, shader, shader->nir, NULL); + comp->module->shader = zink_shader_compile(screen, comp->shader, comp->shader->nir, NULL); assert(comp->module->shader); util_dynarray_init(&comp->shader_cache[0], NULL); util_dynarray_init(&comp->shader_cache[1], NULL); @@ -745,9 +746,7 @@ zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader comp->pipelines = _mesa_hash_table_create(NULL, NULL, equals_compute_pipeline_state); - _mesa_set_add(shader->programs, comp); - comp->shader = shader; - memcpy(comp->base.sha1, shader->base.sha1, sizeof(shader->base.sha1)); + memcpy(comp->base.sha1, comp->shader->base.sha1, sizeof(comp->shader->base.sha1)); if (!zink_descriptor_program_init(ctx, &comp->base)) goto fail; @@ -774,7 +773,7 @@ zink_program_get_descriptor_usage(struct zink_context *ctx, gl_shader_stage stag zs = ctx->gfx_stages[stage]; break; case MESA_SHADER_COMPUTE: { - zs = ctx->compute_stage; + zs = ctx->curr_compute->shader; break; } default: @@ -810,7 +809,7 @@ zink_program_descriptor_is_buffer(struct zink_context *ctx, gl_shader_stage stag zs = ctx->gfx_stages[stage]; break; case MESA_SHADER_COMPUTE: { - zs = ctx->compute_stage; + zs = ctx->curr_compute->shader; break; } default: @@ -994,7 +993,7 @@ zink_get_compute_pipeline(struct zink_screen *screen, } static inline void -bind_stage(struct zink_context *ctx, gl_shader_stage stage, +bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage, struct zink_shader *shader) { if (shader && shader->nir->info.num_inlinable_uniforms) @@ -1002,49 +1001,20 @@ bind_stage(struct zink_context *ctx, gl_shader_stage stage, else ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage); - if (stage == MESA_SHADER_COMPUTE) { - if (ctx->compute_stage) { - ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; - ctx->compute_pipeline_state.module = VK_NULL_HANDLE; - ctx->compute_pipeline_state.module_hash = 0; - } - if (shader && shader != ctx->compute_stage) { - struct hash_entry *entry = _mesa_hash_table_search(&ctx->compute_program_cache, shader); - if (entry) { - ctx->compute_pipeline_state.dirty = true; - ctx->curr_compute = entry->data; - } else { - struct zink_compute_program *comp = zink_create_compute_program(ctx, shader); - _mesa_hash_table_insert(&ctx->compute_program_cache, comp->shader, comp); - ctx->compute_pipeline_state.dirty = true; - ctx->curr_compute = comp; - zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base); - } - ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash; - ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader; - ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; - if (ctx->compute_pipeline_state.key.base.nonseamless_cube_mask) - ctx->dirty_shader_stages |= BITFIELD_BIT(MESA_SHADER_COMPUTE); - } else if (!shader) - ctx->curr_compute = NULL; - ctx->compute_stage = shader; - zink_select_launch_grid(ctx); + if (ctx->gfx_stages[stage]) + ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash; + ctx->gfx_stages[stage] = shader; + ctx->gfx_dirty = ctx->gfx_stages[MESA_SHADER_FRAGMENT] && ctx->gfx_stages[MESA_SHADER_VERTEX]; + ctx->gfx_pipeline_state.modules_changed = true; + if (shader) { + ctx->shader_stages |= BITFIELD_BIT(stage); + ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash; } else { - if (ctx->gfx_stages[stage]) - ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash; - ctx->gfx_stages[stage] = shader; - ctx->gfx_dirty = ctx->gfx_stages[MESA_SHADER_FRAGMENT] && ctx->gfx_stages[MESA_SHADER_VERTEX]; - ctx->gfx_pipeline_state.modules_changed = true; - if (shader) { - ctx->shader_stages |= BITFIELD_BIT(stage); - ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash; - } else { - ctx->gfx_pipeline_state.modules[stage] = VK_NULL_HANDLE; - if (ctx->curr_program) - ctx->gfx_pipeline_state.final_hash ^= ctx->curr_program->last_variant_hash; - ctx->curr_program = NULL; - ctx->shader_stages &= ~BITFIELD_BIT(stage); - } + ctx->gfx_pipeline_state.modules[stage] = VK_NULL_HANDLE; + if (ctx->curr_program) + ctx->gfx_pipeline_state.final_hash ^= ctx->curr_program->last_variant_hash; + ctx->curr_program = NULL; + ctx->shader_stages &= ~BITFIELD_BIT(stage); } } @@ -1096,7 +1066,7 @@ zink_bind_vs_state(struct pipe_context *pctx, struct zink_context *ctx = zink_context(pctx); if (!cso && !ctx->gfx_stages[MESA_SHADER_VERTEX]) return; - bind_stage(ctx, MESA_SHADER_VERTEX, cso); + bind_gfx_stage(ctx, MESA_SHADER_VERTEX, cso); bind_last_vertex_stage(ctx); if (cso) { struct zink_shader *zs = cso; @@ -1132,7 +1102,7 @@ zink_bind_fs_state(struct pipe_context *pctx, struct zink_context *ctx = zink_context(pctx); if (!cso && !ctx->gfx_stages[MESA_SHADER_FRAGMENT]) return; - bind_stage(ctx, MESA_SHADER_FRAGMENT, cso); + bind_gfx_stage(ctx, MESA_SHADER_FRAGMENT, cso); ctx->fbfetch_outputs = 0; if (cso) { nir_shader *nir = ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir; @@ -1155,7 +1125,7 @@ zink_bind_gs_state(struct pipe_context *pctx, if (!cso && !ctx->gfx_stages[MESA_SHADER_GEOMETRY]) return; bool had_points = ctx->gfx_stages[MESA_SHADER_GEOMETRY] ? ctx->gfx_stages[MESA_SHADER_GEOMETRY]->nir->info.gs.output_primitive == SHADER_PRIM_POINTS : false; - bind_stage(ctx, MESA_SHADER_GEOMETRY, cso); + bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, cso); bind_last_vertex_stage(ctx); if (cso) { if (!had_points && ctx->last_vertex_stage->nir->info.gs.output_primitive == SHADER_PRIM_POINTS) @@ -1170,7 +1140,7 @@ static void zink_bind_tcs_state(struct pipe_context *pctx, void *cso) { - bind_stage(zink_context(pctx), MESA_SHADER_TESS_CTRL, cso); + bind_gfx_stage(zink_context(pctx), MESA_SHADER_TESS_CTRL, cso); } static void @@ -1187,7 +1157,7 @@ zink_bind_tes_state(struct pipe_context *pctx, ctx->gfx_stages[MESA_SHADER_TESS_CTRL] = NULL; } } - bind_stage(ctx, MESA_SHADER_TESS_EVAL, cso); + bind_gfx_stage(ctx, MESA_SHADER_TESS_EVAL, cso); bind_last_vertex_stage(ctx); } @@ -1201,14 +1171,43 @@ zink_create_cs_state(struct pipe_context *pctx, else nir = (struct nir_shader *)shader->prog; - return zink_shader_create(zink_screen(pctx->screen), nir, NULL); + return create_compute_program(zink_context(pctx), nir); } static void zink_bind_cs_state(struct pipe_context *pctx, void *cso) { - bind_stage(zink_context(pctx), MESA_SHADER_COMPUTE, cso); + struct zink_context *ctx = zink_context(pctx); + struct zink_compute_program *comp = cso; + if (comp && comp->shader->nir->info.num_inlinable_uniforms) + ctx->shader_has_inlinable_uniforms_mask |= 1 << MESA_SHADER_COMPUTE; + else + ctx->shader_has_inlinable_uniforms_mask &= ~(1 << MESA_SHADER_COMPUTE); + + if (ctx->curr_compute) { + zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base); + ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; + ctx->compute_pipeline_state.module = VK_NULL_HANDLE; + ctx->compute_pipeline_state.module_hash = 0; + } + ctx->compute_pipeline_state.dirty = true; + ctx->curr_compute = comp; + if (comp && comp != ctx->curr_compute) { + ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash; + ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader; + ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; + if (ctx->compute_pipeline_state.key.base.nonseamless_cube_mask) + ctx->dirty_shader_stages |= BITFIELD_BIT(MESA_SHADER_COMPUTE); + } + zink_select_launch_grid(ctx); +} + +static void +zink_delete_cs_shader_state(struct pipe_context *pctx, void *cso) +{ + struct zink_compute_program *comp = cso; + zink_compute_program_reference(zink_context(pctx), &comp, NULL); } void @@ -1269,7 +1268,7 @@ zink_program_init(struct zink_context *ctx) ctx->base.create_compute_state = zink_create_cs_state; ctx->base.bind_compute_state = zink_bind_cs_state; - ctx->base.delete_compute_state = zink_delete_shader_state; + ctx->base.delete_compute_state = zink_delete_cs_shader_state; if (zink_screen(ctx->base.screen)->info.have_EXT_vertex_input_dynamic_state) _mesa_set_init(&ctx->gfx_inputs, ctx, hash_gfx_input_dynamic, equals_gfx_input_dynamic); diff --git a/src/gallium/drivers/zink/zink_program.h b/src/gallium/drivers/zink/zink_program.h index d69f70aad31..378718db94d 100644 --- a/src/gallium/drivers/zink/zink_program.h +++ b/src/gallium/drivers/zink/zink_program.h @@ -161,8 +161,6 @@ zink_gfx_program_reference(struct zink_context *ctx, return ret; } -struct zink_compute_program * -zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader); void zink_destroy_compute_program(struct zink_context *ctx, struct zink_compute_program *comp); diff --git a/src/gallium/drivers/zink/zink_types.h b/src/gallium/drivers/zink/zink_types.h index e22c321fccc..4d5d3852fe3 100644 --- a/src/gallium/drivers/zink/zink_types.h +++ b/src/gallium/drivers/zink/zink_types.h @@ -576,7 +576,6 @@ struct zink_shader_info { bool have_vulkan_memory_model; }; - struct zink_shader { struct util_live_shader base; uint32_t hash; @@ -1376,9 +1375,7 @@ struct zink_context { struct zink_descriptor_data dd; - struct zink_shader *compute_stage; struct zink_compute_pipeline_state compute_pipeline_state; - struct hash_table compute_program_cache; struct zink_compute_program *curr_compute; unsigned shader_stages : ZINK_GFX_SHADER_COUNT; /* mask of bound gfx shader stages */