aco: vectorize conversions from 8bit to 16bit

Massively helps emulated fp8 performance.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35854>
This commit is contained in:
Georg Lehmann 2025-07-07 13:46:11 +02:00 committed by Marge Bot
parent 7fece5592c
commit 92d433c54a
2 changed files with 27 additions and 20 deletions

View file

@ -478,6 +478,8 @@ aco_nir_op_supports_packed_math_16bit(const nir_alu_instr* alu)
case nir_op_ishl:
case nir_op_ishr:
case nir_op_ushr: return true;
case nir_op_u2u16:
case nir_op_i2i16: return alu->src[0].src.ssa->bit_size == 8;
default: return false;
}
}

View file

@ -2964,33 +2964,38 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
}
case nir_op_i2i8:
case nir_op_i2i16:
case nir_op_i2i32: {
if (dst.type() == RegType::sgpr && instr->src[0].src.ssa->bit_size < 32) {
/* no need to do the extract in get_alu_src() */
sgpr_extract_mode mode = instr->def.bit_size > instr->src[0].src.ssa->bit_size
? sgpr_extract_sext
: sgpr_extract_undef;
extract_8_16_bit_sgpr_element(ctx, dst, &instr->src[0], mode);
} else {
const unsigned input_bitsize = instr->src[0].src.ssa->bit_size;
const unsigned output_bitsize = instr->def.bit_size;
convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]), input_bitsize, output_bitsize,
output_bitsize > input_bitsize, dst);
}
break;
}
case nir_op_i2i32:
case nir_op_u2u8:
case nir_op_u2u16:
case nir_op_u2u32: {
const unsigned input_bitsize = instr->src[0].src.ssa->bit_size;
const unsigned output_bitsize = instr->def.bit_size;
bool sext =
instr->op == nir_op_i2i8 || instr->op == nir_op_i2i16 || instr->op == nir_op_i2i32;
bool trunc = output_bitsize <= input_bitsize;
if (instr->def.num_components == 2) {
assert(output_bitsize == 16 && input_bitsize == 8);
assert((instr->src[0].swizzle[0] & ~0x3) == (instr->src[0].swizzle[1] & ~0x3));
Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa);
if (src.bytes() >= 4)
src = emit_extract_vector(ctx, src, instr->src[0].swizzle[0] & ~0x3, v1);
emit_pk_int16_from_8bit(ctx, dst, src, instr->src[0].swizzle[0] & 0x3,
instr->src[0].swizzle[1] & 0x3, sext);
break;
}
if (dst.type() == RegType::sgpr && instr->src[0].src.ssa->bit_size < 32) {
/* no need to do the extract in get_alu_src() */
sgpr_extract_mode mode = instr->def.bit_size > instr->src[0].src.ssa->bit_size
? sgpr_extract_zext
: sgpr_extract_undef;
sgpr_extract_mode mode = trunc ? sgpr_extract_undef
: sext ? sgpr_extract_sext
: sgpr_extract_zext;
extract_8_16_bit_sgpr_element(ctx, dst, &instr->src[0], mode);
} else {
convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]), instr->src[0].src.ssa->bit_size,
instr->def.bit_size, false, dst);
convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]), input_bitsize, output_bitsize,
sext && !trunc, dst);
}
break;
}