nak: Wire up coop matrix opcodes

v2: rebase and scheduling (Karol)
    remove Ldsm and Movm (Karol)
    add support for saturated cmat_muladd

Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32777>
This commit is contained in:
Mary Guillemard 2024-12-24 14:09:58 +01:00 committed by Marge Bot
parent 90438bae51
commit f99db217a7
6 changed files with 300 additions and 23 deletions

View file

@ -3728,6 +3728,70 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.isetp(IntCmpType::I32, IntCmpOp::Ne, src, 0.into());
self.set_dst(&intrin.def, dst.into());
}
nir_intrinsic_cmat_muladd_nv => {
let flags: nak_nir_cmat_mul_add_flags =
unsafe { std::mem::transmute(intrin.flags()) };
let cmat_a = self.get_src(&srcs[0]);
let cmat_b = self.get_src(&srcs[1]);
let cmat_c = self.get_src(&srcs[2]);
let dst_bit_size = intrin.def.bit_size();
let dst = b.alloc_ssa_vec(
RegFile::GPR,
(intrin.def.num_components() * intrin.def.bit_size)
.div_ceil(32),
);
let dst_type = FloatType::from_bits(dst_bit_size.into());
match flags.cmat_type() {
NAK_CMAT_TYPE_M16N8K8_FLOAT
| NAK_CMAT_TYPE_M16N8K16_FLOAT => {
let mat_size = match flags.cmat_type() {
NAK_CMAT_TYPE_M16N8K8_FLOAT => HmmaSize::M16N8K8,
NAK_CMAT_TYPE_M16N8K16_FLOAT => HmmaSize::M16N8K16,
val => unreachable!("unsupported HMMA type: {val}"),
};
assert_eq!(flags.a_type(), GLSL_TYPE_FLOAT16);
assert_eq!(flags.b_type(), GLSL_TYPE_FLOAT16);
assert!(!flags.sat());
b.push_op(OpHmma {
dst: dst.clone().into(),
dst_type: dst_type,
src_type: FloatType::F16,
mat_size: mat_size,
srcs: [cmat_a.into(), cmat_b.into(), cmat_c.into()],
});
}
NAK_CMAT_TYPE_M8N8K16_INT | NAK_CMAT_TYPE_M16N8K32_INT => {
let a_type = match flags.a_type() {
GLSL_TYPE_UINT8 => IntType::U8,
GLSL_TYPE_INT8 => IntType::I8,
val => unreachable!("Invalid a_type: {val}"),
};
let b_type = match flags.b_type() {
GLSL_TYPE_UINT8 => IntType::U8,
GLSL_TYPE_INT8 => IntType::I8,
val => unreachable!("Invalid b_type: {val}"),
};
let mat_size = match flags.cmat_type() {
NAK_CMAT_TYPE_M8N8K16_INT => ImmaSize::M8N8K16,
NAK_CMAT_TYPE_M16N8K32_INT => ImmaSize::M16N8K32,
val => unreachable!("unsupported IMMA type: {val}"),
};
b.push_op(OpImma {
dst: dst.clone().into(),
mat_size,
srcs: [cmat_a.into(), cmat_b.into(), cmat_c.into()],
src_types: [a_type, b_type],
saturate: flags.sat(),
});
}
val => unreachable!("Unknown cmat_type {val}"),
}
self.set_dst(&intrin.def, dst.into());
}
_ => panic!(
"Unsupported intrinsic instruction: {}",
intrin.info().name()

View file

@ -3350,6 +3350,108 @@ impl DisplayOp for OpHMul2 {
}
impl_display_for_op!(OpHMul2);
#[derive(Clone, Copy, Eq, PartialEq)]
#[allow(dead_code)]
pub enum ImmaSize {
M8N8K16,
M8N8K32,
M16N8K16,
M16N8K32,
M16N8K64,
}
impl fmt::Display for ImmaSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ImmaSize::M8N8K16 => write!(f, ".m8n8k16"),
ImmaSize::M8N8K32 => write!(f, ".m8n8k32"),
ImmaSize::M16N8K16 => write!(f, ".m16n8k16"),
ImmaSize::M16N8K32 => write!(f, ".m16n8k32"),
ImmaSize::M16N8K64 => write!(f, ".m16n8k64"),
}
}
}
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpImma {
#[dst_type(Vec)]
pub dst: Dst,
pub mat_size: ImmaSize,
pub src_types: [IntType; 2],
pub saturate: bool,
#[src_type(SSA)]
pub srcs: [Src; 3],
}
impl DisplayOp for OpImma {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sat = if self.saturate { ".sat" } else { "" };
write!(
f,
"imma{}{}{}{sat} {} {} {}",
self.mat_size,
self.src_types[0],
self.src_types[1],
self.srcs[0],
self.srcs[1],
self.srcs[2],
)
}
}
impl_display_for_op!(OpImma);
#[derive(Clone, Copy, Eq, PartialEq)]
#[allow(dead_code)]
pub enum HmmaSize {
M16N8K16,
M16N8K8,
M16N8K4,
}
impl fmt::Display for HmmaSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HmmaSize::M16N8K16 => write!(f, ".m16n8k16"),
HmmaSize::M16N8K8 => write!(f, ".m16n8k8"),
HmmaSize::M16N8K4 => write!(f, ".m16n8k4"),
}
}
}
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHmma {
#[dst_type(Vec)]
pub dst: Dst,
pub mat_size: HmmaSize,
pub src_type: FloatType,
pub dst_type: FloatType,
#[src_type(SSA)]
pub srcs: [Src; 3],
}
impl DisplayOp for OpHmma {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"hmma{}{} {} {} {}",
self.mat_size,
self.dst_type,
self.srcs[0],
self.srcs[1],
self.srcs[2],
)
}
}
impl_display_for_op!(OpHmma);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHFma2 {
@ -7639,6 +7741,8 @@ pub enum Op {
HMul2(OpHMul2),
HSet2(OpHSet2),
HSetP2(OpHSetP2),
Imma(OpImma),
Hmma(OpHmma),
HMnMx2(OpHMnMx2),
BMsk(OpBMsk),
BRev(OpBRev),
@ -7805,6 +7909,9 @@ impl Op {
| Op::DMul(_)
| Op::DSetP(_) => false,
// Matrix Multiply Add
Op::Imma(_) | Op::Hmma(_) => false,
// Integer ALU
Op::BRev(_) | Op::Flo(_) | Op::PopC(_) => false,
Op::IMad(_) | Op::IMul(_) => sm >= 70,

View file

@ -182,6 +182,9 @@ pub fn side_effect_type(op: &Op) -> SideEffect {
| Op::LdTram(_)
| Op::MemBar(_) => SideEffect::Memory,
// Matrix ops
Op::Imma(_) | Op::Hmma(_) => SideEffect::None,
// Control-flow ops
Op::BClear(_)
| Op::Break(_)
@ -319,6 +322,8 @@ pub fn estimate_variable_latency(sm: u8, op: &Op) -> u32 {
| Op::S2R(_)
| Op::Match(_) => 16,
Op::Hmma(_) | Op::Imma(_) => 22,
_ => panic!("Unknown variable latency op {op}"),
}
}

View file

@ -3832,6 +3832,96 @@ impl SM70Op for OpMatch {
}
}
impl SM70Op for OpImma {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
legalize_ext_instr(self, b);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
assert!(e.sm >= 75);
e.set_opcode(0x237);
e.set_dst(&self.dst);
e.set_reg_src(24..32, &self.srcs[0]);
e.set_reg_src(32..40, &self.srcs[1]);
e.set_reg_src(64..72, &self.srcs[2]);
e.set_bit(74, true); // SRC1.COL
assert!(self.mat_size == ImmaSize::M8N8K16 || e.sm >= 80);
e.set_field2(
75..76,
85..87,
match self.mat_size {
ImmaSize::M8N8K16 => 0u8,
ImmaSize::M8N8K32 => 2u8,
ImmaSize::M16N8K16 => 4u8,
ImmaSize::M16N8K32 => 5u8,
ImmaSize::M16N8K64 => 6u8,
},
);
e.set_bit(76, self.src_types[0].is_signed());
e.set_bit(78, self.src_types[1].is_signed());
e.set_bit(82, self.saturate);
match self.mat_size {
ImmaSize::M8N8K32 | ImmaSize::M16N8K64 => {
assert_eq!(self.src_types[0].bits(), 4);
assert_eq!(self.src_types[1].bits(), 4);
}
ImmaSize::M16N8K32 => {
assert!(matches!(self.src_types[0].bits(), 4 | 8));
assert!(matches!(self.src_types[1].bits(), 4 | 8));
}
ImmaSize::M8N8K16 | ImmaSize::M16N8K16 => {
assert_eq!(self.src_types[0].bits(), 8);
assert_eq!(self.src_types[1].bits(), 8);
}
}
e.set_bit(83, self.src_types[0].bits() == 4);
e.set_bit(84, self.src_types[1].bits() == 4);
}
}
impl SM70Op for OpHmma {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
legalize_ext_instr(self, b);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
assert!(e.sm >= 75);
e.set_opcode(0x23c);
e.set_dst(&self.dst);
e.set_reg_src(24..32, &self.srcs[0]);
e.set_reg_src(32..40, &self.srcs[1]);
e.set_reg_src(64..72, &self.srcs[2]);
assert!(self.mat_size != HmmaSize::M16N8K4 || e.sm >= 80);
e.set_field2(
75..76,
78..79,
match self.mat_size {
HmmaSize::M16N8K8 => 0u8,
HmmaSize::M16N8K16 => 1u8,
HmmaSize::M16N8K4 => 2u8,
},
);
assert!(matches!(self.dst_type, FloatType::F16 | FloatType::F32));
e.set_bit(76, self.dst_type == FloatType::F32);
e.set_field(
82..84,
match self.src_type {
FloatType::F16 => 0u8,
// FloatType::BF16 => 1u8,
// FloatType::TF32 => 2u8,
_ => unreachable!("unsupported src type!"),
},
)
}
}
macro_rules! as_sm70_op_match {
($op: expr) => {
match $op {
@ -3920,6 +4010,8 @@ macro_rules! as_sm70_op_match {
Op::OutFinal(op) => op,
Op::Vote(op) => op,
Op::Match(op) => op,
Op::Hmma(op) => op,
Op::Imma(op) => op,
_ => panic!("Unsupported op: {}", $op),
}
};

View file

@ -104,16 +104,11 @@ impl RegLatencySM75 {
Op::HMnMx2(_) => RedirectedFP16, // not in docs
// let in for documentation purposes
// Op::Hmma(h) => {
// match h.mat_size {
// HmmaSize::M16N8K4 => match h.dst_type {
// FloatType::F16 => RedirectedHMMA_884_F16,
// _ => RedirectedHMMA_884_F32
// }
// HmmaSize::M16N8K8 => RedirectedHMMA_1688,
// HmmaSize::M16N8K16 => RedirectedHMMA_16816,
// }
// }
Op::Hmma(h) => match (h.mat_size, h.dst_type) {
(HmmaSize::M16N8K8, _) => RedirectedHMMA_1688,
(HmmaSize::M16N8K16, _) => RedirectedHMMA_16816,
_ => panic!("Illegal HMMA in reg category {}", h),
},
Op::Ipa(_) => Decoupled,
Op::MuFu(_) => Decoupled,
@ -168,8 +163,7 @@ impl RegLatencySM75 {
// PMTRIG => CoupledDisp64
// CSMTEST => CoupledAlu,
Op::Bar(_) => Decoupled,
// Remove when Imma added
//Op::Imma(_) => IMMA,
Op::Imma(_) => IMMA,
Op::IDp4(_) => CoupledFMA,
Op::BClear(_) => Decoupled,
Op::Bra(_) => Decoupled,

View file

@ -144,16 +144,21 @@ impl RegLatencySM80 {
Op::HSet2(_) | Op::HSetP2(_) | Op::HMnMx2(_) => FP16_Alu,
// let in for documentation purposes
//Op::Hmma(h) => {
//match h.mat_size {
// HmmaSize::M16N8K4 => match h.dst_type {
// FloatType::F16 => MMA_1x_Collect,
// _ => MMA_2x_Collect,
// }
// HmmaSize::M16N8K8 => MMA_1x_Collect,
// HmmaSize::M16N8K16 => MMA_2x_Collect,
// }
//}
Op::Hmma(h) => match (h.mat_size, h.dst_type, h.src_type) {
(HmmaSize::M16N8K16, FloatType::F32, FloatType::F16) => {
MMA_2x_collect
}
// (HmmaSize::M16N8K16, FloatType::F32, FloatType::BF16) => MMA_2x_collect,
// (HmmaSize::M16N8K8, FloatType::F32, FloatType::TF32) => MMA_2x_collect,
(HmmaSize::M16N8K8, FloatType::F32, FloatType::F16) => {
MMA_1x_collect
}
// (HmmaSize::M16N8K8, FloatType::F32, FloatType::BF16) => MMA_1x_collect,
// (HmmaSize::M16N8K4, FloatType::F32, FloatType::TF32) => MMA_1x_collect,
(HmmaSize::M16N8K16, FloatType::F16, _) => MMA_2x_collect,
(HmmaSize::M16N8K8, FloatType::F16, _) => MMA_1x_collect,
_ => panic!("Illegal HMMA in reg category {}", h),
},
Op::Ipa(_) => DecoupledAgu,
Op::MuFu(_) => Decoupled,
@ -202,7 +207,17 @@ impl RegLatencySM80 {
// CSMTEST => CoupledAlu,
Op::Bar(_) => DecoupledAgu,
// Remove when Imma added
//Op::Imma(_) => IMMA,
Op::Imma(i) => match (i.mat_size, i.src_types[0]) {
(ImmaSize::M16N8K64, _) => MMA_2x_collect,
(ImmaSize::M16N8K32, IntType::I8 | IntType::U8) => {
MMA_2x_collect
}
// (ImmaSize::M16N8K32, IntType::I4 | IntType::U4) => MMA_1x_collect,
(ImmaSize::M16N8K16, _) => MMA_1x_collect,
(ImmaSize::M8N8K32, _) => IMMA_88,
(ImmaSize::M8N8K16, _) => IMMA_88,
_ => panic!("Illegal IMMA in reg category {}", i),
},
Op::IDp4(_) => CoupledFMA,
Op::BClear(_) => Decoupled,
Op::Bra(_) => Decoupled,