diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index 547b51634a4..519e7d0d386 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -722,6 +722,16 @@ equals_compute_pipeline_state(const void *a, const void *b) sa->module == sb->module; } +static bool +equals_compute_pipeline_state_local_size(const void *a, const void *b) +{ + const struct zink_compute_pipeline_state *sa = a; + const struct zink_compute_pipeline_state *sb = b; + return !memcmp(a, b, offsetof(struct zink_compute_pipeline_state, hash)) && + !memcmp(sa->local_size, sb->local_size, sizeof(sa->local_size)) && + sa->module == sb->module; +} + static struct zink_compute_program * create_compute_program(struct zink_context *ctx, nir_shader *nir) { @@ -742,7 +752,9 @@ create_compute_program(struct zink_context *ctx, nir_shader *nir) nir->info.workgroup_size[1] || nir->info.workgroup_size[2]); - _mesa_hash_table_init(&comp->pipelines, comp, NULL, equals_compute_pipeline_state); + _mesa_hash_table_init(&comp->pipelines, comp, NULL, comp->use_local_size ? + equals_compute_pipeline_state_local_size : + equals_compute_pipeline_state); memcpy(comp->base.sha1, comp->shader->base.sha1, sizeof(comp->shader->base.sha1));