aco: vectorize 16bit extracts

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35854>
This commit is contained in:
Georg Lehmann 2025-07-07 13:45:48 +02:00 committed by Marge Bot
parent a045e9a624
commit 7fece5592c
2 changed files with 61 additions and 0 deletions

View file

@ -473,6 +473,8 @@ aco_nir_op_supports_packed_math_16bit(const nir_alu_instr* alu)
case nir_op_imax:
case nir_op_umin:
case nir_op_umax:
case nir_op_extract_u8:
case nir_op_extract_i8:
case nir_op_ishl:
case nir_op_ishr:
case nir_op_ushr: return true;

View file

@ -341,6 +341,49 @@ emit_pk_shift(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst)
emit_split_vector(ctx, dst, 2);
}
void
emit_pk_int16_from_8bit(isel_context* ctx, Temp dst, Temp src, unsigned byte0, unsigned byte2,
bool sext)
{
Builder bld(ctx->program, ctx->block);
assert(src.size() == 1);
assert(dst.regClass() == v1);
src = as_vgpr(ctx, src);
if (byte0 == 0 && byte2 == 2 && !sext) {
Temp mask = bld.copy(bld.def(s1), Operand::c32(0x00ff00ffu));
bld.vop2(aco_opcode::v_and_b32, Definition(dst), mask, src);
} else if ((byte0 & 0x1) != 0 && (byte2 & 0x1) != 0) {
aco_opcode shift = sext ? aco_opcode::v_pk_ashrrev_i16 : aco_opcode::v_pk_lshrrev_b16;
bld.vop3p(shift, Definition(dst), Operand::c32(8), src, byte0 & 0x2, byte2 & 0x2);
} else {
unsigned swizzle[2] = {byte0, byte2};
uint32_t pk_select = 0;
Operand msb = Operand::c32(0);
for (unsigned i = 0; i < 2; i++) {
pk_select |= swizzle[i] << (i * 16);
if (!sext) {
pk_select |= bperm_0 << (i * 16 + 8);
} else if (swizzle[i] & 0x1) {
pk_select |= (swizzle[i] & 0x2 ? bperm_b3_sign : bperm_b1_sign) << (i * 16 + 8);
} else {
if (msb.isConstant())
msb = bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(8), src);
pk_select |= (swizzle[i] & 0x2 ? bperm_b7_sign : bperm_b5_sign) << (i * 16 + 8);
}
}
bld.vop3(aco_opcode::v_perm_b32, Definition(dst), msb, src,
bld.copy(bld.def(s1), Operand::c32(pk_select)));
}
emit_split_vector(ctx, dst, 2);
}
void
emit_vop1_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst)
{
@ -3412,6 +3455,22 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bool is_signed = instr->op == nir_op_extract_i16 || instr->op == nir_op_extract_i8;
unsigned comp = instr->op == nir_op_extract_u8 || instr->op == nir_op_extract_i8 ? 4 : 2;
uint32_t bits = comp == 4 ? 8 : 16;
if (instr->def.num_components == 2) {
assert(instr->def.bit_size == 16 && bits == 8);
Temp src = get_alu_src_vop3p(ctx, instr->src[0]);
unsigned swizzle[2];
for (unsigned i = 0; i < 2; i++) {
nir_scalar index = nir_scalar_resolved(instr->src[1].src.ssa, instr->src[1].swizzle[i]);
swizzle[i] = (instr->src[0].swizzle[i] & 0x1) * 2 + nir_scalar_as_uint(index);
}
emit_pk_int16_from_8bit(ctx, dst, src, swizzle[0], swizzle[1], is_signed);
break;
}
unsigned index = nir_src_as_uint(instr->src[1].src);
if (bits >= instr->def.bit_size || index * bits >= instr->def.bit_size) {
assert(index == 0);