mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-01 03:48:06 +02:00
aco/sched_vopd: create dot2acc from VOP3P dot2
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40225>
This commit is contained in:
parent
47599b2c38
commit
788aafba2a
2 changed files with 105 additions and 1 deletions
|
|
@ -135,7 +135,7 @@ VOPDInfo
|
|||
get_vopd_info(const SchedILPContext& ctx, const Instruction* instr)
|
||||
{
|
||||
if (instr->format != Format::VOP1 && instr->format != Format::VOP2 &&
|
||||
instr->format != Format::VOP3)
|
||||
instr->format != Format::VOP3 && instr->format != Format::VOP3P)
|
||||
return VOPDInfo();
|
||||
|
||||
VOPDInfo info;
|
||||
|
|
@ -211,6 +211,48 @@ get_vopd_info(const SchedILPContext& ctx, const Instruction* instr)
|
|||
}
|
||||
break;
|
||||
}
|
||||
case aco_opcode::v_dot2_f32_f16:
|
||||
case aco_opcode::v_dot2_f32_bf16: {
|
||||
bool bf16 = instr->opcode == aco_opcode::v_dot2_f32_bf16;
|
||||
/* src2 must be the same as the destination. */
|
||||
if (!instr->operands[2].isOfType(RegType::vgpr) ||
|
||||
instr->operands[2].physReg() != instr->definitions[0].physReg() ||
|
||||
instr->valu().clamp)
|
||||
return VOPDInfo();
|
||||
|
||||
/* One pair of factors must be a vgpr. */
|
||||
if (!instr->operands[0].isOfType(RegType::vgpr) &&
|
||||
!instr->operands[1].isOfType(RegType::vgpr))
|
||||
return VOPDInfo();
|
||||
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
if (!instr->operands[i].isConstant() &&
|
||||
(instr->valu().neg_lo[i] || instr->valu().neg_hi[i] || instr->valu().opsel_lo[i] ||
|
||||
!instr->valu().opsel_hi[i]))
|
||||
return VOPDInfo();
|
||||
if (instr->operands[i].isConstant() &&
|
||||
(instr->operands[i].isLiteral() || instr->valu().neg_lo[i] ||
|
||||
instr->valu().neg_hi[i] || instr->valu().opsel_lo[i] != bf16 ||
|
||||
instr->valu().opsel_hi[i] != bf16)) {
|
||||
info.has_literal = true;
|
||||
info.literal = instr->operands[i].constantValue();
|
||||
uint32_t lo = (info.literal >> (instr->valu().opsel_lo[i] * 16)) & 0xffff;
|
||||
uint32_t hi = (info.literal >> (instr->valu().opsel_hi[i] * 16)) & 0xffff;
|
||||
lo ^= instr->valu().neg_lo[i] ? 0x8000 : 0;
|
||||
hi ^= instr->valu().neg_hi[i] ? 0x8000 : 0;
|
||||
info.literal = lo | (hi << 16);
|
||||
}
|
||||
}
|
||||
|
||||
if (!instr->operands[1].isOfType(RegType::vgpr))
|
||||
info.operand_swizzle = 0b10'00'01;
|
||||
|
||||
if (info.has_literal)
|
||||
info.operand_swizzle |= 0b11;
|
||||
|
||||
info.op = bf16 ? aco_opcode::v_dual_dot2acc_f32_bf16 : aco_opcode::v_dual_dot2acc_f32_f16;
|
||||
break;
|
||||
}
|
||||
default: return VOPDInfo();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -341,3 +341,65 @@ BEGIN_TEST(vopd_sched.fma_with_constant)
|
|||
|
||||
finish_schedule_vopd_test();
|
||||
END_TEST
|
||||
|
||||
BEGIN_TEST(vopd_sched.dot2acc_from_vop3p)
|
||||
if (!setup_cs(NULL, GFX11, CHIP_UNKNOWN, "", 32))
|
||||
return;
|
||||
|
||||
PhysReg reg_v0{256};
|
||||
PhysReg reg_v1{257};
|
||||
PhysReg reg_v2{258};
|
||||
PhysReg reg_v3{259};
|
||||
PhysReg reg_s0{0};
|
||||
|
||||
//>> p_unit_test 0
|
||||
//! v1: %0:v[1] = v_dual_dot2acc_f32_bf16 %0:s[0], %0:v[3], %0:v[1] :: v1: %0:v[0] = v_dual_dot2acc_f32_f16 2.0, %0:v[2], %0:v[0]
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::zero());
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v2, v1),
|
||||
Operand::c16(0x4000), Operand(reg_v0, v1), 0, 0b101);
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_bf16, Definition(reg_v1, v1), Operand(reg_v3, v1),
|
||||
Operand(reg_s0, s1), Operand(reg_v1, v1), 0, 0b111);
|
||||
|
||||
bld.reset(program->create_and_insert_block());
|
||||
//>> p_unit_test 1
|
||||
//! v1: %0:v[1] = v_dual_dot2acc_f32_bf16 %0:s[0], %0:v[3], %0:v[1] :: v1: %0:v[0] = v_dual_dot2acc_f32_f16 0x70ad70ad, %0:v[2], %0:v[0]
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1));
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v2, v1),
|
||||
Operand::literal32(0x70AD), Operand(reg_v0, v1), 0, 0b101);
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_bf16, Definition(reg_v1, v1), Operand(reg_s0, s1),
|
||||
Operand(reg_v3, v1), Operand(reg_v1, v1), 0, 0b111);
|
||||
|
||||
/* Needs at least one vgpr factor. */
|
||||
bld.reset(program->create_and_insert_block());
|
||||
//>> p_unit_test 2
|
||||
//! v1: %0:v[1] = v_mov_b32 0
|
||||
//! v1: %0:v[0] = v_dot2_f32_f16 %0:s[0], 2.0.xx, %0:v[0]
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2));
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_s0, s1),
|
||||
Operand::c16(0x4000), Operand(reg_v0, v1), 0, 0b101);
|
||||
bld.vop1(aco_opcode::v_mov_b32, Definition(reg_v1, v1), Operand::c32(0));
|
||||
|
||||
/* Allow no modifiers. */
|
||||
bld.reset(program->create_and_insert_block());
|
||||
//>> p_unit_test 3
|
||||
//! v1: %0:v[0] = v_dot2_f32_f16 %0:v[1]*[-1,1], %0:v[2], %0:v[0]
|
||||
//! v1: %0:v[1] = v_mov_b32 0
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(3));
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v1, v1),
|
||||
Operand(reg_v2, v1), Operand(reg_v0, v1), 0, 0b111)
|
||||
->valu()
|
||||
.neg_lo[0] = true;
|
||||
bld.vop1(aco_opcode::v_mov_b32, Definition(reg_v1, v1), Operand::c32(0));
|
||||
|
||||
/* Definition must be the same as the last operand. */
|
||||
bld.reset(program->create_and_insert_block());
|
||||
//>> p_unit_test 4
|
||||
//! v1: %0:v[0] = v_dot2_f32_f16 %0:v[1], %0:v[2], %0:v[3]
|
||||
//! v1: %0:v[1] = v_mov_b32 0
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(4));
|
||||
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v1, v1),
|
||||
Operand(reg_v2, v1), Operand(reg_v3, v1), 0, 0b111);
|
||||
bld.vop1(aco_opcode::v_mov_b32, Definition(reg_v1, v1), Operand::c32(0));
|
||||
|
||||
finish_schedule_vopd_test();
|
||||
END_TEST
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue