diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index f7121913e08..42b47809af9 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -2386,7 +2386,7 @@ emit_load_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) uint_type); nir_alu_type atype; SpvId offset = get_src(ctx, &intr->src[0], &atype); - if (atype == nir_type_float) + if (atype != nir_type_uint) offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); SpvId constituents[NIR_MAX_VEC_COMPONENTS]; SpvId scratch_block = get_scratch_block(ctx, bit_size); @@ -2419,8 +2419,8 @@ emit_store_scratch(struct ntv_context *ctx, nir_intrinsic_instr *intr) uint_type); nir_alu_type otype; SpvId offset = get_src(ctx, &intr->src[1], &otype); - if (otype == nir_type_float) - offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[0]), 1); + if (otype != nir_type_uint) + offset = bitcast_to_uvec(ctx, offset, nir_src_bit_size(intr->src[1]), 1); SpvId scratch_block = get_scratch_block(ctx, bit_size); /* this is a partial write, so we have to loop and do a per-component write */ u_foreach_bit(i, wrmask) {