zink: create compute programs from compute shaders directly

this simplifies the whole compute shader/program architecture and
also compiles compute shaders when apps maybe expect them to be compiled

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18197>
This commit is contained in:
Mike Blumenkrantz 2022-08-11 16:50:54 -04:00 committed by Marge Bot
parent 2d46cc76c7
commit 4cb4bb555e
6 changed files with 78 additions and 99 deletions

View file

@ -3211,36 +3211,27 @@ zink_shader_finalize(struct pipe_screen *pscreen, void *nirptr)
void void
zink_shader_free(struct zink_context *ctx, struct zink_shader *shader) zink_shader_free(struct zink_context *ctx, struct zink_shader *shader)
{ {
assert(shader->nir->info.stage != MESA_SHADER_COMPUTE);
set_foreach(shader->programs, entry) { set_foreach(shader->programs, entry) {
if (shader->nir->info.stage == MESA_SHADER_COMPUTE) { struct zink_gfx_program *prog = (void*)entry->key;
struct zink_compute_program *comp = (void*)entry->key; gl_shader_stage stage = shader->nir->info.stage;
if (!comp->base.removed) { assert(stage < ZINK_GFX_SHADER_COUNT);
_mesa_hash_table_remove_key(&ctx->compute_program_cache, comp->shader); if (!prog->base.removed && (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated)) {
comp->base.removed = true; unsigned stages_present = prog->stages_present;
} if (prog->shaders[MESA_SHADER_TESS_CTRL] && prog->shaders[MESA_SHADER_TESS_CTRL]->is_generated)
comp->shader = NULL; stages_present &= ~BITFIELD_BIT(MESA_SHADER_TESS_CTRL);
zink_compute_program_reference(ctx, &comp, NULL); struct hash_table *ht = &ctx->program_cache[zink_program_cache_stages(stages_present)];
} else { struct hash_entry *he = _mesa_hash_table_search(ht, prog->shaders);
struct zink_gfx_program *prog = (void*)entry->key; assert(he);
gl_shader_stage stage = shader->nir->info.stage; _mesa_hash_table_remove(ht, he);
assert(stage < ZINK_GFX_SHADER_COUNT); prog->base.removed = true;
if (!prog->base.removed && (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated)) {
unsigned stages_present = prog->stages_present;
if (prog->shaders[MESA_SHADER_TESS_CTRL] && prog->shaders[MESA_SHADER_TESS_CTRL]->is_generated)
stages_present &= ~BITFIELD_BIT(MESA_SHADER_TESS_CTRL);
struct hash_table *ht = &ctx->program_cache[zink_program_cache_stages(stages_present)];
struct hash_entry *he = _mesa_hash_table_search(ht, prog->shaders);
assert(he);
_mesa_hash_table_remove(ht, he);
prog->base.removed = true;
}
if (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated)
prog->shaders[stage] = NULL;
/* only remove generated tcs during parent tes destruction */
if (stage == MESA_SHADER_TESS_EVAL && shader->generated)
prog->shaders[MESA_SHADER_TESS_CTRL] = NULL;
zink_gfx_program_reference(ctx, &prog, NULL);
} }
if (stage != MESA_SHADER_TESS_CTRL || !shader->is_generated)
prog->shaders[stage] = NULL;
/* only remove generated tcs during parent tes destruction */
if (stage == MESA_SHADER_TESS_EVAL && shader->generated)
prog->shaders[MESA_SHADER_TESS_CTRL] = NULL;
zink_gfx_program_reference(ctx, &prog, NULL);
} }
if (shader->nir->info.stage == MESA_SHADER_TESS_EVAL && shader->generated) { if (shader->nir->info.stage == MESA_SHADER_TESS_EVAL && shader->generated) {
/* automatically destroy generated tcs shaders when tes is destroyed */ /* automatically destroy generated tcs shaders when tes is destroyed */

View file

@ -102,10 +102,6 @@ zink_context_destroy(struct pipe_context *pctx)
pg->removed = true; pg->removed = true;
} }
} }
hash_table_foreach(&ctx->compute_program_cache, entry) {
struct zink_program *pg = entry->data;
pg->removed = true;
}
if (ctx->blitter) if (ctx->blitter)
util_blitter_destroy(ctx->blitter); util_blitter_destroy(ctx->blitter);
@ -162,7 +158,6 @@ zink_context_destroy(struct pipe_context *pctx)
slab_destroy_child(&ctx->transfer_pool); slab_destroy_child(&ctx->transfer_pool);
for (unsigned i = 0; i < ARRAY_SIZE(ctx->program_cache); i++) for (unsigned i = 0; i < ARRAY_SIZE(ctx->program_cache); i++)
_mesa_hash_table_clear(&ctx->program_cache[i], NULL); _mesa_hash_table_clear(&ctx->program_cache[i], NULL);
_mesa_hash_table_clear(&ctx->compute_program_cache, NULL);
_mesa_hash_table_destroy(ctx->render_pass_cache, NULL); _mesa_hash_table_destroy(ctx->render_pass_cache, NULL);
slab_destroy_child(&ctx->transfer_pool_unsync); slab_destroy_child(&ctx->transfer_pool_unsync);
@ -4536,7 +4531,6 @@ zink_context_create(struct pipe_screen *pscreen, void *priv, unsigned flags)
ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_TESS_CTRL].size = sizeof(struct zink_tcs_key); ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_TESS_CTRL].size = sizeof(struct zink_tcs_key);
ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_GEOMETRY].size = sizeof(struct zink_vs_key_base); ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_GEOMETRY].size = sizeof(struct zink_vs_key_base);
ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_FRAGMENT].size = sizeof(struct zink_fs_key); ctx->gfx_pipeline_state.shader_keys.key[MESA_SHADER_FRAGMENT].size = sizeof(struct zink_fs_key);
_mesa_hash_table_init(&ctx->compute_program_cache, ctx, _mesa_hash_pointer, _mesa_key_pointer_equal);
_mesa_hash_table_init(&ctx->framebuffer_cache, ctx, hash_framebuffer_imageless, equals_framebuffer_imageless); _mesa_hash_table_init(&ctx->framebuffer_cache, ctx, hash_framebuffer_imageless, equals_framebuffer_imageless);
if (!zink_init_render_pass(ctx)) if (!zink_init_render_pass(ctx))
goto fail; goto fail;

