From 096dcdbd01b0bfb8aac4fab6c72ae449043ff972 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Tue, 9 May 2023 15:11:33 -0400 Subject: [PATCH] zink: dynamically emit non-bool register values using local_vars spirv buffer this will be useful in a future commit Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index ea851ad4b87..6f4b19cd06f 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -1252,10 +1252,27 @@ get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa) return ctx->defs[ssa->index]; } +static void +init_reg(struct ntv_context *ctx, nir_register *reg) +{ + if (ctx->regs[reg->index]) + return; + + SpvId type = get_vec_from_bit_size(ctx, reg->bit_size, reg->num_components); + SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassFunction, + type); + SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type, + SpvStorageClassFunction); + + ctx->regs[reg->index] = var; +} + static SpvId get_var_from_reg(struct ntv_context *ctx, nir_register *reg) { assert(reg->index < ctx->num_regs); + init_reg(ctx, reg); assert(ctx->regs[reg->index] != 0); return ctx->regs[reg->index]; } @@ -1399,7 +1416,8 @@ bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size, static void store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result) { - SpvId var = get_var_from_reg(ctx, reg->reg); + init_reg(ctx, reg->reg); + SpvId var = ctx->regs[reg->reg->index]; assert(var); spirv_builder_emit_store(&ctx->builder, var, result); } @@ -4725,8 +4743,8 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ ctx.num_defs = entry->ssa_alloc; nir_index_local_regs(entry); - ctx.regs = ralloc_array_size(ctx.mem_ctx, - sizeof(SpvId), entry->reg_alloc); + ctx.regs = rzalloc_array_size(ctx.mem_ctx, + sizeof(SpvId), entry->reg_alloc); if (!ctx.regs) goto fail; ctx.num_regs = entry->reg_alloc; @@ -4745,15 +4763,10 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ /* emit a block only for the variable declarations */ start_block(&ctx, spirv_builder_new_id(&ctx.builder)); spirv_builder_begin_local_vars(&ctx.builder); - foreach_list_typed(nir_register, reg, node, &entry->registers) { - SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components); - SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder, - SpvStorageClassFunction, - type); - SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type, - SpvStorageClassFunction); - ctx.regs[reg->index] = var; + foreach_list_typed(nir_register, reg, node, &entry->registers) { + if (reg->bit_size == 1) + init_reg(&ctx, reg); } nir_foreach_variable_with_modes(var, s, nir_var_shader_temp)