zink: dynamically emit non-bool register values using local_vars spirv buffer

this will be useful in a future commit

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22934>
This commit is contained in:
Mike Blumenkrantz 2023-05-09 15:11:33 -04:00 committed by Marge Bot
parent 871afadfe5
commit 096dcdbd01

View file

@ -1252,10 +1252,27 @@ get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
return ctx->defs[ssa->index];
}
static void
init_reg(struct ntv_context *ctx, nir_register *reg)
{
if (ctx->regs[reg->index])
return;
SpvId type = get_vec_from_bit_size(ctx, reg->bit_size, reg->num_components);
SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
SpvStorageClassFunction,
type);
SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
SpvStorageClassFunction);
ctx->regs[reg->index] = var;
}
static SpvId
get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
{
assert(reg->index < ctx->num_regs);
init_reg(ctx, reg);
assert(ctx->regs[reg->index] != 0);
return ctx->regs[reg->index];
}
@ -1399,7 +1416,8 @@ bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
static void
store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
{
SpvId var = get_var_from_reg(ctx, reg->reg);
init_reg(ctx, reg->reg);
SpvId var = ctx->regs[reg->reg->index];
assert(var);
spirv_builder_emit_store(&ctx->builder, var, result);
}
@ -4725,8 +4743,8 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
ctx.num_defs = entry->ssa_alloc;
nir_index_local_regs(entry);
ctx.regs = ralloc_array_size(ctx.mem_ctx,
sizeof(SpvId), entry->reg_alloc);
ctx.regs = rzalloc_array_size(ctx.mem_ctx,
sizeof(SpvId), entry->reg_alloc);
if (!ctx.regs)
goto fail;
ctx.num_regs = entry->reg_alloc;
@ -4745,15 +4763,10 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
/* emit a block only for the variable declarations */
start_block(&ctx, spirv_builder_new_id(&ctx.builder));
spirv_builder_begin_local_vars(&ctx.builder);
foreach_list_typed(nir_register, reg, node, &entry->registers) {
SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components);
SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
SpvStorageClassFunction,
type);
SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
SpvStorageClassFunction);
ctx.regs[reg->index] = var;
foreach_list_typed(nir_register, reg, node, &entry->registers) {
if (reg->bit_size == 1)
init_reg(&ctx, reg);
}
nir_foreach_variable_with_modes(var, s, nir_var_shader_temp)