From 9f6be8effb43fcd4ce2fd00045bc6244ddf63529 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Tue, 9 May 2023 11:26:04 -0400 Subject: [PATCH] zink: store and use alu types for ntv defs this adds indexing for ssa/reg defs with the accompanying current type of a given def (inaccurate for objects but whatever), enabling that type to be used directly in order to avoid bitcasts in some places this upends the assumption that all stored srcs are uint type Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 489 ++++++++++++------ 1 file changed, 317 insertions(+), 172 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 6f4b19cd06f..32c30172068 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -75,10 +75,12 @@ struct ntv_context { size_t num_entry_ifaces; SpvId *defs; + nir_alu_type *def_types; SpvId *resident_defs; size_t num_defs; SpvId *regs; + nir_alu_type *reg_types; size_t num_regs; struct hash_table *vars; /* nir_variable -> SpvId */ @@ -137,6 +139,31 @@ static SpvId emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src0, SpvId src1, SpvId src2); +static bool +alu_op_is_typeless(nir_op op) +{ + switch (op) { + case nir_op_mov: + case nir_op_vec16: + case nir_op_vec2: + case nir_op_vec3: + case nir_op_vec4: + case nir_op_vec5: + case nir_op_vec8: + case nir_op_bcsel: + return true; + default: + break; + } + return false; +} + +static nir_alu_type +get_nir_alu_type(const struct glsl_type *type) +{ + return nir_alu_type_get_base_type(nir_get_nir_type_for_glsl_base_type(glsl_get_base_type(glsl_without_array_or_matrix(type)))); +} + static SpvId get_bvec_type(struct ntv_context *ctx, int num_components) { @@ -352,25 +379,18 @@ get_alu_type(struct ntv_context *ctx, nir_alu_type type, unsigned num_components if (bit_size == 1) return get_bvec_type(ctx, num_components); + type = nir_alu_type_get_base_type(type); switch (nir_alu_type_get_base_type(type)) { case nir_type_bool: - unreachable("bool should have bit-size 1"); + return get_bvec_type(ctx, num_components); case nir_type_int: - case nir_type_int8: - case nir_type_int16: - case nir_type_int64: return get_ivec_type(ctx, bit_size, num_components); case nir_type_uint: - case nir_type_uint8: - case nir_type_uint16: - case nir_type_uint64: return get_uvec_type(ctx, bit_size, num_components); case nir_type_float: - case nir_type_float16: - case nir_type_float64: return get_fvec_type(ctx, bit_size, num_components); default: @@ -1245,20 +1265,21 @@ get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_c } static SpvId -get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa) +get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa, nir_alu_type *atype) { assert(ssa->index < ctx->num_defs); assert(ctx->defs[ssa->index] != 0); + *atype = ctx->def_types[ssa->index]; return ctx->defs[ssa->index]; } static void -init_reg(struct ntv_context *ctx, nir_register *reg) +init_reg(struct ntv_context *ctx, nir_register *reg, nir_alu_type atype) { if (ctx->regs[reg->index]) return; - SpvId type = get_vec_from_bit_size(ctx, reg->bit_size, reg->num_components); + SpvId type = get_alu_type(ctx, atype, reg->num_components, reg->bit_size); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassFunction, type); @@ -1266,45 +1287,48 @@ init_reg(struct ntv_context *ctx, nir_register *reg) SpvStorageClassFunction); ctx->regs[reg->index] = var; + ctx->reg_types[reg->index] = nir_alu_type_get_base_type(atype); } static SpvId -get_var_from_reg(struct ntv_context *ctx, nir_register *reg) +get_var_from_reg(struct ntv_context *ctx, nir_register *reg, nir_alu_type *atype) { assert(reg->index < ctx->num_regs); - init_reg(ctx, reg); + init_reg(ctx, reg, nir_type_uint); assert(ctx->regs[reg->index] != 0); + + *atype = ctx->reg_types[reg->index]; return ctx->regs[reg->index]; } static SpvId -get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg) +get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg, nir_alu_type *atype) { assert(reg->reg); assert(!reg->indirect); assert(!reg->base_offset); - SpvId var = get_var_from_reg(ctx, reg->reg); - SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components); + SpvId var = get_var_from_reg(ctx, reg->reg, atype); + SpvId type = get_alu_type(ctx, *atype, reg->reg->num_components, reg->reg->bit_size); return spirv_builder_emit_load(&ctx->builder, type, var); } static SpvId -get_src(struct ntv_context *ctx, nir_src *src) +get_src(struct ntv_context *ctx, nir_src *src, nir_alu_type *atype) { if (src->is_ssa) - return get_src_ssa(ctx, src->ssa); + return get_src_ssa(ctx, src->ssa, atype); else - return get_src_reg(ctx, &src->reg); + return get_src_reg(ctx, &src->reg, atype); } static SpvId -get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) +get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, nir_alu_type *atype) { assert(!alu->src[src].negate); assert(!alu->src[src].abs); - SpvId def = get_src(ctx, &alu->src[src].src); + SpvId def = get_src(ctx, &alu->src[src].src, atype); unsigned used_channels = 0; bool need_swizzle = false; @@ -1327,8 +1351,7 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) return def; int bit_size = nir_src_bit_size(alu->src[src].src); - SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) : - spirv_builder_type_uint(&ctx->builder, bit_size); + SpvId raw_type = get_alu_type(ctx, *atype, 1, bit_size); if (used_channels == 1) { uint32_t indices[] = { alu->src[src].swizzle[0] }; @@ -1369,10 +1392,11 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) } static void -store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result) +store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result, nir_alu_type atype) { assert(result != 0); assert(ssa->index < ctx->num_defs); + ctx->def_types[ssa->index] = nir_alu_type_get_base_type(atype); ctx->defs[ssa->index] = result; } @@ -1413,58 +1437,43 @@ bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size, return emit_bitcast(ctx, type, value); } -static void -store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result) +static SpvId +cast_src_to_type(struct ntv_context *ctx, SpvId value, nir_src src, nir_alu_type atype) { - init_reg(ctx, reg->reg); + atype = nir_alu_type_get_base_type(atype); + unsigned num_components = nir_src_num_components(src); + unsigned bit_size = nir_src_bit_size(src); + return emit_bitcast(ctx, get_alu_type(ctx, atype, num_components, bit_size), value); +} + +static void +store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result, nir_alu_type atype) +{ + atype = nir_alu_type_get_base_type(atype); + init_reg(ctx, reg->reg, atype); SpvId var = ctx->regs[reg->reg->index]; + nir_alu_type vtype = ctx->reg_types[reg->reg->index]; + if (atype != vtype) { + assert(vtype != nir_type_bool); + result = emit_bitcast(ctx, get_alu_type(ctx, vtype, reg->reg->num_components, reg->reg->bit_size), result); + } assert(var); spirv_builder_emit_store(&ctx->builder, var, result); } static void -store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result) +store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type atype) { if (dest->is_ssa) - store_ssa_def(ctx, &dest->ssa, result); + store_ssa_def(ctx, &dest->ssa, result, atype); else - store_reg_def(ctx, &dest->reg, result); + store_reg_def(ctx, &dest->reg, result, atype); } static void store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type) { - unsigned num_components = nir_dest_num_components(*dest); - unsigned bit_size = nir_dest_bit_size(*dest); - - if (bit_size != 1) { - switch (nir_alu_type_get_base_type(type)) { - case nir_type_bool: - assert("bool should have bit-size 1"); - break; - - case nir_type_uint: - case nir_type_uint8: - case nir_type_uint16: - case nir_type_uint64: - break; /* nothing to do! */ - - case nir_type_int: - case nir_type_int8: - case nir_type_int16: - case nir_type_int64: - case nir_type_float: - case nir_type_float16: - case nir_type_float64: - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - break; - - default: - unreachable("unsupported nir_alu_type"); - } - } - - store_dest_raw(ctx, dest, result); + store_dest_raw(ctx, dest, result, type); } static SpvId @@ -1974,13 +1983,16 @@ alu_instr_src_components(const nir_alu_instr *instr, unsigned src) } static SpvId -get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *raw_value) +get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *raw_value, nir_alu_type *atype) { - *raw_value = get_alu_src_raw(ctx, alu, src); + *raw_value = get_alu_src_raw(ctx, alu, src, atype); unsigned num_components = alu_instr_src_components(alu, src); unsigned bit_size = nir_src_bit_size(alu->src[src].src); - nir_alu_type type = nir_op_infos[alu->op].input_types[src]; + nir_alu_type type = alu_op_is_typeless(alu->op) ? *atype : nir_op_infos[alu->op].input_types[src]; + type = nir_alu_type_get_base_type(type); + if (type == *atype) + return *raw_value; if (bit_size == 1) return *raw_value; @@ -1993,7 +2005,7 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *ra return bitcast_to_ivec(ctx, *raw_value, bit_size, num_components); case nir_type_uint: - return *raw_value; + return bitcast_to_uvec(ctx, *raw_value, bit_size, num_components); case nir_type_float: return bitcast_to_fvec(ctx, *raw_value, bit_size, num_components); @@ -2005,11 +2017,10 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *ra } static void -store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result, bool force_float) +store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result, nir_alu_type atype) { assert(!alu->dest.saturate); - store_dest(ctx, &alu->dest.dest, result, - force_float ? nir_type_float : nir_op_infos[alu->op].output_type); + store_dest(ctx, &alu->dest.dest, result, atype); } static SpvId @@ -2038,16 +2049,66 @@ needs_derivative_control(nir_alu_instr *alu) static void emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) { + bool is_bcsel = alu->op == nir_op_bcsel; + nir_alu_type stype[NIR_MAX_VEC_COMPONENTS] = {0}; SpvId src[NIR_MAX_VEC_COMPONENTS]; SpvId raw_src[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) - src[i] = get_alu_src(ctx, alu, i, &raw_src[i]); + src[i] = get_alu_src(ctx, alu, i, &raw_src[i], &stype[i]); + + nir_alu_type typeless_type = stype[is_bcsel]; + if (nir_op_infos[alu->op].num_inputs > 1 && + alu_op_is_typeless(alu->op) && + nir_src_bit_size(alu->src[is_bcsel].src) != 1) { + unsigned uint_count = 0; + unsigned int_count = 0; + unsigned float_count = 0; + for (unsigned i = is_bcsel; i < nir_op_infos[alu->op].num_inputs; i++) { + if (stype[i] == nir_type_bool) + break; + switch (stype[i]) { + case nir_type_uint: + uint_count++; + break; + case nir_type_int: + int_count++; + break; + case nir_type_float: + float_count++; + break; + default: + unreachable("this shouldn't happen"); + } + } + if (uint_count > int_count && uint_count > float_count) + typeless_type = nir_type_uint; + else if (int_count > uint_count && int_count > float_count) + typeless_type = nir_type_int; + else if (float_count > uint_count && float_count > int_count) + typeless_type = nir_type_float; + else if (float_count == uint_count || uint_count == int_count) + typeless_type = nir_type_uint; + else if (float_count == int_count) + typeless_type = nir_type_float; + else + typeless_type = nir_type_uint; + assert(typeless_type != nir_type_bool); + for (unsigned i = is_bcsel; i < nir_op_infos[alu->op].num_inputs; i++) { + unsigned num_components = alu_instr_src_components(alu, i); + unsigned bit_size = nir_src_bit_size(alu->src[i].src); + SpvId type = get_alu_type(ctx, typeless_type, num_components, bit_size); + if (stype[i] != typeless_type) { + src[i] = emit_bitcast(ctx, type, src[i]); + } + } + } - SpvId dest_type = get_dest_type(ctx, &alu->dest.dest, - nir_op_infos[alu->op].output_type); - bool force_float = false; unsigned bit_size = nir_dest_bit_size(alu->dest.dest); unsigned num_components = nir_dest_num_components(alu->dest.dest); + nir_alu_type atype = bit_size == 1 ? + nir_type_bool : + (alu_op_is_typeless(alu->op) ? typeless_type : nir_op_infos[alu->op].output_type); + SpvId dest_type = get_dest_type(ctx, &alu->dest.dest, atype); if (needs_derivative_control(alu)) spirv_builder_emit_cap(&ctx->builder, SpvCapabilityDerivativeControl); @@ -2144,7 +2205,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) case nir_op: \ assert(nir_op_infos[alu->op].num_inputs == 1); \ result = emit_builtin_unop(ctx, spirv_op, get_dest_type(ctx, &alu->dest.dest, nir_type_float), src[0]); \ - force_float = true; \ + atype = nir_type_float; \ break; BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs) @@ -2415,7 +2476,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) if (alu->exact) spirv_builder_emit_decoration(&ctx->builder, result, SpvDecorationNoContraction); - store_alu_result(ctx, alu, result, force_float); + store_alu_result(ctx, alu, result, atype); } static void @@ -2425,16 +2486,19 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) unsigned num_components = load_const->def.num_components; SpvId components[NIR_MAX_VEC_COMPONENTS]; + nir_alu_type atype; if (bit_size == 1) { for (int i = 0; i < num_components; i++) components[i] = spirv_builder_const_bool(&ctx->builder, load_const->value[i].b); + atype = nir_type_bool; } else { for (int i = 0; i < num_components; i++) { uint64_t tmp = nir_const_value_as_uint(load_const->value[i], bit_size); components[i] = emit_uint_const(ctx, bit_size, tmp); } + atype = nir_type_uint; } if (num_components > 1) { @@ -2443,10 +2507,10 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) SpvId value = spirv_builder_const_composite(&ctx->builder, type, components, num_components); - store_ssa_def(ctx, &load_const->def, value); + store_ssa_def(ctx, &load_const->def, value, atype); } else { assert(num_components == 1); - store_ssa_def(ctx, &load_const->def, components[0]); + store_ssa_def(ctx, &load_const->def, components[0], atype); } } @@ -2462,7 +2526,8 @@ emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId ptr = get_src(ctx, intr->src); + nir_alu_type atype; + SpvId ptr = get_src(ctx, intr->src, &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); SpvId type; @@ -2472,8 +2537,10 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) type = get_image_type(ctx, var, glsl_type_is_sampler(gtype), glsl_get_sampler_dim(gtype) == GLSL_SAMPLER_DIM_BUF); + atype = nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(gtype)); } else { type = get_glsl_type(ctx, deref->type); + atype = get_nir_alu_type(deref->type); } SpvId result; @@ -2481,18 +2548,15 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) result = emit_atomic(ctx, SpvOpAtomicLoad, type, ptr, 0, 0); else result = spirv_builder_emit_load(&ctx->builder, type, ptr); - unsigned num_components = nir_dest_num_components(intr->dest); - unsigned bit_size = nir_dest_bit_size(intr->dest); - if (bit_size > 1) - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - store_dest(ctx, &intr->dest, result, nir_type_uint); + store_dest(ctx, &intr->dest, result, atype); } static void emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId ptr = get_src(ctx, &intr->src[0]); - SpvId src = get_src(ctx, &intr->src[1]); + nir_alu_type ptype, stype; + SpvId ptr = get_src(ctx, &intr->src[0], &ptype); + SpvId src = get_src(ctx, &intr->src[1], &stype); const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; SpvId type = get_glsl_type(ctx, gtype); @@ -2508,7 +2572,7 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId member_type; if (glsl_type_is_vector(gtype)) { result_type = get_glsl_basetype(ctx, glsl_get_base_type(gtype)); - member_type = get_uvec_type(ctx, glsl_get_bit_size(gtype), 1); + member_type = get_alu_type(ctx, stype, 1, glsl_get_bit_size(gtype)); } else member_type = result_type = get_glsl_type(ctx, glsl_get_array_element(gtype)); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, @@ -2518,7 +2582,8 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) if (wrmask & BITFIELD_BIT(i)) { SpvId idx = emit_uint_const(ctx, 32, i); SpvId val = spirv_builder_emit_composite_extract(&ctx->builder, member_type, src, &i, 1); - val = emit_bitcast(ctx, result_type, val); + if (stype != ptype) + val = emit_bitcast(ctx, result_type, val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, ptr, &idx, 1); spirv_builder_emit_store(&ctx->builder, member, val); @@ -2533,10 +2598,12 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) src = emit_bitcast(ctx, type, src); /* SampleMask is always an array in spirv, so we need to construct it into one */ result = spirv_builder_emit_composite_construct(&ctx->builder, ctx->sample_mask_type, &src, 1); - } else if (glsl_get_base_type(glsl_without_array(gtype)) == GLSL_TYPE_BOOL) { - result = src; - } else - result = emit_bitcast(ctx, type, src); + } else { + if (ptype == stype) + result = src; + else + result = emit_bitcast(ctx, type, src); + } if (nir_intrinsic_access(intr) & ACCESS_COHERENT) spirv_builder_emit_atomic_store(&ctx->builder, ptr, SpvScopeDevice, 0, result); else @@ -2553,7 +2620,10 @@ emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, uint_type); - SpvId offset = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId offset = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId constituents[NIR_MAX_VEC_COMPONENTS]; SpvId shared_block = get_shared_block(ctx, bit_size); /* need to convert array -> vec */ @@ -2574,7 +2644,8 @@ emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId src = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId src = get_src(ctx, &intr->src[0], &atype); unsigned wrmask = nir_intrinsic_write_mask(intr); unsigned bit_size = nir_src_bit_size(intr->src[0]); @@ -2582,7 +2653,10 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, uint_type); - SpvId offset = get_src(ctx, &intr->src[1]); + nir_alu_type otype; + SpvId offset = get_src(ctx, &intr->src[1], &otype); + if (otype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId shared_block = get_shared_block(ctx, bit_size); /* this is a partial write, so we have to loop and do a per-component write */ u_foreach_bit(i, wrmask) { @@ -2590,6 +2664,8 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId val = src; if (nir_src_num_components(intr->src[0]) != 1) val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); + if (atype != nir_type_uint) + val = emit_bitcast(ctx, get_alu_type(ctx, nir_type_uint, 1, bit_size), val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, shared_block, &shared_offset, 1); spirv_builder_emit_store(&ctx->builder, member, val); @@ -2606,7 +2682,10 @@ emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPrivate, uint_type); - SpvId offset = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId offset = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId constituents[NIR_MAX_VEC_COMPONENTS]; SpvId scratch_block = get_scratch_block(ctx, bit_size); /* need to convert array -> vec */ @@ -2627,7 +2706,8 @@ emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId src = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId src = get_src(ctx, &intr->src[0], &atype); unsigned wrmask = nir_intrinsic_write_mask(intr); unsigned bit_size = nir_src_bit_size(intr->src[0]); @@ -2635,7 +2715,10 @@ emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPrivate, uint_type); - SpvId offset = get_src(ctx, &intr->src[1]); + nir_alu_type otype; + SpvId offset = get_src(ctx, &intr->src[1], &otype); + if (otype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId scratch_block = get_scratch_block(ctx, bit_size); /* this is a partial write, so we have to loop and do a per-component write */ u_foreach_bit(i, wrmask) { @@ -2643,6 +2726,8 @@ emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId val = src; if (nir_src_num_components(intr->src[0]) != 1) val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); + if (atype != nir_type_uint) + val = emit_bitcast(ctx, get_alu_type(ctx, nir_type_uint, 1, bit_size), val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, scratch_block, &scratch_offset, 1); spirv_builder_emit_store(&ctx->builder, member, val); @@ -2669,7 +2754,10 @@ emit_load_push_const(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvStorageClassPushConstant, load_type); - SpvId member = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId member = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + member = bitcast_to_uvec(ctx, member, nir_src_bit_size(intr->src[0]), 1); /* reuse the offset from ZINK_PUSH_CONST_OFFSET */ SpvId offset = emit_uint_const(ctx, 32, 0); /* OpAccessChain takes an array of indices that drill into a hierarchy based on the type: @@ -2710,7 +2798,8 @@ emit_load_global(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPhysicalStorageBuffer, dest_type); - SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[0])); + nir_alu_type atype; + SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[0], &atype)); SpvId result = spirv_builder_emit_load(&ctx->builder, dest_type, ptr); store_dest(ctx, &intr->dest, result, nir_type_uint); } @@ -2724,8 +2813,9 @@ emit_store_global(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPhysicalStorageBuffer, dest_type); - SpvId param = get_src(ctx, &intr->src[0]); - SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[1])); + nir_alu_type atype; + SpvId param = get_src(ctx, &intr->src[0], &atype); + SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[1], &atype)); spirv_builder_emit_store(&ctx->builder, ptr, param); } @@ -2842,37 +2932,38 @@ emit_interpolate(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId op; spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInterpolationFunction); SpvId src1 = 0; + nir_alu_type atype; switch (intr->intrinsic) { case nir_intrinsic_interp_deref_at_centroid: op = GLSLstd450InterpolateAtCentroid; break; case nir_intrinsic_interp_deref_at_sample: op = GLSLstd450InterpolateAtSample; - src1 = get_src(ctx, &intr->src[1]); + src1 = get_src(ctx, &intr->src[1], &atype); break; case nir_intrinsic_interp_deref_at_offset: op = GLSLstd450InterpolateAtOffset; - src1 = get_src(ctx, &intr->src[1]); + src1 = get_src(ctx, &intr->src[1], &atype); /* The offset operand must be a vector of 2 components of 32-bit floating-point type. - InterpolateAtOffset spec */ - src1 = emit_bitcast(ctx, get_fvec_type(ctx, 32, 2), src1); + if (atype != nir_type_float) + src1 = emit_bitcast(ctx, get_fvec_type(ctx, 32, 2), src1); break; default: unreachable("unknown interp op"); } - SpvId ptr = get_src(ctx, &intr->src[0]); + nir_alu_type ptype; + SpvId ptr = get_src(ctx, &intr->src[0], &ptype); SpvId result; + const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; + assert(ptype == get_nir_alu_type(gtype)); if (intr->intrinsic == nir_intrinsic_interp_deref_at_centroid) - result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type), ptr); + result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, gtype), ptr); else - result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type), - ptr, src1); - unsigned num_components = nir_dest_num_components(intr->dest); - unsigned bit_size = nir_dest_bit_size(intr->dest); - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - store_dest(ctx, &intr->dest, result, nir_type_uint); + result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, gtype), ptr, src1); + store_dest(ctx, &intr->dest, result, ptype); } static void @@ -2887,19 +2978,25 @@ handle_atomic_op(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId ptr, static void emit_deref_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId ptr = get_src(ctx, &intr->src[0]); - SpvId param = get_src(ctx, &intr->src[1]); + nir_alu_type atype; + nir_alu_type ret_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)) == nir_type_float ? nir_type_float : nir_type_uint; + SpvId ptr = get_src(ctx, &intr->src[0], &atype); + SpvId param = get_src(ctx, &intr->src[1], &atype); + if (atype != ret_type) + param = cast_src_to_type(ctx, param, intr->src[1], ret_type); SpvId param2 = 0; if (nir_src_bit_size(intr->src[1]) == 64) spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics); - if (intr->intrinsic == nir_intrinsic_deref_atomic_swap) - param2 = get_src(ctx, &intr->src[2]); + if (intr->intrinsic == nir_intrinsic_deref_atomic_swap) { + param2 = get_src(ctx, &intr->src[2], &atype); + if (atype != ret_type) + param2 = cast_src_to_type(ctx, param2, intr->src[2], ret_type); + } - nir_alu_type op_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)); - handle_atomic_op(ctx, intr, ptr, param, param2, op_type == nir_type_float ? nir_type_float : nir_type_uint32); + handle_atomic_op(ctx, intr, ptr, param, param2, ret_type); } static void @@ -2907,12 +3004,19 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { unsigned bit_size = nir_src_bit_size(intr->src[1]); SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint); - SpvId param = get_src(ctx, &intr->src[1]); + nir_alu_type atype; + nir_alu_type ret_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)) == nir_type_float ? nir_type_float : nir_type_uint; + SpvId param = get_src(ctx, &intr->src[1], &atype); + if (atype != ret_type) + param = cast_src_to_type(ctx, param, intr->src[1], ret_type); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, dest_type); - SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, bit_size / 8)); + SpvId offset = get_src(ctx, &intr->src[0], &atype); + if (atype != nir_type_uint) + offset = cast_src_to_type(ctx, offset, intr->src[0], nir_type_uint); + offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), offset, emit_uint_const(ctx, 32, bit_size / 8)); SpvId shared_block = get_shared_block(ctx, bit_size); SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type, shared_block, &offset, 1); @@ -2920,11 +3024,13 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics); SpvId param2 = 0; - if (intr->intrinsic == nir_intrinsic_shared_atomic_swap) - param2 = get_src(ctx, &intr->src[2]); + if (intr->intrinsic == nir_intrinsic_shared_atomic_swap) { + param2 = get_src(ctx, &intr->src[2], &atype); + if (atype != ret_type) + param2 = cast_src_to_type(ctx, param2, intr->src[2], ret_type); + } - nir_alu_type op_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)); - handle_atomic_op(ctx, intr, ptr, param, param2, op_type == nir_type_float ? nir_type_float : nir_type_uint32); + handle_atomic_op(ctx, intr, ptr, param, param2, ret_type); } static void @@ -2937,7 +3043,10 @@ emit_get_ssbo_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassStorageBuffer, get_bo_struct_type(ctx, var)); - SpvId bo = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId bo = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + bo = bitcast_to_uvec(ctx, bo, nir_src_bit_size(intr->src[0]), 1); SpvId indices[] = { bo }; SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type, ctx->ssbos[2], indices, @@ -2966,16 +3075,17 @@ get_image_coords(struct ntv_context *ctx, const struct glsl_type *type, nir_src uint32_t num_coords = glsl_get_sampler_coordinate_components(type); uint32_t src_components = nir_src_num_components(*src); - SpvId spv = get_src(ctx, src); + nir_alu_type atype; + SpvId spv = get_src(ctx, src, &atype); if (num_coords == src_components) return spv; /* need to extract the coord dimensions that the image can use */ - SpvId vec_type = get_uvec_type(ctx, 32, num_coords); + SpvId vec_type = get_alu_type(ctx, atype, num_coords, 32); if (num_coords == 1) return spirv_builder_emit_vector_extract(&ctx->builder, vec_type, spv, 0); uint32_t constituents[4]; - SpvId zero = emit_uint_const(ctx, nir_src_bit_size(*src), 0); + SpvId zero = atype == nir_type_uint ? emit_uint_const(ctx, nir_src_bit_size(*src), 0) : emit_float_const(ctx, nir_src_bit_size(*src), 0); assert(num_coords < ARRAY_SIZE(constituents)); for (unsigned i = 0; i < num_coords; i++) constituents[i] = i < src_components ? i : zero; @@ -2985,7 +3095,8 @@ get_image_coords(struct ntv_context *ctx, const struct glsl_type *type, nir_src static void emit_image_deref_store(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); SpvId img_type = find_image_type(ctx, var); @@ -2993,15 +3104,16 @@ emit_image_deref_store(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId base_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type)); SpvId img = spirv_builder_emit_load(&ctx->builder, img_type, img_var); SpvId coord = get_image_coords(ctx, type, &intr->src[1]); - SpvId texel = get_src(ctx, &intr->src[3]); + SpvId texel = get_src(ctx, &intr->src[3], &atype); + /* texel type must match image type */ + if (atype != nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(type))) + texel = emit_bitcast(ctx, + spirv_builder_type_vector(&ctx->builder, base_type, 4), + texel); bool use_sample = glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_MS || glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_SUBPASS_MS; - SpvId sample = use_sample ? get_src(ctx, &intr->src[2]) : 0; + SpvId sample = use_sample ? get_src(ctx, &intr->src[2], &atype) : 0; assert(nir_src_bit_size(intr->src[3]) == glsl_base_type_bit_size(glsl_get_sampler_result_type(type))); - /* texel type must match image type */ - texel = emit_bitcast(ctx, - spirv_builder_type_vector(&ctx->builder, base_type, 4), - texel); spirv_builder_emit_image_write(&ctx->builder, img, coord, texel, 0, sample, 0); } @@ -3042,7 +3154,8 @@ static void emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) { bool sparse = intr->intrinsic == nir_intrinsic_image_deref_sparse_load; - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); bool mediump = (var->data.precision == GLSL_PRECISION_MEDIUM || var->data.precision == GLSL_PRECISION_LOW); @@ -3053,7 +3166,7 @@ emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId coord = get_image_coords(ctx, type, &intr->src[1]); bool use_sample = glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_MS || glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_SUBPASS_MS; - SpvId sample = use_sample ? get_src(ctx, &intr->src[2]) : 0; + SpvId sample = use_sample ? get_src(ctx, &intr->src[2], &atype) : 0; SpvId dest_type = spirv_builder_type_vector(&ctx->builder, base_type, nir_dest_num_components(intr->dest)); SpvId result = spirv_builder_emit_image_read(&ctx->builder, dest_type, @@ -3066,13 +3179,14 @@ emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvDecorationRelaxedPrecision); } - store_dest(ctx, &intr->dest, result, nir_type_float); + store_dest(ctx, &intr->dest, result, nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(type))); } static void emit_image_deref_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); SpvId img_type = find_image_type(ctx, var); @@ -3091,7 +3205,8 @@ emit_image_deref_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_image_deref_samples(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); SpvId img_type = find_image_type(ctx, var); @@ -3105,14 +3220,15 @@ emit_image_deref_samples(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId param = get_src(ctx, &intr->src[3]); - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype, ptype; + SpvId param = get_src(ctx, &intr->src[3], &ptype); + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); const struct glsl_type *type = glsl_without_array(var->type); bool is_ms; type_to_dim(glsl_get_sampler_dim(type), &is_ms); - SpvId sample = is_ms ? get_src(ctx, &intr->src[2]) : emit_uint_const(ctx, 32, 0); + SpvId sample = is_ms ? get_src(ctx, &intr->src[2], &atype) : emit_uint_const(ctx, 32, 0); SpvId coord = get_image_coords(ctx, type, &intr->src[1]); enum glsl_base_type glsl_type = glsl_get_sampler_result_type(type); SpvId base_type = get_glsl_basetype(ctx, glsl_type); @@ -3123,12 +3239,17 @@ emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) * The type of the value pointed to by Pointer must be the same as Result Type. */ nir_alu_type ntype = nir_get_nir_type_for_glsl_base_type(glsl_type); - SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype); - param = emit_bitcast(ctx, cast_type, param); + if (ptype != ntype) { + SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype); + param = emit_bitcast(ctx, cast_type, param); + } if (intr->intrinsic == nir_intrinsic_image_deref_atomic_swap) { - param2 = get_src(ctx, &intr->src[4]); - param2 = emit_bitcast(ctx, cast_type, param2); + param2 = get_src(ctx, &intr->src[4], &ptype); + if (ptype != ntype) { + SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype); + param2 = emit_bitcast(ctx, cast_type, param2); + } } handle_atomic_op(ctx, intr, texel, param, param2, ntype); @@ -3140,7 +3261,8 @@ emit_ballot(struct ntv_context *ctx, nir_intrinsic_instr *intr) spirv_builder_emit_cap(&ctx->builder, SpvCapabilitySubgroupBallotKHR); spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_shader_ballot"); SpvId type = get_dest_uvec_type(ctx, &intr->dest); - SpvId result = emit_unop(ctx, SpvOpSubgroupBallotKHR, type, get_src(ctx, &intr->src[0])); + nir_alu_type atype; + SpvId result = emit_unop(ctx, SpvOpSubgroupBallotKHR, type, get_src(ctx, &intr->src[0], &atype)); store_dest(ctx, &intr->dest, result, nir_type_uint); } @@ -3149,9 +3271,11 @@ emit_read_first_invocation(struct ntv_context *ctx, nir_intrinsic_instr *intr) { spirv_builder_emit_cap(&ctx->builder, SpvCapabilitySubgroupBallotKHR); spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_shader_ballot"); - SpvId type = get_dest_type(ctx, &intr->dest, nir_type_uint); - SpvId result = emit_unop(ctx, SpvOpSubgroupFirstInvocationKHR, type, get_src(ctx, &intr->src[0])); - store_dest(ctx, &intr->dest, result, nir_type_uint); + nir_alu_type atype; + SpvId src = get_src(ctx, &intr->src[0], &atype); + SpvId type = get_dest_type(ctx, &intr->dest, atype); + SpvId result = emit_unop(ctx, SpvOpSubgroupFirstInvocationKHR, type, src); + store_dest(ctx, &intr->dest, result, atype); } static void @@ -3159,11 +3283,13 @@ emit_read_invocation(struct ntv_context *ctx, nir_intrinsic_instr *intr) { spirv_builder_emit_cap(&ctx->builder, SpvCapabilitySubgroupBallotKHR); spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_shader_ballot"); - SpvId type = get_dest_type(ctx, &intr->dest, nir_type_uint); + nir_alu_type atype, itype; + SpvId src = get_src(ctx, &intr->src[0], &atype); + SpvId type = get_dest_type(ctx, &intr->dest, atype); SpvId result = emit_binop(ctx, SpvOpSubgroupReadInvocationKHR, type, - get_src(ctx, &intr->src[0]), - get_src(ctx, &intr->src[1])); - store_dest(ctx, &intr->dest, result, nir_type_uint); + src, + get_src(ctx, &intr->src[1], &itype)); + store_dest(ctx, &intr->dest, result, atype); } static void @@ -3220,8 +3346,9 @@ emit_vote(struct ntv_context *ctx, nir_intrinsic_instr *intr) unreachable("unknown vote intrinsic"); } spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformVote); - SpvId result = spirv_builder_emit_vote(&ctx->builder, op, get_src(ctx, &intr->src[0])); - store_dest_raw(ctx, &intr->dest, result); + nir_alu_type atype; + SpvId result = spirv_builder_emit_vote(&ctx->builder, op, get_src(ctx, &intr->src[0], &atype)); + store_dest_raw(ctx, &intr->dest, result, nir_type_bool); } static void @@ -3546,13 +3673,17 @@ emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef) undef->def.num_components); store_ssa_def(ctx, &undef->def, - spirv_builder_emit_undef(&ctx->builder, type)); + spirv_builder_emit_undef(&ctx->builder, type), + undef->def.bit_size == 1 ? nir_type_bool : nir_type_uint); } static SpvId get_src_float(struct ntv_context *ctx, nir_src *src) { - SpvId def = get_src(ctx, src); + nir_alu_type atype; + SpvId def = get_src(ctx, src, &atype); + if (atype == nir_type_float) + return def; unsigned num_components = nir_src_num_components(*src); unsigned bit_size = nir_src_bit_size(*src); return bitcast_to_fvec(ctx, def, bit_size, num_components); @@ -3561,7 +3692,10 @@ get_src_float(struct ntv_context *ctx, nir_src *src) static SpvId get_src_int(struct ntv_context *ctx, nir_src *src) { - SpvId def = get_src(ctx, src); + nir_alu_type atype; + SpvId def = get_src(ctx, src, &atype); + if (atype == nir_type_int) + return def; unsigned num_components = nir_src_num_components(*src); unsigned bit_size = nir_src_bit_size(*src); return bitcast_to_ivec(ctx, def, bit_size, num_components); @@ -3602,6 +3736,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex) const_offset = 0, offset = 0, sample = 0, tex_offset = 0, bindless = 0, min_lod = 0; unsigned coord_components = 0; nir_variable *bindless_var = NULL; + nir_alu_type atype; for (unsigned i = 0; i < tex->num_srcs; i++) { nir_const_value *cv; switch (tex->src[i].src_type) { @@ -3698,7 +3833,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex) break; case nir_tex_src_texture_handle: - bindless = get_src(ctx, &tex->src[i].src); + bindless = get_src(ctx, &tex->src[i].src, &atype); bindless_var = nir_deref_instr_get_variable(nir_src_as_deref(tex->src[i].src)); break; @@ -3968,7 +4103,7 @@ emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref) struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var); assert(he); SpvId result = (SpvId)(intptr_t)he->data; - store_dest_raw(ctx, &deref->dest, result); + store_dest_raw(ctx, &deref->dest, result, get_nir_alu_type(deref->type)); } static void @@ -4001,11 +4136,12 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) SpvStorageClass storage_class = get_storage_class(var); SpvId base, type; + nir_alu_type atype = nir_type_uint; switch (var->data.mode) { case nir_var_mem_ubo: case nir_var_mem_ssbo: - base = get_src(ctx, &deref->parent); + base = get_src(ctx, &deref->parent, &atype); /* this is either the array deref or the array deref */ if (glsl_type_is_struct_or_ifc(deref->type)) { /* array */ @@ -4017,7 +4153,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) case nir_var_function_temp: case nir_var_shader_in: case nir_var_shader_out: - base = get_src(ctx, &deref->parent); + base = get_src(ctx, &deref->parent, &atype); type = get_glsl_type(ctx, deref->type); break; @@ -4037,7 +4173,10 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) unreachable("Unsupported nir_variable_mode\n"); } - SpvId index = get_src(ctx, &deref->arr.index); + nir_alu_type itype; + SpvId index = get_src(ctx, &deref->arr.index, &itype); + if (itype == nir_type_float) + index = emit_bitcast(ctx, get_uvec_type(ctx, 32, 1), index); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, storage_class, @@ -4048,7 +4187,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) base, &index, 1); /* uint is a bit of a lie here, it's really just an opaque type */ - store_dest(ctx, &deref->dest, result, nir_type_uint); + store_dest(ctx, &deref->dest, result, get_nir_alu_type(deref->type)); } static void @@ -4068,12 +4207,13 @@ emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref) storage_class, type); + nir_alu_type atype; SpvId result = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, - get_src(ctx, &deref->parent), + get_src(ctx, &deref->parent, &atype), &index, 1); /* uint is a bit of a lie here, it's really just an opaque type */ - store_dest(ctx, &deref->dest, result, nir_type_uint); + store_dest(ctx, &deref->dest, result, get_nir_alu_type(deref->type)); } static void @@ -4144,7 +4284,8 @@ static SpvId get_src_bool(struct ntv_context *ctx, nir_src *src) { assert(nir_src_bit_size(*src) == 1); - return get_src(ctx, src); + nir_alu_type atype; + return get_src(ctx, src, &atype); } static void @@ -4728,7 +4869,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ ctx.defs = ralloc_array_size(ctx.mem_ctx, sizeof(SpvId), entry->ssa_alloc); - if (!ctx.defs) + ctx.def_types = ralloc_array_size(ctx.mem_ctx, + sizeof(nir_alu_type), entry->ssa_alloc); + if (!ctx.defs || !ctx.def_types) goto fail; if (sinfo->have_sparse) { spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySparseResidency); @@ -4745,7 +4888,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ nir_index_local_regs(entry); ctx.regs = rzalloc_array_size(ctx.mem_ctx, sizeof(SpvId), entry->reg_alloc); - if (!ctx.regs) + ctx.reg_types = ralloc_array_size(ctx.mem_ctx, + sizeof(nir_alu_type), entry->reg_alloc); + if (!ctx.regs || !ctx.reg_types) goto fail; ctx.num_regs = entry->reg_alloc; @@ -4766,7 +4911,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ foreach_list_typed(nir_register, reg, node, &entry->registers) { if (reg->bit_size == 1) - init_reg(&ctx, reg); + init_reg(&ctx, reg, nir_type_bool); } nir_foreach_variable_with_modes(var, s, nir_var_shader_temp)