View file

@ -910,7 +910,7 @@ zink_launch_grid(struct pipe_context *pctx, const struct pipe_grid_info *info)
zink_select_launch_grid(ctx); zink_select_launch_grid(ctx);
} }
if (BITSET_TEST(ctx->compute_stage->nir->info.system_values_read, SYSTEM_VALUE_WORK_DIM)) if (BITSET_TEST(ctx->curr_compute->shader->nir->info.system_values_read, SYSTEM_VALUE_WORK_DIM))
VKCTX(CmdPushConstants)(batch->state->cmdbuf, ctx->curr_compute->base.layout, VK_SHADER_STAGE_COMPUTE_BIT, VKCTX(CmdPushConstants)(batch->state->cmdbuf, ctx->curr_compute->base.layout, VK_SHADER_STAGE_COMPUTE_BIT,
offsetof(struct zink_cs_push_constant, work_dim), sizeof(uint32_t), offsetof(struct zink_cs_push_constant, work_dim), sizeof(uint32_t),
&info->work_dim); &info->work_dim);

View file

@ -727,17 +727,18 @@ equals_compute_pipeline_state(const void *a, const void *b)
sa->module == sb->module; sa->module == sb->module;
} }
struct zink_compute_program * static struct zink_compute_program *
zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader) create_compute_program(struct zink_context *ctx, nir_shader *nir)
{ {
struct zink_screen *screen = zink_screen(ctx->base.screen); struct zink_screen *screen = zink_screen(ctx->base.screen);
struct zink_compute_program *comp = create_program(ctx, true); struct zink_compute_program *comp = create_program(ctx, true);
if (!comp) if (!comp)
goto fail; goto fail;
comp->shader = zink_shader_create(screen, nir, NULL);
comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module); comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module);
assert(comp->module); assert(comp->module);
comp->module->shader = zink_shader_compile(screen, shader, shader->nir, NULL); comp->module->shader = zink_shader_compile(screen, comp->shader, comp->shader->nir, NULL);
assert(comp->module->shader); assert(comp->module->shader);
util_dynarray_init(&comp->shader_cache[0], NULL); util_dynarray_init(&comp->shader_cache[0], NULL);
util_dynarray_init(&comp->shader_cache[1], NULL); util_dynarray_init(&comp->shader_cache[1], NULL);
@ -745,9 +746,7 @@ zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader
comp->pipelines = _mesa_hash_table_create(NULL, NULL, comp->pipelines = _mesa_hash_table_create(NULL, NULL,
equals_compute_pipeline_state); equals_compute_pipeline_state);
_mesa_set_add(shader->programs, comp); memcpy(comp->base.sha1, comp->shader->base.sha1, sizeof(comp->shader->base.sha1));
comp->shader = shader;
memcpy(comp->base.sha1, shader->base.sha1, sizeof(shader->base.sha1));
if (!zink_descriptor_program_init(ctx, &comp->base)) if (!zink_descriptor_program_init(ctx, &comp->base))
goto fail; goto fail;
@ -774,7 +773,7 @@ zink_program_get_descriptor_usage(struct zink_context *ctx, gl_shader_stage stag
zs = ctx->gfx_stages[stage]; zs = ctx->gfx_stages[stage];
break; break;
case MESA_SHADER_COMPUTE: { case MESA_SHADER_COMPUTE: {
zs = ctx->compute_stage; zs = ctx->curr_compute->shader;
break; break;
} }
default: default:
@ -810,7 +809,7 @@ zink_program_descriptor_is_buffer(struct zink_context *ctx, gl_shader_stage stag
zs = ctx->gfx_stages[stage]; zs = ctx->gfx_stages[stage];
break; break;
case MESA_SHADER_COMPUTE: { case MESA_SHADER_COMPUTE: {
zs = ctx->compute_stage; zs = ctx->curr_compute->shader;
break; break;
} }
default: default:
@ -994,7 +993,7 @@ zink_get_compute_pipeline(struct zink_screen *screen,
} }
static inline void static inline void
bind_stage(struct zink_context *ctx, gl_shader_stage stage, bind_gfx_stage(struct zink_context *ctx, gl_shader_stage stage,
struct zink_shader *shader) struct zink_shader *shader)
{ {
if (shader && shader->nir->info.num_inlinable_uniforms) if (shader && shader->nir->info.num_inlinable_uniforms)
@ -1002,49 +1001,20 @@ bind_stage(struct zink_context *ctx, gl_shader_stage stage,
else else
ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage); ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage);
if (stage == MESA_SHADER_COMPUTE) { if (ctx->gfx_stages[stage])
if (ctx->compute_stage) { ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash;
ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; ctx->gfx_stages[stage] = shader;
ctx->compute_pipeline_state.module = VK_NULL_HANDLE; ctx->gfx_dirty = ctx->gfx_stages[MESA_SHADER_FRAGMENT] && ctx->gfx_stages[MESA_SHADER_VERTEX];
ctx->compute_pipeline_state.module_hash = 0; ctx->gfx_pipeline_state.modules_changed = true;
} if (shader) {
if (shader && shader != ctx->compute_stage) { ctx->shader_stages |= BITFIELD_BIT(stage);
struct hash_entry *entry = _mesa_hash_table_search(&ctx->compute_program_cache, shader); ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash;
if (entry) {
ctx->compute_pipeline_state.dirty = true;
ctx->curr_compute = entry->data;
} else {
struct zink_compute_program *comp = zink_create_compute_program(ctx, shader);
_mesa_hash_table_insert(&ctx->compute_program_cache, comp->shader, comp);
ctx->compute_pipeline_state.dirty = true;
ctx->curr_compute = comp;
zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
}
ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash;
ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader;
ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
if (ctx->compute_pipeline_state.key.base.nonseamless_cube_mask)
ctx->dirty_shader_stages |= BITFIELD_BIT(MESA_SHADER_COMPUTE);
} else if (!shader)
ctx->curr_compute = NULL;
ctx->compute_stage = shader;
zink_select_launch_grid(ctx);
} else { } else {
if (ctx->gfx_stages[stage]) ctx->gfx_pipeline_state.modules[stage] = VK_NULL_HANDLE;
ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash; if (ctx->curr_program)
ctx->gfx_stages[stage] = shader; ctx->gfx_pipeline_state.final_hash ^= ctx->curr_program->last_variant_hash;
ctx->gfx_dirty = ctx->gfx_stages[MESA_SHADER_FRAGMENT] && ctx->gfx_stages[MESA_SHADER_VERTEX]; ctx->curr_program = NULL;
ctx->gfx_pipeline_state.modules_changed = true; ctx->shader_stages &= ~BITFIELD_BIT(stage);
if (shader) {
ctx->shader_stages |= BITFIELD_BIT(stage);
ctx->gfx_hash ^= ctx->gfx_stages[stage]->hash;
} else {
ctx->gfx_pipeline_state.modules[stage] = VK_NULL_HANDLE;
if (ctx->curr_program)
ctx->gfx_pipeline_state.final_hash ^= ctx->curr_program->last_variant_hash;
ctx->curr_program = NULL;
ctx->shader_stages &= ~BITFIELD_BIT(stage);
}
} }
} }
@ -1096,7 +1066,7 @@ zink_bind_vs_state(struct pipe_context *pctx,
struct zink_context *ctx = zink_context(pctx); struct zink_context *ctx = zink_context(pctx);
if (!cso && !ctx->gfx_stages[MESA_SHADER_VERTEX]) if (!cso && !ctx->gfx_stages[MESA_SHADER_VERTEX])
return; return;
bind_stage(ctx, MESA_SHADER_VERTEX, cso); bind_gfx_stage(ctx, MESA_SHADER_VERTEX, cso);
bind_last_vertex_stage(ctx); bind_last_vertex_stage(ctx);
if (cso) { if (cso) {
struct zink_shader *zs = cso; struct zink_shader *zs = cso;
@ -1132,7 +1102,7 @@ zink_bind_fs_state(struct pipe_context *pctx,
struct zink_context *ctx = zink_context(pctx); struct zink_context *ctx = zink_context(pctx);
if (!cso && !ctx->gfx_stages[MESA_SHADER_FRAGMENT]) if (!cso && !ctx->gfx_stages[MESA_SHADER_FRAGMENT])
return; return;
bind_stage(ctx, MESA_SHADER_FRAGMENT, cso); bind_gfx_stage(ctx, MESA_SHADER_FRAGMENT, cso);
ctx->fbfetch_outputs = 0; ctx->fbfetch_outputs = 0;
if (cso) { if (cso) {
nir_shader *nir = ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir; nir_shader *nir = ctx->gfx_stages[MESA_SHADER_FRAGMENT]->nir;
@ -1155,7 +1125,7 @@ zink_bind_gs_state(struct pipe_context *pctx,
if (!cso && !ctx->gfx_stages[MESA_SHADER_GEOMETRY]) if (!cso && !ctx->gfx_stages[MESA_SHADER_GEOMETRY])
return; return;
bool had_points = ctx->gfx_stages[MESA_SHADER_GEOMETRY] ? ctx->gfx_stages[MESA_SHADER_GEOMETRY]->nir->info.gs.output_primitive == SHADER_PRIM_POINTS : false; bool had_points = ctx->gfx_stages[MESA_SHADER_GEOMETRY] ? ctx->gfx_stages[MESA_SHADER_GEOMETRY]->nir->info.gs.output_primitive == SHADER_PRIM_POINTS : false;
bind_stage(ctx, MESA_SHADER_GEOMETRY, cso); bind_gfx_stage(ctx, MESA_SHADER_GEOMETRY, cso);
bind_last_vertex_stage(ctx); bind_last_vertex_stage(ctx);
if (cso) { if (cso) {
if (!had_points && ctx->last_vertex_stage->nir->info.gs.output_primitive == SHADER_PRIM_POINTS) if (!had_points && ctx->last_vertex_stage->nir->info.gs.output_primitive == SHADER_PRIM_POINTS)
@ -1170,7 +1140,7 @@ static void
zink_bind_tcs_state(struct pipe_context *pctx, zink_bind_tcs_state(struct pipe_context *pctx,
void *cso) void *cso)
{ {
bind_stage(zink_context(pctx), MESA_SHADER_TESS_CTRL, cso); bind_gfx_stage(zink_context(pctx), MESA_SHADER_TESS_CTRL, cso);
} }
static void static void
@ -1187,7 +1157,7 @@ zink_bind_tes_state(struct pipe_context *pctx,
ctx->gfx_stages[MESA_SHADER_TESS_CTRL] = NULL; ctx->gfx_stages[MESA_SHADER_TESS_CTRL] = NULL;
} }
} }
bind_stage(ctx, MESA_SHADER_TESS_EVAL, cso); bind_gfx_stage(ctx, MESA_SHADER_TESS_EVAL, cso);
bind_last_vertex_stage(ctx); bind_last_vertex_stage(ctx);
} }
@ -1201,14 +1171,43 @@ zink_create_cs_state(struct pipe_context *pctx,
else else
nir = (struct nir_shader *)shader->prog; nir = (struct nir_shader *)shader->prog;
return zink_shader_create(zink_screen(pctx->screen), nir, NULL); return create_compute_program(zink_context(pctx), nir);
} }
static void static void
zink_bind_cs_state(struct pipe_context *pctx, zink_bind_cs_state(struct pipe_context *pctx,
void *cso) void *cso)
{ {
bind_stage(zink_context(pctx), MESA_SHADER_COMPUTE, cso); struct zink_context *ctx = zink_context(pctx);
struct zink_compute_program *comp = cso;
if (comp && comp->shader->nir->info.num_inlinable_uniforms)
ctx->shader_has_inlinable_uniforms_mask |= 1 << MESA_SHADER_COMPUTE;
else
ctx->shader_has_inlinable_uniforms_mask &= ~(1 << MESA_SHADER_COMPUTE);
if (ctx->curr_compute) {
zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base);
ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
ctx->compute_pipeline_state.module = VK_NULL_HANDLE;
ctx->compute_pipeline_state.module_hash = 0;
}
ctx->compute_pipeline_state.dirty = true;
ctx->curr_compute = comp;
if (comp && comp != ctx->curr_compute) {
ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash;
ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader;
ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash;
if (ctx->compute_pipeline_state.key.base.nonseamless_cube_mask)
ctx->dirty_shader_stages |= BITFIELD_BIT(MESA_SHADER_COMPUTE);
}
zink_select_launch_grid(ctx);
}
static void
zink_delete_cs_shader_state(struct pipe_context *pctx, void *cso)
{
struct zink_compute_program *comp = cso;
zink_compute_program_reference(zink_context(pctx), &comp, NULL);
} }
void void
@ -1269,7 +1268,7 @@ zink_program_init(struct zink_context *ctx)
ctx->base.create_compute_state = zink_create_cs_state; ctx->base.create_compute_state = zink_create_cs_state;
ctx->base.bind_compute_state = zink_bind_cs_state; ctx->base.bind_compute_state = zink_bind_cs_state;
ctx->base.delete_compute_state = zink_delete_shader_state; ctx->base.delete_compute_state = zink_delete_cs_shader_state;
if (zink_screen(ctx->base.screen)->info.have_EXT_vertex_input_dynamic_state) if (zink_screen(ctx->base.screen)->info.have_EXT_vertex_input_dynamic_state)
_mesa_set_init(&ctx->gfx_inputs, ctx, hash_gfx_input_dynamic, equals_gfx_input_dynamic); _mesa_set_init(&ctx->gfx_inputs, ctx, hash_gfx_input_dynamic, equals_gfx_input_dynamic);

View file

@ -161,8 +161,6 @@ zink_gfx_program_reference(struct zink_context *ctx,
return ret; return ret;
} }
struct zink_compute_program *
zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader);
void void
zink_destroy_compute_program(struct zink_context *ctx, zink_destroy_compute_program(struct zink_context *ctx,
struct zink_compute_program *comp); struct zink_compute_program *comp);

View file

@ -576,7 +576,6 @@ struct zink_shader_info {
bool have_vulkan_memory_model; bool have_vulkan_memory_model;
}; };
struct zink_shader { struct zink_shader {
struct util_live_shader base; struct util_live_shader base;
uint32_t hash; uint32_t hash;
@ -1376,9 +1375,7 @@ struct zink_context {
struct zink_descriptor_data dd; struct zink_descriptor_data dd;
struct zink_shader *compute_stage;
struct zink_compute_pipeline_state compute_pipeline_state; struct zink_compute_pipeline_state compute_pipeline_state;
struct hash_table compute_program_cache;
struct zink_compute_program *curr_compute; struct zink_compute_program *curr_compute;
unsigned shader_stages : ZINK_GFX_SHADER_COUNT; /* mask of bound gfx shader stages */ unsigned shader_stages : ZINK_GFX_SHADER_COUNT; /* mask of bound gfx shader stages */