aco: fix VOP3P assembly, VN and validation

aco/opcodes: rename v_pk_fma_mix* -> v_fma_mix*
and add modifier capabilities for VOP3P.

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>
This commit is contained in:
Daniel Schürmann 2020-09-03 11:59:00 +01:00 committed by Marge Bot
parent 2bde971f46
commit 2caba08c1a
4 changed files with 43 additions and 29 deletions

View file

@ -612,7 +612,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
encoding |= opcode << 16;
encoding |= (vop3->clamp ? 1 : 0) << 15;
encoding |= vop3->opsel_lo << 11;
encoding |= (vop3->opsel_hi & 0x4) ? 1 : 0 << 14;
encoding |= ((vop3->opsel_hi & 0x4) ? 1 : 0) << 14;
for (unsigned i = 0; i < 3; i++)
encoding |= vop3->neg_hi[i] << (8+i);
encoding |= (0xFF & instr->definitions[0].physReg());
@ -620,7 +620,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
encoding = 0;
for (unsigned i = 0; i < instr->operands.size(); i++)
encoding |= instr->operands[i].physReg() << (i * 9);
encoding |= vop3->opsel_hi & 0x3 << 27;
encoding |= (vop3->opsel_hi & 0x3) << 27;
for (unsigned i = 0; i < 3; i++)
encoding |= vop3->neg_lo[i] << (29+i);
out.push_back(encoding);

View file

@ -903,33 +903,34 @@ for i in range(8):
# VOPP instructions: packed 16bit instructions - 1 or 2 inputs and 1 output
VOPP = {
(0x00, "v_pk_mad_i16"),
(0x01, "v_pk_mul_lo_u16"),
(0x02, "v_pk_add_i16"),
(0x03, "v_pk_sub_i16"),
(0x04, "v_pk_lshlrev_b16"),
(0x05, "v_pk_lshrrev_b16"),
(0x06, "v_pk_ashrrev_i16"),
(0x07, "v_pk_max_i16"),
(0x08, "v_pk_min_i16"),
(0x09, "v_pk_mad_u16"),
(0x0a, "v_pk_add_u16"),
(0x0b, "v_pk_sub_u16"),
(0x0c, "v_pk_max_u16"),
(0x0d, "v_pk_min_u16"),
(0x0e, "v_pk_fma_f16"),
(0x0f, "v_pk_add_f16"),
(0x10, "v_pk_mul_f16"),
(0x11, "v_pk_min_f16"),
(0x12, "v_pk_max_f16"),
(0x20, "v_pk_fma_mix_f32"), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA
(0x21, "v_pk_fma_mixlo_f16"), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA
(0x22, "v_pk_fma_mixhi_f16"), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA
# opcode, name, input/output modifiers
(0x00, "v_pk_mad_i16", False),
(0x01, "v_pk_mul_lo_u16", False),
(0x02, "v_pk_add_i16", False),
(0x03, "v_pk_sub_i16", False),
(0x04, "v_pk_lshlrev_b16", False),
(0x05, "v_pk_lshrrev_b16", False),
(0x06, "v_pk_ashrrev_i16", False),
(0x07, "v_pk_max_i16", False),
(0x08, "v_pk_min_i16", False),
(0x09, "v_pk_mad_u16", False),
(0x0a, "v_pk_add_u16", False),
(0x0b, "v_pk_sub_u16", False),
(0x0c, "v_pk_max_u16", False),
(0x0d, "v_pk_min_u16", False),
(0x0e, "v_pk_fma_f16", True),
(0x0f, "v_pk_add_f16", True),
(0x10, "v_pk_mul_f16", True),
(0x11, "v_pk_min_f16", True),
(0x12, "v_pk_max_f16", True),
(0x20, "v_fma_mix_f32", True), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA
(0x21, "v_fma_mixlo_f16", True), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA
(0x22, "v_fma_mixhi_f16", True), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA
}
# note that these are only supported on gfx9+ so we'll need to distinguish between gfx8 and gfx9 here
# (gfx6, gfx7, gfx8, gfx9, gfx10, name) = (-1, -1, -1, code, code, name)
for (code, name) in VOPP:
opcode(name, -1, code, code, Format.VOP3P)
for (code, name, modifiers) in VOPP:
opcode(name, -1, code, code, Format.VOP3P, modifiers, modifiers)
# VINTERP instructions:

View file

@ -244,6 +244,18 @@ struct InstrPred {
return false;
return true;
}
case Format::VOP3P: {
VOP3P_instruction* a3P = static_cast<VOP3P_instruction*>(a);
VOP3P_instruction* b3P = static_cast<VOP3P_instruction*>(b);
for (unsigned i = 0; i < 3; i++) {
if (a3P->neg_lo[i] != b3P->neg_lo[i] ||
a3P->neg_hi[i] != b3P->neg_hi[i])
return false;
}
return a3P->opsel_lo == b3P->opsel_lo &&
a3P->opsel_hi == b3P->opsel_hi &&
a3P->clamp == b3P->clamp;
}
case Format::PSEUDO_REDUCTION: {
Pseudo_reduction_instruction *aR = static_cast<Pseudo_reduction_instruction*>(a);
Pseudo_reduction_instruction *bR = static_cast<Pseudo_reduction_instruction*>(b);

View file

@ -240,12 +240,13 @@ bool validate_ir(Program* program)
instr->format == Format::VOP1 ||
instr->format == Format::VOP2 ||
instr->format == Format::VOPC ||
(instr->isVOP3() && program->chip_class >= GFX10),
(instr->isVOP3() && program->chip_class >= GFX10) ||
(instr->format == Format::VOP3P && program->chip_class >= GFX10),
"Literal applied on wrong instruction format", instr.get());
check(literal.isUndefined() || (literal.size() == op.size() && literal.constantValue() == op.constantValue()), "Only 1 Literal allowed", instr.get());
literal = op;
check(!instr->isVALU() || instr->isVOP3() || i == 0 || i == 2, "Wrong source position for Literal argument", instr.get());
check(instr->isSALU() || instr->isVOP3() || instr->format == Format::VOP3P || i == 0 || i == 2, "Wrong source position for Literal argument", instr.get());
}
/* check num sgprs for VALU */
@ -257,7 +258,7 @@ bool validate_ir(Program* program)
if (program->chip_class >= GFX10 && !is_shift64)
const_bus_limit = 2;
uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5;
uint32_t scalar_mask = instr->isVOP3() || instr->format == Format::VOP3P ? 0x7 : 0x5;
if (instr->isSDWA())
scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4;