aco/sched_vopd: create dot2acc from VOP3P dot2

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40225>
This commit is contained in:
Georg Lehmann 2026-03-04 16:06:30 +01:00 committed by Marge Bot
parent 47599b2c38
commit 788aafba2a
2 changed files with 105 additions and 1 deletions

View file

@ -135,7 +135,7 @@ VOPDInfo
get_vopd_info(const SchedILPContext& ctx, const Instruction* instr)
{
if (instr->format != Format::VOP1 && instr->format != Format::VOP2 &&
instr->format != Format::VOP3)
instr->format != Format::VOP3 && instr->format != Format::VOP3P)
return VOPDInfo();
VOPDInfo info;
@ -211,6 +211,48 @@ get_vopd_info(const SchedILPContext& ctx, const Instruction* instr)
}
break;
}
case aco_opcode::v_dot2_f32_f16:
case aco_opcode::v_dot2_f32_bf16: {
bool bf16 = instr->opcode == aco_opcode::v_dot2_f32_bf16;
/* src2 must be the same as the destination. */
if (!instr->operands[2].isOfType(RegType::vgpr) ||
instr->operands[2].physReg() != instr->definitions[0].physReg() ||
instr->valu().clamp)
return VOPDInfo();
/* One pair of factors must be a vgpr. */
if (!instr->operands[0].isOfType(RegType::vgpr) &&
!instr->operands[1].isOfType(RegType::vgpr))
return VOPDInfo();
for (unsigned i = 0; i < 3; i++) {
if (!instr->operands[i].isConstant() &&
(instr->valu().neg_lo[i] || instr->valu().neg_hi[i] || instr->valu().opsel_lo[i] ||
!instr->valu().opsel_hi[i]))
return VOPDInfo();
if (instr->operands[i].isConstant() &&
(instr->operands[i].isLiteral() || instr->valu().neg_lo[i] ||
instr->valu().neg_hi[i] || instr->valu().opsel_lo[i] != bf16 ||
instr->valu().opsel_hi[i] != bf16)) {
info.has_literal = true;
info.literal = instr->operands[i].constantValue();
uint32_t lo = (info.literal >> (instr->valu().opsel_lo[i] * 16)) & 0xffff;
uint32_t hi = (info.literal >> (instr->valu().opsel_hi[i] * 16)) & 0xffff;
lo ^= instr->valu().neg_lo[i] ? 0x8000 : 0;
hi ^= instr->valu().neg_hi[i] ? 0x8000 : 0;
info.literal = lo | (hi << 16);
}
}
if (!instr->operands[1].isOfType(RegType::vgpr))
info.operand_swizzle = 0b10'00'01;
if (info.has_literal)
info.operand_swizzle |= 0b11;
info.op = bf16 ? aco_opcode::v_dual_dot2acc_f32_bf16 : aco_opcode::v_dual_dot2acc_f32_f16;
break;
}
default: return VOPDInfo();
}

View file

@ -341,3 +341,65 @@ BEGIN_TEST(vopd_sched.fma_with_constant)
finish_schedule_vopd_test();
END_TEST
BEGIN_TEST(vopd_sched.dot2acc_from_vop3p)
if (!setup_cs(NULL, GFX11, CHIP_UNKNOWN, "", 32))
return;
PhysReg reg_v0{256};
PhysReg reg_v1{257};
PhysReg reg_v2{258};
PhysReg reg_v3{259};
PhysReg reg_s0{0};
//>> p_unit_test 0
//! v1: %0:v[1] = v_dual_dot2acc_f32_bf16 %0:s[0], %0:v[3], %0:v[1] :: v1: %0:v[0] = v_dual_dot2acc_f32_f16 2.0, %0:v[2], %0:v[0]
bld.pseudo(aco_opcode::p_unit_test, Operand::zero());
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v2, v1),
Operand::c16(0x4000), Operand(reg_v0, v1), 0, 0b101);
bld.vop3p(aco_opcode::v_dot2_f32_bf16, Definition(reg_v1, v1), Operand(reg_v3, v1),
Operand(reg_s0, s1), Operand(reg_v1, v1), 0, 0b111);
bld.reset(program->create_and_insert_block());
//>> p_unit_test 1
//! v1: %0:v[1] = v_dual_dot2acc_f32_bf16 %0:s[0], %0:v[3], %0:v[1] :: v1: %0:v[0] = v_dual_dot2acc_f32_f16 0x70ad70ad, %0:v[2], %0:v[0]
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1));
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v2, v1),
Operand::literal32(0x70AD), Operand(reg_v0, v1), 0, 0b101);
bld.vop3p(aco_opcode::v_dot2_f32_bf16, Definition(reg_v1, v1), Operand(reg_s0, s1),
Operand(reg_v3, v1), Operand(reg_v1, v1), 0, 0b111);
/* Needs at least one vgpr factor. */
bld.reset(program->create_and_insert_block());
//>> p_unit_test 2
//! v1: %0:v[1] = v_mov_b32 0
//! v1: %0:v[0] = v_dot2_f32_f16 %0:s[0], 2.0.xx, %0:v[0]
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2));
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_s0, s1),
Operand::c16(0x4000), Operand(reg_v0, v1), 0, 0b101);
bld.vop1(aco_opcode::v_mov_b32, Definition(reg_v1, v1), Operand::c32(0));
/* Allow no modifiers. */
bld.reset(program->create_and_insert_block());
//>> p_unit_test 3
//! v1: %0:v[0] = v_dot2_f32_f16 %0:v[1]*[-1,1], %0:v[2], %0:v[0]
//! v1: %0:v[1] = v_mov_b32 0
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(3));
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v1, v1),
Operand(reg_v2, v1), Operand(reg_v0, v1), 0, 0b111)
->valu()
.neg_lo[0] = true;
bld.vop1(aco_opcode::v_mov_b32, Definition(reg_v1, v1), Operand::c32(0));
/* Definition must be the same as the last operand. */
bld.reset(program->create_and_insert_block());
//>> p_unit_test 4
//! v1: %0:v[0] = v_dot2_f32_f16 %0:v[1], %0:v[2], %0:v[3]
//! v1: %0:v[1] = v_mov_b32 0
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(4));
bld.vop3p(aco_opcode::v_dot2_f32_f16, Definition(reg_v0, v1), Operand(reg_v1, v1),
Operand(reg_v2, v1), Operand(reg_v3, v1), 0, 0b111);
bld.vop1(aco_opcode::v_mov_b32, Definition(reg_v1, v1), Operand::c32(0));
finish_schedule_vopd_test();
END_TEST