From a8d90c8ed55e77344bcf277934a5ff2fa52d3e15 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Tue, 9 Nov 2021 10:07:49 -0500 Subject: [PATCH] zink: implement cs uniform inlining this implements shader variants for compute Reviewed-by: Dave Airlie Part-of: --- src/gallium/drivers/zink/zink_context.c | 7 +- src/gallium/drivers/zink/zink_context.h | 1 - src/gallium/drivers/zink/zink_draw.cpp | 6 ++ src/gallium/drivers/zink/zink_pipeline.h | 7 ++ src/gallium/drivers/zink/zink_program.c | 111 +++++++++++++++++++++-- src/gallium/drivers/zink/zink_program.h | 10 +- 6 files changed, 127 insertions(+), 15 deletions(-) diff --git a/src/gallium/drivers/zink/zink_context.c b/src/gallium/drivers/zink/zink_context.c index 75b376369c9..defca8d80e0 100644 --- a/src/gallium/drivers/zink/zink_context.c +++ b/src/gallium/drivers/zink/zink_context.c @@ -1047,18 +1047,17 @@ zink_set_inlinable_constants(struct pipe_context *pctx, struct zink_shader_key *key = NULL; if (shader == PIPE_SHADER_COMPUTE) { - inlinable_uniforms = ctx->compute_inlinable_uniforms; + key = &ctx->compute_pipeline_state.key; } else { key = &ctx->gfx_pipeline_state.shader_keys.key[shader]; - inlinable_uniforms = key->base.inlined_uniform_values; } + inlinable_uniforms = key->base.inlined_uniform_values; if (!(ctx->inlinable_uniforms_valid_mask & bit) || memcmp(inlinable_uniforms, values, num_values * 4)) { memcpy(inlinable_uniforms, values, num_values * 4); ctx->dirty_shader_stages |= bit; ctx->inlinable_uniforms_valid_mask |= bit; - if (key) - key->inline_uniforms = true; + key->inline_uniforms = true; } } diff --git a/src/gallium/drivers/zink/zink_context.h b/src/gallium/drivers/zink/zink_context.h index 8301e90241f..65aaf02c00f 100644 --- a/src/gallium/drivers/zink/zink_context.h +++ b/src/gallium/drivers/zink/zink_context.h @@ -199,7 +199,6 @@ struct zink_context { unsigned shader_has_inlinable_uniforms_mask; unsigned inlinable_uniforms_valid_mask; - uint32_t compute_inlinable_uniforms[MAX_INLINABLE_UNIFORMS]; struct pipe_constant_buffer ubos[PIPE_SHADER_TYPES][PIPE_MAX_CONSTANT_BUFFERS]; struct pipe_shader_buffer ssbos[PIPE_SHADER_TYPES][PIPE_MAX_SHADER_BUFFERS]; diff --git a/src/gallium/drivers/zink/zink_draw.cpp b/src/gallium/drivers/zink/zink_draw.cpp index a80e37b67c0..1548f64229c 100644 --- a/src/gallium/drivers/zink/zink_draw.cpp +++ b/src/gallium/drivers/zink/zink_draw.cpp @@ -882,6 +882,12 @@ zink_launch_grid(struct pipe_context *pctx, const struct pipe_grid_info *info) zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base); } + if (ctx->dirty_shader_stages & BITFIELD_BIT(PIPE_SHADER_COMPUTE)) { + /* update inlinable constants */ + zink_update_compute_program(ctx); + ctx->dirty_shader_stages &= ~BITFIELD_BIT(PIPE_SHADER_COMPUTE); + } + if (prev_pipeline != pipeline || BATCH_CHANGED) VKCTX(CmdBindPipeline)(batch->state->cmdbuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); if (BATCH_CHANGED) { diff --git a/src/gallium/drivers/zink/zink_pipeline.h b/src/gallium/drivers/zink/zink_pipeline.h index 4acc6c44285..04c6a05d25e 100644 --- a/src/gallium/drivers/zink/zink_pipeline.h +++ b/src/gallium/drivers/zink/zink_pipeline.h @@ -92,10 +92,17 @@ struct zink_compute_pipeline_state { /* Pre-hashed value for table lookup, invalid when zero. * Members after this point are not included in pipeline state hash key */ uint32_t hash; + uint32_t final_hash; bool dirty; bool use_local_size; uint32_t local_size[3]; + uint32_t module_hash; + VkShaderModule module; + bool module_changed; + + struct zink_shader_key key; + VkPipeline pipeline; }; diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index fa5461dc990..92e7baca53d 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -155,7 +155,7 @@ destroy_shader_cache(struct zink_screen *screen, struct list_head *sc) } static void -update_shader_modules(struct zink_context *ctx, +update_gfx_shader_modules(struct zink_context *ctx, struct zink_screen *screen, struct zink_gfx_program *prog, uint32_t mask, struct zink_gfx_pipeline_state *state) @@ -245,7 +245,87 @@ equals_gfx_pipeline_state(const void *a, const void *b) void zink_update_gfx_program(struct zink_context *ctx, struct zink_gfx_program *prog) { - update_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state); + update_gfx_shader_modules(ctx, zink_screen(ctx->base.screen), prog, ctx->dirty_shader_stages & prog->stages_present, &ctx->gfx_pipeline_state); +} + +static bool +uniforms_match(const struct zink_shader_module *zm, uint32_t *uniforms, unsigned num_uniforms) +{ + assert(zm->num_uniforms == num_uniforms); + return !memcmp(zm->key, uniforms, zm->num_uniforms * sizeof(uint32_t)); +} + +static uint32_t +cs_module_hash(const struct zink_shader_module *zm) +{ + return _mesa_hash_data(zm->key, zm->num_uniforms * sizeof(uint32_t)); +} + +static void +update_cs_shader_module(struct zink_context *ctx, struct zink_compute_program *comp) +{ + struct zink_shader *zs = comp->shader; + VkShaderModule mod; + struct zink_shader_module *zm = NULL; + unsigned base_size = 0; + struct zink_shader_key *key = &ctx->compute_pipeline_state.key; + + if (ctx && zs->nir->info.num_inlinable_uniforms && + ctx->inlinable_uniforms_valid_mask & BITFIELD64_BIT(PIPE_SHADER_COMPUTE)) { + if (comp->inlined_variant_count < ZINK_MAX_INLINED_VARIANTS) + base_size = zs->nir->info.num_inlinable_uniforms; + else + key->inline_uniforms = false; + } + + if (base_size) { + struct zink_shader_module *iter, *next; + LIST_FOR_EACH_ENTRY_SAFE(iter, next, &comp->shader_cache, list) { + if (!uniforms_match(iter, key->base.inlined_uniform_values, base_size)) + continue; + list_delinit(&iter->list); + zm = iter; + break; + } + } else { + zm = comp->module; + } + + if (!zm) { + zm = malloc(sizeof(struct zink_shader_module) + base_size * sizeof(uint32_t)); + if (!zm) { + return; + } + mod = zink_shader_compile(zink_screen(ctx->base.screen), zs, comp->shader->nir, key); + if (!mod) { + FREE(zm); + return; + } + zm->shader = mod; + list_inithead(&zm->list); + zm->num_uniforms = base_size; + zm->key_size = 0; + assert(base_size); + memcpy(zm->key, key->base.inlined_uniform_values, base_size * sizeof(uint32_t)); + zm->hash = cs_module_hash(zm); + zm->default_variant = false; + comp->inlined_variant_count++; + } + if (zm->num_uniforms) + list_add(&zm->list, &comp->shader_cache); + if (comp->curr == zm) + return; + ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; + comp->curr = zm; + ctx->compute_pipeline_state.module_hash = zm->hash; + ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; + ctx->compute_pipeline_state.module_changed = true; +} + +void +zink_update_compute_program(struct zink_context *ctx) +{ + update_cs_shader_module(ctx, ctx->curr_compute); } VkPipelineLayout @@ -418,7 +498,10 @@ zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink static bool equals_compute_pipeline_state(const void *a, const void *b) { - return memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) == 0; + const struct zink_compute_pipeline_state *sa = a; + const struct zink_compute_pipeline_state *sb = b; + return !memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) && + sa->module == sb->module; } struct zink_compute_program * @@ -432,12 +515,13 @@ zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader pipe_reference_init(&comp->base.reference, 1); comp->base.is_compute = true; - comp->module = CALLOC_STRUCT(zink_shader_module); + comp->curr = comp->module = CALLOC_STRUCT(zink_shader_module); assert(comp->module); comp->module->shader = zink_shader_compile(screen, shader, shader->nir, NULL); assert(comp->module->shader); + list_inithead(&comp->shader_cache); - comp->pipelines = _mesa_hash_table_create(NULL, hash_compute_pipeline_state, + comp->pipelines = _mesa_hash_table_create(NULL, NULL, equals_compute_pipeline_state); _mesa_set_add(shader->programs, comp); @@ -736,13 +820,16 @@ zink_get_compute_pipeline(struct zink_screen *screen, { struct hash_entry *entry = NULL; - if (!state->dirty) + if (!state->dirty && !state->module_changed) return state->pipeline; if (state->dirty) { + if (state->pipeline) //avoid on first hash + state->final_hash ^= state->hash; state->hash = hash_compute_pipeline_state(state); state->dirty = false; + state->final_hash ^= state->hash; } - entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->hash, state); + entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->final_hash, state); if (!entry) { util_queue_fence_wait(&comp->base.cache_fence); @@ -758,7 +845,7 @@ zink_get_compute_pipeline(struct zink_screen *screen, memcpy(&pc_entry->state, state, sizeof(*state)); pc_entry->pipeline = pipeline; - entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->hash, pc_entry, pc_entry); + entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->final_hash, pc_entry, pc_entry); assert(entry); } @@ -777,6 +864,11 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage, ctx->shader_has_inlinable_uniforms_mask &= ~(1 << stage); if (stage == PIPE_SHADER_COMPUTE) { + if (ctx->compute_stage) { + ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; + ctx->compute_pipeline_state.module = VK_NULL_HANDLE; + ctx->compute_pipeline_state.module_hash = 0; + } if (shader && shader != ctx->compute_stage) { struct hash_entry *entry = _mesa_hash_table_search(&ctx->compute_program_cache, shader); if (entry) { @@ -789,6 +881,9 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage, ctx->curr_compute = comp; zink_batch_reference_program(&ctx->batch, &ctx->curr_compute->base); } + ctx->compute_pipeline_state.module_hash = ctx->curr_compute->curr->hash; + ctx->compute_pipeline_state.module = ctx->curr_compute->curr->shader; + ctx->compute_pipeline_state.final_hash ^= ctx->compute_pipeline_state.module_hash; } else if (!shader) ctx->curr_compute = NULL; ctx->compute_stage = shader; diff --git a/src/gallium/drivers/zink/zink_program.h b/src/gallium/drivers/zink/zink_program.h index 111207f9541..78ac5f048ea 100644 --- a/src/gallium/drivers/zink/zink_program.h +++ b/src/gallium/drivers/zink/zink_program.h @@ -116,7 +116,12 @@ struct zink_gfx_program { struct zink_compute_program { struct zink_program base; - struct zink_shader_module *module; + struct zink_shader_module *curr; + + struct zink_shader_module *module; //base + struct list_head shader_cache; //inline uniforms + unsigned inlined_variant_count; + struct zink_shader *shader; struct hash_table *pipelines; }; @@ -272,7 +277,8 @@ zink_pipeline_layout_create(struct zink_screen *screen, struct zink_program *pg, void zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink_compute_program *comp, const uint block[3]); - +void +zink_update_compute_program(struct zink_context *ctx); VkPipeline zink_get_compute_pipeline(struct zink_screen *screen, struct zink_compute_program *comp,