diff --git a/src/gallium/drivers/zink/zink_batch.c b/src/gallium/drivers/zink/zink_batch.c index fc4f7f6daba..1ee394fa427 100644 --- a/src/gallium/drivers/zink/zink_batch.c +++ b/src/gallium/drivers/zink/zink_batch.c @@ -21,8 +21,13 @@ zink_batch_release(struct zink_screen *screen, struct zink_batch *batch) zink_framebuffer_reference(screen, &batch->fb, NULL); set_foreach(batch->programs, entry) { - struct zink_gfx_program *prog = (struct zink_gfx_program*)entry->key; - zink_gfx_program_reference(screen, &prog, NULL); + if (batch->batch_id == ZINK_COMPUTE_BATCH_ID) { + struct zink_compute_program *comp = (struct zink_compute_program*)entry->key; + zink_compute_program_reference(screen, &comp, NULL); + } else { + struct zink_gfx_program *prog = (struct zink_gfx_program*)entry->key; + zink_gfx_program_reference(screen, &prog, NULL); + } } _mesa_set_clear(batch->programs, NULL); diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index df47c6c0742..d69c514025a 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -542,13 +542,20 @@ zink_shader_free(struct zink_context *ctx, struct zink_shader *shader) { struct zink_screen *screen = zink_screen(ctx->base.screen); set_foreach(shader->programs, entry) { - struct zink_gfx_program *prog = (void*)entry->key; - _mesa_hash_table_remove_key(ctx->program_cache, prog->shaders); - prog->shaders[pipe_shader_type_from_mesa(shader->nir->info.stage)] = NULL; - if (shader->nir->info.stage == MESA_SHADER_TESS_EVAL && shader->generated) + if (shader->nir->info.stage == MESA_SHADER_COMPUTE) { + struct zink_compute_program *comp = (void*)entry->key; + _mesa_hash_table_remove_key(ctx->compute_program_cache, &comp->shader->shader_id); + comp->shader = NULL; + zink_compute_program_reference(screen, &comp, NULL); + } else { + struct zink_gfx_program *prog = (void*)entry->key; + _mesa_hash_table_remove_key(ctx->program_cache, prog->shaders); + prog->shaders[pipe_shader_type_from_mesa(shader->nir->info.stage)] = NULL; + if (shader->nir->info.stage == MESA_SHADER_TESS_EVAL && shader->generated) /* automatically destroy generated tcs shaders when tes is destroyed */ zink_shader_free(ctx, shader->generated); - zink_gfx_program_reference(screen, &prog, NULL); + zink_gfx_program_reference(screen, &prog, NULL); + } } _mesa_set_destroy(shader->programs, NULL); free(shader->streamout.so_info_slots); diff --git a/src/gallium/drivers/zink/zink_context.c b/src/gallium/drivers/zink/zink_context.c index 7762b1fbf14..047927d5582 100644 --- a/src/gallium/drivers/zink/zink_context.c +++ b/src/gallium/drivers/zink/zink_context.c @@ -81,6 +81,7 @@ zink_context_destroy(struct pipe_context *pctx) u_upload_destroy(pctx->stream_uploader); slab_destroy_child(&ctx->transfer_pool); _mesa_hash_table_destroy(ctx->program_cache, NULL); + _mesa_hash_table_destroy(ctx->compute_program_cache, NULL); _mesa_hash_table_destroy(ctx->render_pass_cache, NULL); FREE(ctx); } @@ -1614,6 +1615,7 @@ zink_context_create(struct pipe_screen *pscreen, void *priv, unsigned flags) goto fail; ctx->gfx_pipeline_state.dirty = true; + ctx->compute_pipeline_state.dirty = true; ctx->base.screen = pscreen; ctx->base.priv = priv; @@ -1706,10 +1708,13 @@ zink_context_create(struct pipe_screen *pscreen, void *priv, unsigned flags) ctx->program_cache = _mesa_hash_table_create(NULL, hash_gfx_program, equals_gfx_program); + ctx->compute_program_cache = _mesa_hash_table_create(NULL, + _mesa_hash_uint, + _mesa_key_uint_equal); ctx->render_pass_cache = _mesa_hash_table_create(NULL, hash_render_pass_state, equals_render_pass_state); - if (!ctx->program_cache || !ctx->render_pass_cache) + if (!ctx->program_cache || !ctx->compute_program_cache || !ctx->render_pass_cache) goto fail; const uint8_t data[] = { 0 }; diff --git a/src/gallium/drivers/zink/zink_context.h b/src/gallium/drivers/zink/zink_context.h index 95af638bec7..ffb55454f7c 100644 --- a/src/gallium/drivers/zink/zink_context.h +++ b/src/gallium/drivers/zink/zink_context.h @@ -124,6 +124,11 @@ struct zink_context { struct hash_table *program_cache; struct zink_gfx_program *curr_program; + struct zink_shader *compute_stage; + struct zink_compute_pipeline_state compute_pipeline_state; + struct hash_table *compute_program_cache; + struct zink_compute_program *curr_compute; + unsigned dirty_shader_stages : 6; /* mask of changed shader stages */ bool last_vertex_stage_dirty; diff --git a/src/gallium/drivers/zink/zink_draw.c b/src/gallium/drivers/zink/zink_draw.c index 26df8741742..f11c8e91fa9 100644 --- a/src/gallium/drivers/zink/zink_draw.c +++ b/src/gallium/drivers/zink/zink_draw.c @@ -161,6 +161,29 @@ zink_bind_vertex_buffers(struct zink_batch *batch, struct zink_context *ctx) buffers, buffer_offsets); } +static struct zink_compute_program * +get_compute_program(struct zink_context *ctx) +{ + if (ctx->dirty_shader_stages) { + struct hash_entry *entry = _mesa_hash_table_search(ctx->compute_program_cache, + &ctx->compute_stage->shader_id); + if (!entry) { + struct zink_compute_program *comp; + comp = zink_create_compute_program(ctx, ctx->compute_stage); + entry = _mesa_hash_table_insert(ctx->compute_program_cache, &comp->shader->shader_id, comp); + if (!entry) + return NULL; + } + if (entry->data != ctx->curr_compute) + ctx->compute_pipeline_state.dirty = true; + ctx->curr_compute = entry->data; + ctx->dirty_shader_stages &= (1 << PIPE_SHADER_COMPUTE); + } + + assert(ctx->curr_compute); + return ctx->curr_compute; +} + static struct zink_gfx_program * get_gfx_program(struct zink_context *ctx) { @@ -185,7 +208,8 @@ get_gfx_program(struct zink_context *ctx) return NULL; } ctx->curr_program = entry->data; - ctx->dirty_shader_stages = 0; + unsigned bits = u_bit_consecutive(PIPE_SHADER_VERTEX, 5); + ctx->dirty_shader_stages &= ~bits; } assert(ctx->curr_program); diff --git a/src/gallium/drivers/zink/zink_pipeline.c b/src/gallium/drivers/zink/zink_pipeline.c index 0808dbdf910..15cc7065111 100644 --- a/src/gallium/drivers/zink/zink_pipeline.c +++ b/src/gallium/drivers/zink/zink_pipeline.c @@ -189,3 +189,29 @@ zink_create_gfx_pipeline(struct zink_screen *screen, return pipeline; } + +VkPipeline +zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_program *comp, struct zink_compute_pipeline_state *state) +{ + VkComputePipelineCreateInfo pci = {}; + pci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pci.flags = VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; + pci.layout = comp->layout; + + VkPipelineShaderStageCreateInfo stage = {}; + stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + stage.module = comp->module->shader; + stage.pName = "main"; + + pci.stage = stage; + + VkPipeline pipeline; + if (vkCreateComputePipelines(screen->dev, VK_NULL_HANDLE, 1, &pci, + NULL, &pipeline) != VK_SUCCESS) { + debug_printf("vkCreateComputePipelines failed\n"); + return VK_NULL_HANDLE; + } + + return pipeline; +} diff --git a/src/gallium/drivers/zink/zink_pipeline.h b/src/gallium/drivers/zink/zink_pipeline.h index 830f479fa91..18c40fd2c53 100644 --- a/src/gallium/drivers/zink/zink_pipeline.h +++ b/src/gallium/drivers/zink/zink_pipeline.h @@ -31,6 +31,7 @@ struct zink_blend_state; struct zink_depth_stencil_alpha_state; struct zink_gfx_program; +struct zink_compute_program; struct zink_rasterizer_state; struct zink_render_pass; struct zink_screen; @@ -67,10 +68,19 @@ struct zink_gfx_pipeline_state { bool dirty; }; +struct zink_compute_pipeline_state { + /* Pre-hashed value for table lookup, invalid when zero. + * Members after this point are not included in pipeline state hash key */ + uint32_t hash; + bool dirty; +}; + VkPipeline zink_create_gfx_pipeline(struct zink_screen *screen, struct zink_gfx_program *prog, struct zink_gfx_pipeline_state *state, VkPrimitiveTopology primitive_topology); +VkPipeline +zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_program *comp, struct zink_compute_pipeline_state *state); #endif diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index 918792faeb9..5e1577f1d51 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -40,12 +40,23 @@ struct gfx_pipeline_cache_entry { VkPipeline pipeline; }; +struct compute_pipeline_cache_entry { + struct zink_compute_pipeline_state state; + VkPipeline pipeline; +}; + void debug_describe_zink_gfx_program(char *buf, const struct zink_gfx_program *ptr) { sprintf(buf, "zink_gfx_program"); } +void +debug_describe_zink_compute_program(char *buf, const struct zink_compute_program *ptr) +{ + sprintf(buf, "zink_compute_program"); +} + static void debug_describe_zink_shader_module(char *buf, const struct zink_shader_module *ptr) { @@ -111,7 +122,7 @@ create_desc_set_layout(VkDevice dev, if (!shader) continue; - VkShaderStageFlagBits stage_flags = zink_shader_stage(i); + VkShaderStageFlagBits stage_flags = zink_shader_stage(pipe_shader_type_from_mesa(shader->nir->info.stage)); for (int j = 0; j < shader->num_bindings; j++) { assert(num_bindings < ARRAY_SIZE(bindings)); bindings[num_bindings].binding = shader->bindings[j].binding; @@ -141,7 +152,7 @@ create_desc_set_layout(VkDevice dev, } static VkPipelineLayout -create_pipeline_layout(VkDevice dev, VkDescriptorSetLayout dsl) +create_gfx_pipeline_layout(VkDevice dev, VkDescriptorSetLayout dsl) { assert(dsl != VK_NULL_HANDLE); @@ -168,6 +179,26 @@ create_pipeline_layout(VkDevice dev, VkDescriptorSetLayout dsl) return layout; } +static VkPipelineLayout +create_compute_pipeline_layout(VkDevice dev, VkDescriptorSetLayout dsl) +{ + assert(dsl != VK_NULL_HANDLE); + + VkPipelineLayoutCreateInfo plci = {}; + plci.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + + plci.pSetLayouts = &dsl; + plci.setLayoutCount = 1; + + VkPipelineLayout layout; + if (vkCreatePipelineLayout(dev, &plci, NULL, &layout) != VK_SUCCESS) { + debug_printf("vkCreatePipelineLayout failed!\n"); + return VK_NULL_HANDLE; + } + + return layout; +} + static void shader_key_vs_gen(struct zink_context *ctx, struct zink_shader *zs, struct zink_shader *shaders[ZINK_SHADER_COUNT], struct zink_shader_key *key) @@ -352,7 +383,8 @@ update_shader_modules(struct zink_context *ctx, struct zink_shader *stages[ZINK_ zink_shader_module_reference(zink_screen(ctx->base.screen), &prog->modules[type], ctx->curr_program->modules[type]); prog->shaders[type] = stages[type]; } - ctx->dirty_shader_stages = 0; + unsigned clean = u_bit_consecutive(PIPE_SHADER_VERTEX, 5);; + ctx->dirty_shader_stages &= ~clean; } static uint32_t @@ -439,7 +471,7 @@ zink_create_gfx_program(struct zink_context *ctx, if (!prog->dsl) goto fail; - prog->layout = create_pipeline_layout(screen->dev, prog->dsl); + prog->layout = create_gfx_pipeline_layout(screen->dev, prog->dsl); if (!prog->layout) goto fail; @@ -451,6 +483,81 @@ fail: return NULL; } +static uint32_t +hash_compute_pipeline_state(const void *key) +{ + return _mesa_hash_data(key, offsetof(struct zink_compute_pipeline_state, hash)); +} + +static bool +equals_compute_pipeline_state(const void *a, const void *b) +{ + return memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) == 0; +} + +struct zink_compute_program * +zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader) +{ + struct zink_screen *screen = zink_screen(ctx->base.screen); + struct zink_compute_program *comp = CALLOC_STRUCT(zink_compute_program); + if (!comp) + goto fail; + + pipe_reference_init(&comp->reference, 1); + + if (!ctx->curr_compute || !ctx->curr_compute->shader_cache) { + /* TODO: cs shader keys placeholder for now */ + comp->shader_cache = CALLOC_STRUCT(zink_shader_cache); + pipe_reference_init(&comp->shader_cache->reference, 1); + comp->shader_cache->shader_cache = _mesa_hash_table_create(NULL, _mesa_hash_u32, _mesa_key_u32_equal); + } else + zink_shader_cache_reference(zink_screen(ctx->base.screen), &comp->shader_cache, ctx->curr_compute->shader_cache); + + if (ctx->dirty_shader_stages & (1 << PIPE_SHADER_COMPUTE)) { + struct hash_entry *he = _mesa_hash_table_search(comp->shader_cache->shader_cache, &shader->shader_id); + if (he) + comp->module = he->data; + else { + comp->module = CALLOC_STRUCT(zink_shader_module); + assert(comp->module); + pipe_reference_init(&comp->module->reference, 1); + comp->module->shader = zink_shader_compile(screen, shader, NULL, NULL, NULL); + assert(comp->module->shader); + _mesa_hash_table_insert(comp->shader_cache->shader_cache, &shader->shader_id, comp->module); + } + } else + comp->module = ctx->curr_compute->module; + + struct zink_shader_module *zm = NULL; + zink_shader_module_reference(zink_screen(ctx->base.screen), &zm, comp->module); + ctx->dirty_shader_stages &= ~(1 << PIPE_SHADER_COMPUTE); + + comp->pipelines = _mesa_hash_table_create(NULL, hash_compute_pipeline_state, + equals_compute_pipeline_state); + + _mesa_set_add(shader->programs, comp); + zink_compute_program_reference(screen, NULL, comp); + comp->shader = shader; + + struct zink_shader *stages[ZINK_SHADER_COUNT] = {}; + stages[0] = shader; + comp->dsl = create_desc_set_layout(screen->dev, stages, + &comp->num_descriptors); + if (!comp->dsl) + goto fail; + + comp->layout = create_compute_pipeline_layout(screen->dev, comp->dsl); + if (!comp->layout) + goto fail; + + return comp; + +fail: + if (comp) + zink_destroy_compute_program(screen, comp); + return NULL; +} + static void gfx_program_remove_shader(struct zink_gfx_program *prog, struct zink_shader *shader) { @@ -492,6 +599,33 @@ zink_destroy_gfx_program(struct zink_screen *screen, FREE(prog); } +void +zink_destroy_compute_program(struct zink_screen *screen, + struct zink_compute_program *comp) +{ + if (comp->layout) + vkDestroyPipelineLayout(screen->dev, comp->layout, NULL); + + if (comp->dsl) + vkDestroyDescriptorSetLayout(screen->dev, comp->dsl, NULL); + + if (comp->shader) + _mesa_set_remove_key(comp->shader->programs, comp); + if (comp->module) + zink_shader_module_reference(screen, &comp->module, NULL); + + hash_table_foreach(comp->pipelines, entry) { + struct compute_pipeline_cache_entry *pc_entry = entry->data; + + vkDestroyPipeline(screen->dev, pc_entry->pipeline, NULL); + free(pc_entry); + } + _mesa_hash_table_destroy(comp->pipelines, NULL); + zink_shader_cache_reference(screen, &comp->shader_cache, NULL); + + FREE(comp); +} + static VkPrimitiveTopology primitive_topology(enum pipe_prim_type mode) { @@ -574,6 +708,39 @@ zink_get_gfx_pipeline(struct zink_screen *screen, return ((struct gfx_pipeline_cache_entry *)(entry->data))->pipeline; } +VkPipeline +zink_get_compute_pipeline(struct zink_screen *screen, + struct zink_compute_program *comp, + struct zink_compute_pipeline_state *state) +{ + struct hash_entry *entry = NULL; + + if (state->dirty) { + state->hash = hash_compute_pipeline_state(state); + state->dirty = false; + } + entry = _mesa_hash_table_search_pre_hashed(comp->pipelines, state->hash, state); + + if (!entry) { + VkPipeline pipeline = zink_create_compute_pipeline(screen, comp, state); + + if (pipeline == VK_NULL_HANDLE) + return VK_NULL_HANDLE; + + struct compute_pipeline_cache_entry *pc_entry = CALLOC_STRUCT(compute_pipeline_cache_entry); + if (!pc_entry) + return VK_NULL_HANDLE; + + memcpy(&pc_entry->state, state, sizeof(*state)); + pc_entry->pipeline = pipeline; + + entry = _mesa_hash_table_insert_pre_hashed(comp->pipelines, state->hash, state, pc_entry); + assert(entry); + } + + return ((struct compute_pipeline_cache_entry *)(entry->data))->pipeline; +} + static void * zink_create_vs_state(struct pipe_context *pctx, @@ -592,8 +759,10 @@ static void bind_stage(struct zink_context *ctx, enum pipe_shader_type stage, struct zink_shader *shader) { - assert(stage < PIPE_SHADER_COMPUTE); - ctx->gfx_stages[stage] = shader; + if (stage == PIPE_SHADER_COMPUTE) + ctx->compute_stage = shader; + else + ctx->gfx_stages[stage] = shader; ctx->dirty_shader_stages |= 1 << stage; } @@ -697,6 +866,25 @@ zink_delete_shader_state(struct pipe_context *pctx, void *cso) zink_shader_free(zink_context(pctx), cso); } +static void * +zink_create_cs_state(struct pipe_context *pctx, + const struct pipe_compute_state *shader) +{ + struct nir_shader *nir; + if (shader->ir_type != PIPE_SHADER_IR_NIR) + nir = zink_tgsi_to_nir(pctx->screen, shader->prog); + else + nir = (struct nir_shader *)shader->prog; + + return zink_shader_create(zink_screen(pctx->screen), nir, NULL); +} + +static void +zink_bind_cs_state(struct pipe_context *pctx, + void *cso) +{ + bind_stage(zink_context(pctx), PIPE_SHADER_COMPUTE, cso); +} void zink_program_init(struct zink_context *ctx) @@ -720,4 +908,8 @@ zink_program_init(struct zink_context *ctx) ctx->base.create_tes_state = zink_create_tes_state; ctx->base.bind_tes_state = zink_bind_tes_state; ctx->base.delete_tes_state = zink_delete_shader_state; + + ctx->base.create_compute_state = zink_create_cs_state; + ctx->base.bind_compute_state = zink_bind_cs_state; + ctx->base.delete_compute_state = zink_delete_shader_state; } diff --git a/src/gallium/drivers/zink/zink_program.h b/src/gallium/drivers/zink/zink_program.h index 9d726f67fe3..60f9173e457 100644 --- a/src/gallium/drivers/zink/zink_program.h +++ b/src/gallium/drivers/zink/zink_program.h @@ -69,6 +69,17 @@ struct zink_gfx_program { struct hash_table *pipelines[11]; // number of draw modes we support }; +struct zink_compute_program { + struct pipe_reference reference; + + struct zink_shader_module *module; + struct zink_shader *shader; + struct zink_shader_cache *shader_cache; + VkDescriptorSetLayout dsl; + VkPipelineLayout layout; + unsigned num_descriptors; + struct hash_table *pipelines; +}; void zink_update_gfx_program(struct zink_context *ctx, struct zink_gfx_program *prog); @@ -105,4 +116,31 @@ zink_gfx_program_reference(struct zink_screen *screen, zink_destroy_gfx_program(screen, old_dst); if (dst) *dst = src; } + +struct zink_compute_program * +zink_create_compute_program(struct zink_context *ctx, struct zink_shader *shader); +void +zink_destroy_compute_program(struct zink_screen *screen, + struct zink_compute_program *comp); + +void +debug_describe_zink_compute_program(char* buf, const struct zink_compute_program *ptr); + +static inline void +zink_compute_program_reference(struct zink_screen *screen, + struct zink_compute_program **dst, + struct zink_compute_program *src) +{ + struct zink_compute_program *old_dst = dst ? *dst : NULL; + + if (pipe_reference_described(old_dst ? &old_dst->reference : NULL, &src->reference, + (debug_reference_descriptor)debug_describe_zink_compute_program)) + zink_destroy_compute_program(screen, old_dst); + if (dst) *dst = src; +} + +VkPipeline +zink_get_compute_pipeline(struct zink_screen *screen, + struct zink_compute_program *comp, + struct zink_compute_pipeline_state *state); #endif