diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 6f4b19cd06f..32c30172068 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -75,10 +75,12 @@ struct ntv_context { size_t num_entry_ifaces; SpvId *defs; + nir_alu_type *def_types; SpvId *resident_defs; size_t num_defs; SpvId *regs; + nir_alu_type *reg_types; size_t num_regs; struct hash_table *vars; /* nir_variable -> SpvId */ @@ -137,6 +139,31 @@ static SpvId emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src0, SpvId src1, SpvId src2); +static bool +alu_op_is_typeless(nir_op op) +{ + switch (op) { + case nir_op_mov: + case nir_op_vec16: + case nir_op_vec2: + case nir_op_vec3: + case nir_op_vec4: + case nir_op_vec5: + case nir_op_vec8: + case nir_op_bcsel: + return true; + default: + break; + } + return false; +} + +static nir_alu_type +get_nir_alu_type(const struct glsl_type *type) +{ + return nir_alu_type_get_base_type(nir_get_nir_type_for_glsl_base_type(glsl_get_base_type(glsl_without_array_or_matrix(type)))); +} + static SpvId get_bvec_type(struct ntv_context *ctx, int num_components) { @@ -352,25 +379,18 @@ get_alu_type(struct ntv_context *ctx, nir_alu_type type, unsigned num_components if (bit_size == 1) return get_bvec_type(ctx, num_components); + type = nir_alu_type_get_base_type(type); switch (nir_alu_type_get_base_type(type)) { case nir_type_bool: - unreachable("bool should have bit-size 1"); + return get_bvec_type(ctx, num_components); case nir_type_int: - case nir_type_int8: - case nir_type_int16: - case nir_type_int64: return get_ivec_type(ctx, bit_size, num_components); case nir_type_uint: - case nir_type_uint8: - case nir_type_uint16: - case nir_type_uint64: return get_uvec_type(ctx, bit_size, num_components); case nir_type_float: - case nir_type_float16: - case nir_type_float64: return get_fvec_type(ctx, bit_size, num_components); default: @@ -1245,20 +1265,21 @@ get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_c } static SpvId -get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa) +get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa, nir_alu_type *atype) { assert(ssa->index < ctx->num_defs); assert(ctx->defs[ssa->index] != 0); + *atype = ctx->def_types[ssa->index]; return ctx->defs[ssa->index]; } static void -init_reg(struct ntv_context *ctx, nir_register *reg) +init_reg(struct ntv_context *ctx, nir_register *reg, nir_alu_type atype) { if (ctx->regs[reg->index]) return; - SpvId type = get_vec_from_bit_size(ctx, reg->bit_size, reg->num_components); + SpvId type = get_alu_type(ctx, atype, reg->num_components, reg->bit_size); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassFunction, type); @@ -1266,45 +1287,48 @@ init_reg(struct ntv_context *ctx, nir_register *reg) SpvStorageClassFunction); ctx->regs[reg->index] = var; + ctx->reg_types[reg->index] = nir_alu_type_get_base_type(atype); } static SpvId -get_var_from_reg(struct ntv_context *ctx, nir_register *reg) +get_var_from_reg(struct ntv_context *ctx, nir_register *reg, nir_alu_type *atype) { assert(reg->index < ctx->num_regs); - init_reg(ctx, reg); + init_reg(ctx, reg, nir_type_uint); assert(ctx->regs[reg->index] != 0); + + *atype = ctx->reg_types[reg->index]; return ctx->regs[reg->index]; } static SpvId -get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg) +get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg, nir_alu_type *atype) { assert(reg->reg); assert(!reg->indirect); assert(!reg->base_offset); - SpvId var = get_var_from_reg(ctx, reg->reg); - SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components); + SpvId var = get_var_from_reg(ctx, reg->reg, atype); + SpvId type = get_alu_type(ctx, *atype, reg->reg->num_components, reg->reg->bit_size); return spirv_builder_emit_load(&ctx->builder, type, var); } static SpvId -get_src(struct ntv_context *ctx, nir_src *src) +get_src(struct ntv_context *ctx, nir_src *src, nir_alu_type *atype) { if (src->is_ssa) - return get_src_ssa(ctx, src->ssa); + return get_src_ssa(ctx, src->ssa, atype); else - return get_src_reg(ctx, &src->reg); + return get_src_reg(ctx, &src->reg, atype); } static SpvId -get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) +get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, nir_alu_type *atype) { assert(!alu->src[src].negate); assert(!alu->src[src].abs); - SpvId def = get_src(ctx, &alu->src[src].src); + SpvId def = get_src(ctx, &alu->src[src].src, atype); unsigned used_channels = 0; bool need_swizzle = false; @@ -1327,8 +1351,7 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) return def; int bit_size = nir_src_bit_size(alu->src[src].src); - SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) : - spirv_builder_type_uint(&ctx->builder, bit_size); + SpvId raw_type = get_alu_type(ctx, *atype, 1, bit_size); if (used_channels == 1) { uint32_t indices[] = { alu->src[src].swizzle[0] }; @@ -1369,10 +1392,11 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) } static void -store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result) +store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result, nir_alu_type atype) { assert(result != 0); assert(ssa->index < ctx->num_defs); + ctx->def_types[ssa->index] = nir_alu_type_get_base_type(atype); ctx->defs[ssa->index] = result; } @@ -1413,58 +1437,43 @@ bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size, return emit_bitcast(ctx, type, value); } -static void -store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result) +static SpvId +cast_src_to_type(struct ntv_context *ctx, SpvId value, nir_src src, nir_alu_type atype) { - init_reg(ctx, reg->reg); + atype = nir_alu_type_get_base_type(atype); + unsigned num_components = nir_src_num_components(src); + unsigned bit_size = nir_src_bit_size(src); + return emit_bitcast(ctx, get_alu_type(ctx, atype, num_components, bit_size), value); +} + +static void +store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result, nir_alu_type atype) +{ + atype = nir_alu_type_get_base_type(atype); + init_reg(ctx, reg->reg, atype); SpvId var = ctx->regs[reg->reg->index]; + nir_alu_type vtype = ctx->reg_types[reg->reg->index]; + if (atype != vtype) { + assert(vtype != nir_type_bool); + result = emit_bitcast(ctx, get_alu_type(ctx, vtype, reg->reg->num_components, reg->reg->bit_size), result); + } assert(var); spirv_builder_emit_store(&ctx->builder, var, result); } static void -store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result) +store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type atype) { if (dest->is_ssa) - store_ssa_def(ctx, &dest->ssa, result); + store_ssa_def(ctx, &dest->ssa, result, atype); else - store_reg_def(ctx, &dest->reg, result); + store_reg_def(ctx, &dest->reg, result, atype); } static void store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type) { - unsigned num_components = nir_dest_num_components(*dest); - unsigned bit_size = nir_dest_bit_size(*dest); - - if (bit_size != 1) { - switch (nir_alu_type_get_base_type(type)) { - case nir_type_bool: - assert("bool should have bit-size 1"); - break; - - case nir_type_uint: - case nir_type_uint8: - case nir_type_uint16: - case nir_type_uint64: - break; /* nothing to do! */ - - case nir_type_int: - case nir_type_int8: - case nir_type_int16: - case nir_type_int64: - case nir_type_float: - case nir_type_float16: - case nir_type_float64: - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - break; - - default: - unreachable("unsupported nir_alu_type"); - } - } - - store_dest_raw(ctx, dest, result); + store_dest_raw(ctx, dest, result, type); } static SpvId @@ -1974,13 +1983,16 @@ alu_instr_src_components(const nir_alu_instr *instr, unsigned src) } static SpvId -get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *raw_value) +get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *raw_value, nir_alu_type *atype) { - *raw_value = get_alu_src_raw(ctx, alu, src); + *raw_value = get_alu_src_raw(ctx, alu, src, atype); unsigned num_components = alu_instr_src_components(alu, src); unsigned bit_size = nir_src_bit_size(alu->src[src].src); - nir_alu_type type = nir_op_infos[alu->op].input_types[src]; + nir_alu_type type = alu_op_is_typeless(alu->op) ? *atype : nir_op_infos[alu->op].input_types[src]; + type = nir_alu_type_get_base_type(type); + if (type == *atype) + return *raw_value; if (bit_size == 1) return *raw_value; @@ -1993,7 +2005,7 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *ra return bitcast_to_ivec(ctx, *raw_value, bit_size, num_components); case nir_type_uint: - return *raw_value; + return bitcast_to_uvec(ctx, *raw_value, bit_size, num_components); case nir_type_float: return bitcast_to_fvec(ctx, *raw_value, bit_size, num_components); @@ -2005,11 +2017,10 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src, SpvId *ra } static void -store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result, bool force_float) +store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result, nir_alu_type atype) { assert(!alu->dest.saturate); - store_dest(ctx, &alu->dest.dest, result, - force_float ? nir_type_float : nir_op_infos[alu->op].output_type); + store_dest(ctx, &alu->dest.dest, result, atype); } static SpvId @@ -2038,16 +2049,66 @@ needs_derivative_control(nir_alu_instr *alu) static void emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) { + bool is_bcsel = alu->op == nir_op_bcsel; + nir_alu_type stype[NIR_MAX_VEC_COMPONENTS] = {0}; SpvId src[NIR_MAX_VEC_COMPONENTS]; SpvId raw_src[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) - src[i] = get_alu_src(ctx, alu, i, &raw_src[i]); + src[i] = get_alu_src(ctx, alu, i, &raw_src[i], &stype[i]); + + nir_alu_type typeless_type = stype[is_bcsel]; + if (nir_op_infos[alu->op].num_inputs > 1 && + alu_op_is_typeless(alu->op) && + nir_src_bit_size(alu->src[is_bcsel].src) != 1) { + unsigned uint_count = 0; + unsigned int_count = 0; + unsigned float_count = 0; + for (unsigned i = is_bcsel; i < nir_op_infos[alu->op].num_inputs; i++) { + if (stype[i] == nir_type_bool) + break; + switch (stype[i]) { + case nir_type_uint: + uint_count++; + break; + case nir_type_int: + int_count++; + break; + case nir_type_float: + float_count++; + break; + default: + unreachable("this shouldn't happen"); + } + } + if (uint_count > int_count && uint_count > float_count) + typeless_type = nir_type_uint; + else if (int_count > uint_count && int_count > float_count) + typeless_type = nir_type_int; + else if (float_count > uint_count && float_count > int_count) + typeless_type = nir_type_float; + else if (float_count == uint_count || uint_count == int_count) + typeless_type = nir_type_uint; + else if (float_count == int_count) + typeless_type = nir_type_float; + else + typeless_type = nir_type_uint; + assert(typeless_type != nir_type_bool); + for (unsigned i = is_bcsel; i < nir_op_infos[alu->op].num_inputs; i++) { + unsigned num_components = alu_instr_src_components(alu, i); + unsigned bit_size = nir_src_bit_size(alu->src[i].src); + SpvId type = get_alu_type(ctx, typeless_type, num_components, bit_size); + if (stype[i] != typeless_type) { + src[i] = emit_bitcast(ctx, type, src[i]); + } + } + } - SpvId dest_type = get_dest_type(ctx, &alu->dest.dest, - nir_op_infos[alu->op].output_type); - bool force_float = false; unsigned bit_size = nir_dest_bit_size(alu->dest.dest); unsigned num_components = nir_dest_num_components(alu->dest.dest); + nir_alu_type atype = bit_size == 1 ? + nir_type_bool : + (alu_op_is_typeless(alu->op) ? typeless_type : nir_op_infos[alu->op].output_type); + SpvId dest_type = get_dest_type(ctx, &alu->dest.dest, atype); if (needs_derivative_control(alu)) spirv_builder_emit_cap(&ctx->builder, SpvCapabilityDerivativeControl); @@ -2144,7 +2205,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) case nir_op: \ assert(nir_op_infos[alu->op].num_inputs == 1); \ result = emit_builtin_unop(ctx, spirv_op, get_dest_type(ctx, &alu->dest.dest, nir_type_float), src[0]); \ - force_float = true; \ + atype = nir_type_float; \ break; BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs) @@ -2415,7 +2476,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) if (alu->exact) spirv_builder_emit_decoration(&ctx->builder, result, SpvDecorationNoContraction); - store_alu_result(ctx, alu, result, force_float); + store_alu_result(ctx, alu, result, atype); } static void @@ -2425,16 +2486,19 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) unsigned num_components = load_const->def.num_components; SpvId components[NIR_MAX_VEC_COMPONENTS]; + nir_alu_type atype; if (bit_size == 1) { for (int i = 0; i < num_components; i++) components[i] = spirv_builder_const_bool(&ctx->builder, load_const->value[i].b); + atype = nir_type_bool; } else { for (int i = 0; i < num_components; i++) { uint64_t tmp = nir_const_value_as_uint(load_const->value[i], bit_size); components[i] = emit_uint_const(ctx, bit_size, tmp); } + atype = nir_type_uint; } if (num_components > 1) { @@ -2443,10 +2507,10 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) SpvId value = spirv_builder_const_composite(&ctx->builder, type, components, num_components); - store_ssa_def(ctx, &load_const->def, value); + store_ssa_def(ctx, &load_const->def, value, atype); } else { assert(num_components == 1); - store_ssa_def(ctx, &load_const->def, components[0]); + store_ssa_def(ctx, &load_const->def, components[0], atype); } } @@ -2462,7 +2526,8 @@ emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId ptr = get_src(ctx, intr->src); + nir_alu_type atype; + SpvId ptr = get_src(ctx, intr->src, &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); SpvId type; @@ -2472,8 +2537,10 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) type = get_image_type(ctx, var, glsl_type_is_sampler(gtype), glsl_get_sampler_dim(gtype) == GLSL_SAMPLER_DIM_BUF); + atype = nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(gtype)); } else { type = get_glsl_type(ctx, deref->type); + atype = get_nir_alu_type(deref->type); } SpvId result; @@ -2481,18 +2548,15 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) result = emit_atomic(ctx, SpvOpAtomicLoad, type, ptr, 0, 0); else result = spirv_builder_emit_load(&ctx->builder, type, ptr); - unsigned num_components = nir_dest_num_components(intr->dest); - unsigned bit_size = nir_dest_bit_size(intr->dest); - if (bit_size > 1) - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - store_dest(ctx, &intr->dest, result, nir_type_uint); + store_dest(ctx, &intr->dest, result, atype); } static void emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId ptr = get_src(ctx, &intr->src[0]); - SpvId src = get_src(ctx, &intr->src[1]); + nir_alu_type ptype, stype; + SpvId ptr = get_src(ctx, &intr->src[0], &ptype); + SpvId src = get_src(ctx, &intr->src[1], &stype); const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; SpvId type = get_glsl_type(ctx, gtype); @@ -2508,7 +2572,7 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId member_type; if (glsl_type_is_vector(gtype)) { result_type = get_glsl_basetype(ctx, glsl_get_base_type(gtype)); - member_type = get_uvec_type(ctx, glsl_get_bit_size(gtype), 1); + member_type = get_alu_type(ctx, stype, 1, glsl_get_bit_size(gtype)); } else member_type = result_type = get_glsl_type(ctx, glsl_get_array_element(gtype)); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, @@ -2518,7 +2582,8 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) if (wrmask & BITFIELD_BIT(i)) { SpvId idx = emit_uint_const(ctx, 32, i); SpvId val = spirv_builder_emit_composite_extract(&ctx->builder, member_type, src, &i, 1); - val = emit_bitcast(ctx, result_type, val); + if (stype != ptype) + val = emit_bitcast(ctx, result_type, val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, ptr, &idx, 1); spirv_builder_emit_store(&ctx->builder, member, val); @@ -2533,10 +2598,12 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) src = emit_bitcast(ctx, type, src); /* SampleMask is always an array in spirv, so we need to construct it into one */ result = spirv_builder_emit_composite_construct(&ctx->builder, ctx->sample_mask_type, &src, 1); - } else if (glsl_get_base_type(glsl_without_array(gtype)) == GLSL_TYPE_BOOL) { - result = src; - } else - result = emit_bitcast(ctx, type, src); + } else { + if (ptype == stype) + result = src; + else + result = emit_bitcast(ctx, type, src); + } if (nir_intrinsic_access(intr) & ACCESS_COHERENT) spirv_builder_emit_atomic_store(&ctx->builder, ptr, SpvScopeDevice, 0, result); else @@ -2553,7 +2620,10 @@ emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, uint_type); - SpvId offset = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId offset = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId constituents[NIR_MAX_VEC_COMPONENTS]; SpvId shared_block = get_shared_block(ctx, bit_size); /* need to convert array -> vec */ @@ -2574,7 +2644,8 @@ emit_load_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId src = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId src = get_src(ctx, &intr->src[0], &atype); unsigned wrmask = nir_intrinsic_write_mask(intr); unsigned bit_size = nir_src_bit_size(intr->src[0]); @@ -2582,7 +2653,10 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, uint_type); - SpvId offset = get_src(ctx, &intr->src[1]); + nir_alu_type otype; + SpvId offset = get_src(ctx, &intr->src[1], &otype); + if (otype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId shared_block = get_shared_block(ctx, bit_size); /* this is a partial write, so we have to loop and do a per-component write */ u_foreach_bit(i, wrmask) { @@ -2590,6 +2664,8 @@ emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId val = src; if (nir_src_num_components(intr->src[0]) != 1) val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); + if (atype != nir_type_uint) + val = emit_bitcast(ctx, get_alu_type(ctx, nir_type_uint, 1, bit_size), val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, shared_block, &shared_offset, 1); spirv_builder_emit_store(&ctx->builder, member, val); @@ -2606,7 +2682,10 @@ emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPrivate, uint_type); - SpvId offset = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId offset = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId constituents[NIR_MAX_VEC_COMPONENTS]; SpvId scratch_block = get_scratch_block(ctx, bit_size); /* need to convert array -> vec */ @@ -2627,7 +2706,8 @@ emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId src = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId src = get_src(ctx, &intr->src[0], &atype); unsigned wrmask = nir_intrinsic_write_mask(intr); unsigned bit_size = nir_src_bit_size(intr->src[0]); @@ -2635,7 +2715,10 @@ emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPrivate, uint_type); - SpvId offset = get_src(ctx, &intr->src[1]); + nir_alu_type otype; + SpvId offset = get_src(ctx, &intr->src[1], &otype); + if (otype == nir_type_float) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId scratch_block = get_scratch_block(ctx, bit_size); /* this is a partial write, so we have to loop and do a per-component write */ u_foreach_bit(i, wrmask) { @@ -2643,6 +2726,8 @@ emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId val = src; if (nir_src_num_components(intr->src[0]) != 1) val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, src, &i, 1); + if (atype != nir_type_uint) + val = emit_bitcast(ctx, get_alu_type(ctx, nir_type_uint, 1, bit_size), val); SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, scratch_block, &scratch_offset, 1); spirv_builder_emit_store(&ctx->builder, member, val); @@ -2669,7 +2754,10 @@ emit_load_push_const(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvStorageClassPushConstant, load_type); - SpvId member = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId member = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + member = bitcast_to_uvec(ctx, member, nir_src_bit_size(intr->src[0]), 1); /* reuse the offset from ZINK_PUSH_CONST_OFFSET */ SpvId offset = emit_uint_const(ctx, 32, 0); /* OpAccessChain takes an array of indices that drill into a hierarchy based on the type: @@ -2710,7 +2798,8 @@ emit_load_global(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPhysicalStorageBuffer, dest_type); - SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[0])); + nir_alu_type atype; + SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[0], &atype)); SpvId result = spirv_builder_emit_load(&ctx->builder, dest_type, ptr); store_dest(ctx, &intr->dest, result, nir_type_uint); } @@ -2724,8 +2813,9 @@ emit_store_global(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassPhysicalStorageBuffer, dest_type); - SpvId param = get_src(ctx, &intr->src[0]); - SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[1])); + nir_alu_type atype; + SpvId param = get_src(ctx, &intr->src[0], &atype); + SpvId ptr = emit_bitcast(ctx, pointer_type, get_src(ctx, &intr->src[1], &atype)); spirv_builder_emit_store(&ctx->builder, ptr, param); } @@ -2842,37 +2932,38 @@ emit_interpolate(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId op; spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInterpolationFunction); SpvId src1 = 0; + nir_alu_type atype; switch (intr->intrinsic) { case nir_intrinsic_interp_deref_at_centroid: op = GLSLstd450InterpolateAtCentroid; break; case nir_intrinsic_interp_deref_at_sample: op = GLSLstd450InterpolateAtSample; - src1 = get_src(ctx, &intr->src[1]); + src1 = get_src(ctx, &intr->src[1], &atype); break; case nir_intrinsic_interp_deref_at_offset: op = GLSLstd450InterpolateAtOffset; - src1 = get_src(ctx, &intr->src[1]); + src1 = get_src(ctx, &intr->src[1], &atype); /* The offset operand must be a vector of 2 components of 32-bit floating-point type. - InterpolateAtOffset spec */ - src1 = emit_bitcast(ctx, get_fvec_type(ctx, 32, 2), src1); + if (atype != nir_type_float) + src1 = emit_bitcast(ctx, get_fvec_type(ctx, 32, 2), src1); break; default: unreachable("unknown interp op"); } - SpvId ptr = get_src(ctx, &intr->src[0]); + nir_alu_type ptype; + SpvId ptr = get_src(ctx, &intr->src[0], &ptype); SpvId result; + const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; + assert(ptype == get_nir_alu_type(gtype)); if (intr->intrinsic == nir_intrinsic_interp_deref_at_centroid) - result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type), ptr); + result = emit_builtin_unop(ctx, op, get_glsl_type(ctx, gtype), ptr); else - result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type), - ptr, src1); - unsigned num_components = nir_dest_num_components(intr->dest); - unsigned bit_size = nir_dest_bit_size(intr->dest); - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - store_dest(ctx, &intr->dest, result, nir_type_uint); + result = emit_builtin_binop(ctx, op, get_glsl_type(ctx, gtype), ptr, src1); + store_dest(ctx, &intr->dest, result, ptype); } static void @@ -2887,19 +2978,25 @@ handle_atomic_op(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId ptr, static void emit_deref_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId ptr = get_src(ctx, &intr->src[0]); - SpvId param = get_src(ctx, &intr->src[1]); + nir_alu_type atype; + nir_alu_type ret_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)) == nir_type_float ? nir_type_float : nir_type_uint; + SpvId ptr = get_src(ctx, &intr->src[0], &atype); + SpvId param = get_src(ctx, &intr->src[1], &atype); + if (atype != ret_type) + param = cast_src_to_type(ctx, param, intr->src[1], ret_type); SpvId param2 = 0; if (nir_src_bit_size(intr->src[1]) == 64) spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics); - if (intr->intrinsic == nir_intrinsic_deref_atomic_swap) - param2 = get_src(ctx, &intr->src[2]); + if (intr->intrinsic == nir_intrinsic_deref_atomic_swap) { + param2 = get_src(ctx, &intr->src[2], &atype); + if (atype != ret_type) + param2 = cast_src_to_type(ctx, param2, intr->src[2], ret_type); + } - nir_alu_type op_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)); - handle_atomic_op(ctx, intr, ptr, param, param2, op_type == nir_type_float ? nir_type_float : nir_type_uint32); + handle_atomic_op(ctx, intr, ptr, param, param2, ret_type); } static void @@ -2907,12 +3004,19 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { unsigned bit_size = nir_src_bit_size(intr->src[1]); SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint); - SpvId param = get_src(ctx, &intr->src[1]); + nir_alu_type atype; + nir_alu_type ret_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)) == nir_type_float ? nir_type_float : nir_type_uint; + SpvId param = get_src(ctx, &intr->src[1], &atype); + if (atype != ret_type) + param = cast_src_to_type(ctx, param, intr->src[1], ret_type); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, dest_type); - SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, bit_size / 8)); + SpvId offset = get_src(ctx, &intr->src[0], &atype); + if (atype != nir_type_uint) + offset = cast_src_to_type(ctx, offset, intr->src[0], nir_type_uint); + offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), offset, emit_uint_const(ctx, 32, bit_size / 8)); SpvId shared_block = get_shared_block(ctx, bit_size); SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type, shared_block, &offset, 1); @@ -2920,11 +3024,13 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics); SpvId param2 = 0; - if (intr->intrinsic == nir_intrinsic_shared_atomic_swap) - param2 = get_src(ctx, &intr->src[2]); + if (intr->intrinsic == nir_intrinsic_shared_atomic_swap) { + param2 = get_src(ctx, &intr->src[2], &atype); + if (atype != ret_type) + param2 = cast_src_to_type(ctx, param2, intr->src[2], ret_type); + } - nir_alu_type op_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr)); - handle_atomic_op(ctx, intr, ptr, param, param2, op_type == nir_type_float ? nir_type_float : nir_type_uint32); + handle_atomic_op(ctx, intr, ptr, param, param2, ret_type); } static void @@ -2937,7 +3043,10 @@ emit_get_ssbo_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassStorageBuffer, get_bo_struct_type(ctx, var)); - SpvId bo = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId bo = get_src(ctx, &intr->src[0], &atype); + if (atype == nir_type_float) + bo = bitcast_to_uvec(ctx, bo, nir_src_bit_size(intr->src[0]), 1); SpvId indices[] = { bo }; SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type, ctx->ssbos[2], indices, @@ -2966,16 +3075,17 @@ get_image_coords(struct ntv_context *ctx, const struct glsl_type *type, nir_src uint32_t num_coords = glsl_get_sampler_coordinate_components(type); uint32_t src_components = nir_src_num_components(*src); - SpvId spv = get_src(ctx, src); + nir_alu_type atype; + SpvId spv = get_src(ctx, src, &atype); if (num_coords == src_components) return spv; /* need to extract the coord dimensions that the image can use */ - SpvId vec_type = get_uvec_type(ctx, 32, num_coords); + SpvId vec_type = get_alu_type(ctx, atype, num_coords, 32); if (num_coords == 1) return spirv_builder_emit_vector_extract(&ctx->builder, vec_type, spv, 0); uint32_t constituents[4]; - SpvId zero = emit_uint_const(ctx, nir_src_bit_size(*src), 0); + SpvId zero = atype == nir_type_uint ? emit_uint_const(ctx, nir_src_bit_size(*src), 0) : emit_float_const(ctx, nir_src_bit_size(*src), 0); assert(num_coords < ARRAY_SIZE(constituents)); for (unsigned i = 0; i < num_coords; i++) constituents[i] = i < src_components ? i : zero; @@ -2985,7 +3095,8 @@ get_image_coords(struct ntv_context *ctx, const struct glsl_type *type, nir_src static void emit_image_deref_store(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); SpvId img_type = find_image_type(ctx, var); @@ -2993,15 +3104,16 @@ emit_image_deref_store(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId base_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type)); SpvId img = spirv_builder_emit_load(&ctx->builder, img_type, img_var); SpvId coord = get_image_coords(ctx, type, &intr->src[1]); - SpvId texel = get_src(ctx, &intr->src[3]); + SpvId texel = get_src(ctx, &intr->src[3], &atype); + /* texel type must match image type */ + if (atype != nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(type))) + texel = emit_bitcast(ctx, + spirv_builder_type_vector(&ctx->builder, base_type, 4), + texel); bool use_sample = glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_MS || glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_SUBPASS_MS; - SpvId sample = use_sample ? get_src(ctx, &intr->src[2]) : 0; + SpvId sample = use_sample ? get_src(ctx, &intr->src[2], &atype) : 0; assert(nir_src_bit_size(intr->src[3]) == glsl_base_type_bit_size(glsl_get_sampler_result_type(type))); - /* texel type must match image type */ - texel = emit_bitcast(ctx, - spirv_builder_type_vector(&ctx->builder, base_type, 4), - texel); spirv_builder_emit_image_write(&ctx->builder, img, coord, texel, 0, sample, 0); } @@ -3042,7 +3154,8 @@ static void emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) { bool sparse = intr->intrinsic == nir_intrinsic_image_deref_sparse_load; - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); bool mediump = (var->data.precision == GLSL_PRECISION_MEDIUM || var->data.precision == GLSL_PRECISION_LOW); @@ -3053,7 +3166,7 @@ emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId coord = get_image_coords(ctx, type, &intr->src[1]); bool use_sample = glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_MS || glsl_get_sampler_dim(type) == GLSL_SAMPLER_DIM_SUBPASS_MS; - SpvId sample = use_sample ? get_src(ctx, &intr->src[2]) : 0; + SpvId sample = use_sample ? get_src(ctx, &intr->src[2], &atype) : 0; SpvId dest_type = spirv_builder_type_vector(&ctx->builder, base_type, nir_dest_num_components(intr->dest)); SpvId result = spirv_builder_emit_image_read(&ctx->builder, dest_type, @@ -3066,13 +3179,14 @@ emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvDecorationRelaxedPrecision); } - store_dest(ctx, &intr->dest, result, nir_type_float); + store_dest(ctx, &intr->dest, result, nir_get_nir_type_for_glsl_base_type(glsl_get_sampler_result_type(type))); } static void emit_image_deref_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); SpvId img_type = find_image_type(ctx, var); @@ -3091,7 +3205,8 @@ emit_image_deref_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_image_deref_samples(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype; + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); SpvId img_type = find_image_type(ctx, var); @@ -3105,14 +3220,15 @@ emit_image_deref_samples(struct ntv_context *ctx, nir_intrinsic_instr *intr) static void emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { - SpvId param = get_src(ctx, &intr->src[3]); - SpvId img_var = get_src(ctx, &intr->src[0]); + nir_alu_type atype, ptype; + SpvId param = get_src(ctx, &intr->src[3], &ptype); + SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); nir_variable *var = nir_deref_instr_get_variable(deref); const struct glsl_type *type = glsl_without_array(var->type); bool is_ms; type_to_dim(glsl_get_sampler_dim(type), &is_ms); - SpvId sample = is_ms ? get_src(ctx, &intr->src[2]) : emit_uint_const(ctx, 32, 0); + SpvId sample = is_ms ? get_src(ctx, &intr->src[2], &atype) : emit_uint_const(ctx, 32, 0); SpvId coord = get_image_coords(ctx, type, &intr->src[1]); enum glsl_base_type glsl_type = glsl_get_sampler_result_type(type); SpvId base_type = get_glsl_basetype(ctx, glsl_type); @@ -3123,12 +3239,17 @@ emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) * The type of the value pointed to by Pointer must be the same as Result Type. */ nir_alu_type ntype = nir_get_nir_type_for_glsl_base_type(glsl_type); - SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype); - param = emit_bitcast(ctx, cast_type, param); + if (ptype != ntype) { + SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype); + param = emit_bitcast(ctx, cast_type, param); + } if (intr->intrinsic == nir_intrinsic_image_deref_atomic_swap) { - param2 = get_src(ctx, &intr->src[4]); - param2 = emit_bitcast(ctx, cast_type, param2); + param2 = get_src(ctx, &intr->src[4], &ptype); + if (ptype != ntype) { + SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype); + param2 = emit_bitcast(ctx, cast_type, param2); + } } handle_atomic_op(ctx, intr, texel, param, param2, ntype); @@ -3140,7 +3261,8 @@ emit_ballot(struct ntv_context *ctx, nir_intrinsic_instr *intr) spirv_builder_emit_cap(&ctx->builder, SpvCapabilitySubgroupBallotKHR); spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_shader_ballot"); SpvId type = get_dest_uvec_type(ctx, &intr->dest); - SpvId result = emit_unop(ctx, SpvOpSubgroupBallotKHR, type, get_src(ctx, &intr->src[0])); + nir_alu_type atype; + SpvId result = emit_unop(ctx, SpvOpSubgroupBallotKHR, type, get_src(ctx, &intr->src[0], &atype)); store_dest(ctx, &intr->dest, result, nir_type_uint); } @@ -3149,9 +3271,11 @@ emit_read_first_invocation(struct ntv_context *ctx, nir_intrinsic_instr *intr) { spirv_builder_emit_cap(&ctx->builder, SpvCapabilitySubgroupBallotKHR); spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_shader_ballot"); - SpvId type = get_dest_type(ctx, &intr->dest, nir_type_uint); - SpvId result = emit_unop(ctx, SpvOpSubgroupFirstInvocationKHR, type, get_src(ctx, &intr->src[0])); - store_dest(ctx, &intr->dest, result, nir_type_uint); + nir_alu_type atype; + SpvId src = get_src(ctx, &intr->src[0], &atype); + SpvId type = get_dest_type(ctx, &intr->dest, atype); + SpvId result = emit_unop(ctx, SpvOpSubgroupFirstInvocationKHR, type, src); + store_dest(ctx, &intr->dest, result, atype); } static void @@ -3159,11 +3283,13 @@ emit_read_invocation(struct ntv_context *ctx, nir_intrinsic_instr *intr) { spirv_builder_emit_cap(&ctx->builder, SpvCapabilitySubgroupBallotKHR); spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_shader_ballot"); - SpvId type = get_dest_type(ctx, &intr->dest, nir_type_uint); + nir_alu_type atype, itype; + SpvId src = get_src(ctx, &intr->src[0], &atype); + SpvId type = get_dest_type(ctx, &intr->dest, atype); SpvId result = emit_binop(ctx, SpvOpSubgroupReadInvocationKHR, type, - get_src(ctx, &intr->src[0]), - get_src(ctx, &intr->src[1])); - store_dest(ctx, &intr->dest, result, nir_type_uint); + src, + get_src(ctx, &intr->src[1], &itype)); + store_dest(ctx, &intr->dest, result, atype); } static void @@ -3220,8 +3346,9 @@ emit_vote(struct ntv_context *ctx, nir_intrinsic_instr *intr) unreachable("unknown vote intrinsic"); } spirv_builder_emit_cap(&ctx->builder, SpvCapabilityGroupNonUniformVote); - SpvId result = spirv_builder_emit_vote(&ctx->builder, op, get_src(ctx, &intr->src[0])); - store_dest_raw(ctx, &intr->dest, result); + nir_alu_type atype; + SpvId result = spirv_builder_emit_vote(&ctx->builder, op, get_src(ctx, &intr->src[0], &atype)); + store_dest_raw(ctx, &intr->dest, result, nir_type_bool); } static void @@ -3546,13 +3673,17 @@ emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef) undef->def.num_components); store_ssa_def(ctx, &undef->def, - spirv_builder_emit_undef(&ctx->builder, type)); + spirv_builder_emit_undef(&ctx->builder, type), + undef->def.bit_size == 1 ? nir_type_bool : nir_type_uint); } static SpvId get_src_float(struct ntv_context *ctx, nir_src *src) { - SpvId def = get_src(ctx, src); + nir_alu_type atype; + SpvId def = get_src(ctx, src, &atype); + if (atype == nir_type_float) + return def; unsigned num_components = nir_src_num_components(*src); unsigned bit_size = nir_src_bit_size(*src); return bitcast_to_fvec(ctx, def, bit_size, num_components); @@ -3561,7 +3692,10 @@ get_src_float(struct ntv_context *ctx, nir_src *src) static SpvId get_src_int(struct ntv_context *ctx, nir_src *src) { - SpvId def = get_src(ctx, src); + nir_alu_type atype; + SpvId def = get_src(ctx, src, &atype); + if (atype == nir_type_int) + return def; unsigned num_components = nir_src_num_components(*src); unsigned bit_size = nir_src_bit_size(*src); return bitcast_to_ivec(ctx, def, bit_size, num_components); @@ -3602,6 +3736,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex) const_offset = 0, offset = 0, sample = 0, tex_offset = 0, bindless = 0, min_lod = 0; unsigned coord_components = 0; nir_variable *bindless_var = NULL; + nir_alu_type atype; for (unsigned i = 0; i < tex->num_srcs; i++) { nir_const_value *cv; switch (tex->src[i].src_type) { @@ -3698,7 +3833,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex) break; case nir_tex_src_texture_handle: - bindless = get_src(ctx, &tex->src[i].src); + bindless = get_src(ctx, &tex->src[i].src, &atype); bindless_var = nir_deref_instr_get_variable(nir_src_as_deref(tex->src[i].src)); break; @@ -3968,7 +4103,7 @@ emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref) struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var); assert(he); SpvId result = (SpvId)(intptr_t)he->data; - store_dest_raw(ctx, &deref->dest, result); + store_dest_raw(ctx, &deref->dest, result, get_nir_alu_type(deref->type)); } static void @@ -4001,11 +4136,12 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) SpvStorageClass storage_class = get_storage_class(var); SpvId base, type; + nir_alu_type atype = nir_type_uint; switch (var->data.mode) { case nir_var_mem_ubo: case nir_var_mem_ssbo: - base = get_src(ctx, &deref->parent); + base = get_src(ctx, &deref->parent, &atype); /* this is either the array deref or the array deref */ if (glsl_type_is_struct_or_ifc(deref->type)) { /* array */ @@ -4017,7 +4153,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) case nir_var_function_temp: case nir_var_shader_in: case nir_var_shader_out: - base = get_src(ctx, &deref->parent); + base = get_src(ctx, &deref->parent, &atype); type = get_glsl_type(ctx, deref->type); break; @@ -4037,7 +4173,10 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) unreachable("Unsupported nir_variable_mode\n"); } - SpvId index = get_src(ctx, &deref->arr.index); + nir_alu_type itype; + SpvId index = get_src(ctx, &deref->arr.index, &itype); + if (itype == nir_type_float) + index = emit_bitcast(ctx, get_uvec_type(ctx, 32, 1), index); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, storage_class, @@ -4048,7 +4187,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) base, &index, 1); /* uint is a bit of a lie here, it's really just an opaque type */ - store_dest(ctx, &deref->dest, result, nir_type_uint); + store_dest(ctx, &deref->dest, result, get_nir_alu_type(deref->type)); } static void @@ -4068,12 +4207,13 @@ emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref) storage_class, type); + nir_alu_type atype; SpvId result = spirv_builder_emit_access_chain(&ctx->builder, ptr_type, - get_src(ctx, &deref->parent), + get_src(ctx, &deref->parent, &atype), &index, 1); /* uint is a bit of a lie here, it's really just an opaque type */ - store_dest(ctx, &deref->dest, result, nir_type_uint); + store_dest(ctx, &deref->dest, result, get_nir_alu_type(deref->type)); } static void @@ -4144,7 +4284,8 @@ static SpvId get_src_bool(struct ntv_context *ctx, nir_src *src) { assert(nir_src_bit_size(*src) == 1); - return get_src(ctx, src); + nir_alu_type atype; + return get_src(ctx, src, &atype); } static void @@ -4728,7 +4869,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ ctx.defs = ralloc_array_size(ctx.mem_ctx, sizeof(SpvId), entry->ssa_alloc); - if (!ctx.defs) + ctx.def_types = ralloc_array_size(ctx.mem_ctx, + sizeof(nir_alu_type), entry->ssa_alloc); + if (!ctx.defs || !ctx.def_types) goto fail; if (sinfo->have_sparse) { spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySparseResidency); @@ -4745,7 +4888,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ nir_index_local_regs(entry); ctx.regs = rzalloc_array_size(ctx.mem_ctx, sizeof(SpvId), entry->reg_alloc); - if (!ctx.regs) + ctx.reg_types = ralloc_array_size(ctx.mem_ctx, + sizeof(nir_alu_type), entry->reg_alloc); + if (!ctx.regs || !ctx.reg_types) goto fail; ctx.num_regs = entry->reg_alloc; @@ -4766,7 +4911,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ foreach_list_typed(nir_register, reg, node, &entry->registers) { if (reg->bit_size == 1) - init_reg(&ctx, reg); + init_reg(&ctx, reg, nir_type_bool); } nir_foreach_variable_with_modes(var, s, nir_var_shader_temp)