mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-05 11:48:06 +02:00
nak: Wire up coop matrix opcodes
v2: rebase and scheduling (Karol)
remove Ldsm and Movm (Karol)
add support for saturated cmat_muladd
Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32777>
This commit is contained in:
parent
90438bae51
commit
f99db217a7
6 changed files with 300 additions and 23 deletions
|
|
@ -3728,6 +3728,70 @@ impl<'a> ShaderFromNir<'a> {
|
|||
let dst = b.isetp(IntCmpType::I32, IntCmpOp::Ne, src, 0.into());
|
||||
self.set_dst(&intrin.def, dst.into());
|
||||
}
|
||||
nir_intrinsic_cmat_muladd_nv => {
|
||||
let flags: nak_nir_cmat_mul_add_flags =
|
||||
unsafe { std::mem::transmute(intrin.flags()) };
|
||||
let cmat_a = self.get_src(&srcs[0]);
|
||||
let cmat_b = self.get_src(&srcs[1]);
|
||||
let cmat_c = self.get_src(&srcs[2]);
|
||||
let dst_bit_size = intrin.def.bit_size();
|
||||
let dst = b.alloc_ssa_vec(
|
||||
RegFile::GPR,
|
||||
(intrin.def.num_components() * intrin.def.bit_size)
|
||||
.div_ceil(32),
|
||||
);
|
||||
let dst_type = FloatType::from_bits(dst_bit_size.into());
|
||||
match flags.cmat_type() {
|
||||
NAK_CMAT_TYPE_M16N8K8_FLOAT
|
||||
| NAK_CMAT_TYPE_M16N8K16_FLOAT => {
|
||||
let mat_size = match flags.cmat_type() {
|
||||
NAK_CMAT_TYPE_M16N8K8_FLOAT => HmmaSize::M16N8K8,
|
||||
NAK_CMAT_TYPE_M16N8K16_FLOAT => HmmaSize::M16N8K16,
|
||||
val => unreachable!("unsupported HMMA type: {val}"),
|
||||
};
|
||||
|
||||
assert_eq!(flags.a_type(), GLSL_TYPE_FLOAT16);
|
||||
assert_eq!(flags.b_type(), GLSL_TYPE_FLOAT16);
|
||||
assert!(!flags.sat());
|
||||
b.push_op(OpHmma {
|
||||
dst: dst.clone().into(),
|
||||
dst_type: dst_type,
|
||||
src_type: FloatType::F16,
|
||||
mat_size: mat_size,
|
||||
srcs: [cmat_a.into(), cmat_b.into(), cmat_c.into()],
|
||||
});
|
||||
}
|
||||
NAK_CMAT_TYPE_M8N8K16_INT | NAK_CMAT_TYPE_M16N8K32_INT => {
|
||||
let a_type = match flags.a_type() {
|
||||
GLSL_TYPE_UINT8 => IntType::U8,
|
||||
GLSL_TYPE_INT8 => IntType::I8,
|
||||
val => unreachable!("Invalid a_type: {val}"),
|
||||
};
|
||||
let b_type = match flags.b_type() {
|
||||
GLSL_TYPE_UINT8 => IntType::U8,
|
||||
GLSL_TYPE_INT8 => IntType::I8,
|
||||
val => unreachable!("Invalid b_type: {val}"),
|
||||
};
|
||||
|
||||
let mat_size = match flags.cmat_type() {
|
||||
NAK_CMAT_TYPE_M8N8K16_INT => ImmaSize::M8N8K16,
|
||||
NAK_CMAT_TYPE_M16N8K32_INT => ImmaSize::M16N8K32,
|
||||
val => unreachable!("unsupported IMMA type: {val}"),
|
||||
};
|
||||
|
||||
b.push_op(OpImma {
|
||||
dst: dst.clone().into(),
|
||||
mat_size,
|
||||
srcs: [cmat_a.into(), cmat_b.into(), cmat_c.into()],
|
||||
src_types: [a_type, b_type],
|
||||
saturate: flags.sat(),
|
||||
});
|
||||
}
|
||||
val => unreachable!("Unknown cmat_type {val}"),
|
||||
}
|
||||
|
||||
self.set_dst(&intrin.def, dst.into());
|
||||
}
|
||||
_ => panic!(
|
||||
"Unsupported intrinsic instruction: {}",
|
||||
intrin.info().name()
|
||||
|
|
|
|||
|
|
@ -3350,6 +3350,108 @@ impl DisplayOp for OpHMul2 {
|
|||
}
|
||||
impl_display_for_op!(OpHMul2);
|
||||
|
||||
#[derive(Clone, Copy, Eq, PartialEq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum ImmaSize {
|
||||
M8N8K16,
|
||||
M8N8K32,
|
||||
M16N8K16,
|
||||
M16N8K32,
|
||||
M16N8K64,
|
||||
}
|
||||
|
||||
impl fmt::Display for ImmaSize {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ImmaSize::M8N8K16 => write!(f, ".m8n8k16"),
|
||||
ImmaSize::M8N8K32 => write!(f, ".m8n8k32"),
|
||||
ImmaSize::M16N8K16 => write!(f, ".m16n8k16"),
|
||||
ImmaSize::M16N8K32 => write!(f, ".m16n8k32"),
|
||||
ImmaSize::M16N8K64 => write!(f, ".m16n8k64"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpImma {
|
||||
#[dst_type(Vec)]
|
||||
pub dst: Dst,
|
||||
|
||||
pub mat_size: ImmaSize,
|
||||
pub src_types: [IntType; 2],
|
||||
pub saturate: bool,
|
||||
|
||||
#[src_type(SSA)]
|
||||
pub srcs: [Src; 3],
|
||||
}
|
||||
|
||||
impl DisplayOp for OpImma {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let sat = if self.saturate { ".sat" } else { "" };
|
||||
write!(
|
||||
f,
|
||||
"imma{}{}{}{sat} {} {} {}",
|
||||
self.mat_size,
|
||||
self.src_types[0],
|
||||
self.src_types[1],
|
||||
self.srcs[0],
|
||||
self.srcs[1],
|
||||
self.srcs[2],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl_display_for_op!(OpImma);
|
||||
|
||||
#[derive(Clone, Copy, Eq, PartialEq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum HmmaSize {
|
||||
M16N8K16,
|
||||
M16N8K8,
|
||||
M16N8K4,
|
||||
}
|
||||
|
||||
impl fmt::Display for HmmaSize {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
HmmaSize::M16N8K16 => write!(f, ".m16n8k16"),
|
||||
HmmaSize::M16N8K8 => write!(f, ".m16n8k8"),
|
||||
HmmaSize::M16N8K4 => write!(f, ".m16n8k4"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHmma {
|
||||
#[dst_type(Vec)]
|
||||
pub dst: Dst,
|
||||
|
||||
pub mat_size: HmmaSize,
|
||||
pub src_type: FloatType,
|
||||
pub dst_type: FloatType,
|
||||
|
||||
#[src_type(SSA)]
|
||||
pub srcs: [Src; 3],
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHmma {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"hmma{}{} {} {} {}",
|
||||
self.mat_size,
|
||||
self.dst_type,
|
||||
self.srcs[0],
|
||||
self.srcs[1],
|
||||
self.srcs[2],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl_display_for_op!(OpHmma);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHFma2 {
|
||||
|
|
@ -7639,6 +7741,8 @@ pub enum Op {
|
|||
HMul2(OpHMul2),
|
||||
HSet2(OpHSet2),
|
||||
HSetP2(OpHSetP2),
|
||||
Imma(OpImma),
|
||||
Hmma(OpHmma),
|
||||
HMnMx2(OpHMnMx2),
|
||||
BMsk(OpBMsk),
|
||||
BRev(OpBRev),
|
||||
|
|
@ -7805,6 +7909,9 @@ impl Op {
|
|||
| Op::DMul(_)
|
||||
| Op::DSetP(_) => false,
|
||||
|
||||
// Matrix Multiply Add
|
||||
Op::Imma(_) | Op::Hmma(_) => false,
|
||||
|
||||
// Integer ALU
|
||||
Op::BRev(_) | Op::Flo(_) | Op::PopC(_) => false,
|
||||
Op::IMad(_) | Op::IMul(_) => sm >= 70,
|
||||
|
|
|
|||
|
|
@ -182,6 +182,9 @@ pub fn side_effect_type(op: &Op) -> SideEffect {
|
|||
| Op::LdTram(_)
|
||||
| Op::MemBar(_) => SideEffect::Memory,
|
||||
|
||||
// Matrix ops
|
||||
Op::Imma(_) | Op::Hmma(_) => SideEffect::None,
|
||||
|
||||
// Control-flow ops
|
||||
Op::BClear(_)
|
||||
| Op::Break(_)
|
||||
|
|
@ -319,6 +322,8 @@ pub fn estimate_variable_latency(sm: u8, op: &Op) -> u32 {
|
|||
| Op::S2R(_)
|
||||
| Op::Match(_) => 16,
|
||||
|
||||
Op::Hmma(_) | Op::Imma(_) => 22,
|
||||
|
||||
_ => panic!("Unknown variable latency op {op}"),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3832,6 +3832,96 @@ impl SM70Op for OpMatch {
|
|||
}
|
||||
}
|
||||
|
||||
impl SM70Op for OpImma {
|
||||
fn legalize(&mut self, b: &mut LegalizeBuilder) {
|
||||
legalize_ext_instr(self, b);
|
||||
}
|
||||
|
||||
fn encode(&self, e: &mut SM70Encoder<'_>) {
|
||||
assert!(e.sm >= 75);
|
||||
|
||||
e.set_opcode(0x237);
|
||||
e.set_dst(&self.dst);
|
||||
e.set_reg_src(24..32, &self.srcs[0]);
|
||||
e.set_reg_src(32..40, &self.srcs[1]);
|
||||
e.set_reg_src(64..72, &self.srcs[2]);
|
||||
e.set_bit(74, true); // SRC1.COL
|
||||
|
||||
assert!(self.mat_size == ImmaSize::M8N8K16 || e.sm >= 80);
|
||||
e.set_field2(
|
||||
75..76,
|
||||
85..87,
|
||||
match self.mat_size {
|
||||
ImmaSize::M8N8K16 => 0u8,
|
||||
ImmaSize::M8N8K32 => 2u8,
|
||||
ImmaSize::M16N8K16 => 4u8,
|
||||
ImmaSize::M16N8K32 => 5u8,
|
||||
ImmaSize::M16N8K64 => 6u8,
|
||||
},
|
||||
);
|
||||
|
||||
e.set_bit(76, self.src_types[0].is_signed());
|
||||
e.set_bit(78, self.src_types[1].is_signed());
|
||||
e.set_bit(82, self.saturate);
|
||||
|
||||
match self.mat_size {
|
||||
ImmaSize::M8N8K32 | ImmaSize::M16N8K64 => {
|
||||
assert_eq!(self.src_types[0].bits(), 4);
|
||||
assert_eq!(self.src_types[1].bits(), 4);
|
||||
}
|
||||
ImmaSize::M16N8K32 => {
|
||||
assert!(matches!(self.src_types[0].bits(), 4 | 8));
|
||||
assert!(matches!(self.src_types[1].bits(), 4 | 8));
|
||||
}
|
||||
ImmaSize::M8N8K16 | ImmaSize::M16N8K16 => {
|
||||
assert_eq!(self.src_types[0].bits(), 8);
|
||||
assert_eq!(self.src_types[1].bits(), 8);
|
||||
}
|
||||
}
|
||||
e.set_bit(83, self.src_types[0].bits() == 4);
|
||||
e.set_bit(84, self.src_types[1].bits() == 4);
|
||||
}
|
||||
}
|
||||
|
||||
impl SM70Op for OpHmma {
|
||||
fn legalize(&mut self, b: &mut LegalizeBuilder) {
|
||||
legalize_ext_instr(self, b);
|
||||
}
|
||||
|
||||
fn encode(&self, e: &mut SM70Encoder<'_>) {
|
||||
assert!(e.sm >= 75);
|
||||
|
||||
e.set_opcode(0x23c);
|
||||
e.set_dst(&self.dst);
|
||||
e.set_reg_src(24..32, &self.srcs[0]);
|
||||
e.set_reg_src(32..40, &self.srcs[1]);
|
||||
e.set_reg_src(64..72, &self.srcs[2]);
|
||||
|
||||
assert!(self.mat_size != HmmaSize::M16N8K4 || e.sm >= 80);
|
||||
e.set_field2(
|
||||
75..76,
|
||||
78..79,
|
||||
match self.mat_size {
|
||||
HmmaSize::M16N8K8 => 0u8,
|
||||
HmmaSize::M16N8K16 => 1u8,
|
||||
HmmaSize::M16N8K4 => 2u8,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(matches!(self.dst_type, FloatType::F16 | FloatType::F32));
|
||||
e.set_bit(76, self.dst_type == FloatType::F32);
|
||||
e.set_field(
|
||||
82..84,
|
||||
match self.src_type {
|
||||
FloatType::F16 => 0u8,
|
||||
// FloatType::BF16 => 1u8,
|
||||
// FloatType::TF32 => 2u8,
|
||||
_ => unreachable!("unsupported src type!"),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! as_sm70_op_match {
|
||||
($op: expr) => {
|
||||
match $op {
|
||||
|
|
@ -3920,6 +4010,8 @@ macro_rules! as_sm70_op_match {
|
|||
Op::OutFinal(op) => op,
|
||||
Op::Vote(op) => op,
|
||||
Op::Match(op) => op,
|
||||
Op::Hmma(op) => op,
|
||||
Op::Imma(op) => op,
|
||||
_ => panic!("Unsupported op: {}", $op),
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -104,16 +104,11 @@ impl RegLatencySM75 {
|
|||
|
||||
Op::HMnMx2(_) => RedirectedFP16, // not in docs
|
||||
// let in for documentation purposes
|
||||
// Op::Hmma(h) => {
|
||||
// match h.mat_size {
|
||||
// HmmaSize::M16N8K4 => match h.dst_type {
|
||||
// FloatType::F16 => RedirectedHMMA_884_F16,
|
||||
// _ => RedirectedHMMA_884_F32
|
||||
// }
|
||||
// HmmaSize::M16N8K8 => RedirectedHMMA_1688,
|
||||
// HmmaSize::M16N8K16 => RedirectedHMMA_16816,
|
||||
// }
|
||||
// }
|
||||
Op::Hmma(h) => match (h.mat_size, h.dst_type) {
|
||||
(HmmaSize::M16N8K8, _) => RedirectedHMMA_1688,
|
||||
(HmmaSize::M16N8K16, _) => RedirectedHMMA_16816,
|
||||
_ => panic!("Illegal HMMA in reg category {}", h),
|
||||
},
|
||||
Op::Ipa(_) => Decoupled,
|
||||
Op::MuFu(_) => Decoupled,
|
||||
|
||||
|
|
@ -168,8 +163,7 @@ impl RegLatencySM75 {
|
|||
// PMTRIG => CoupledDisp64
|
||||
// CSMTEST => CoupledAlu,
|
||||
Op::Bar(_) => Decoupled,
|
||||
// Remove when Imma added
|
||||
//Op::Imma(_) => IMMA,
|
||||
Op::Imma(_) => IMMA,
|
||||
Op::IDp4(_) => CoupledFMA,
|
||||
Op::BClear(_) => Decoupled,
|
||||
Op::Bra(_) => Decoupled,
|
||||
|
|
|
|||
|
|
@ -144,16 +144,21 @@ impl RegLatencySM80 {
|
|||
|
||||
Op::HSet2(_) | Op::HSetP2(_) | Op::HMnMx2(_) => FP16_Alu,
|
||||
// let in for documentation purposes
|
||||
//Op::Hmma(h) => {
|
||||
//match h.mat_size {
|
||||
// HmmaSize::M16N8K4 => match h.dst_type {
|
||||
// FloatType::F16 => MMA_1x_Collect,
|
||||
// _ => MMA_2x_Collect,
|
||||
// }
|
||||
// HmmaSize::M16N8K8 => MMA_1x_Collect,
|
||||
// HmmaSize::M16N8K16 => MMA_2x_Collect,
|
||||
// }
|
||||
//}
|
||||
Op::Hmma(h) => match (h.mat_size, h.dst_type, h.src_type) {
|
||||
(HmmaSize::M16N8K16, FloatType::F32, FloatType::F16) => {
|
||||
MMA_2x_collect
|
||||
}
|
||||
// (HmmaSize::M16N8K16, FloatType::F32, FloatType::BF16) => MMA_2x_collect,
|
||||
// (HmmaSize::M16N8K8, FloatType::F32, FloatType::TF32) => MMA_2x_collect,
|
||||
(HmmaSize::M16N8K8, FloatType::F32, FloatType::F16) => {
|
||||
MMA_1x_collect
|
||||
}
|
||||
// (HmmaSize::M16N8K8, FloatType::F32, FloatType::BF16) => MMA_1x_collect,
|
||||
// (HmmaSize::M16N8K4, FloatType::F32, FloatType::TF32) => MMA_1x_collect,
|
||||
(HmmaSize::M16N8K16, FloatType::F16, _) => MMA_2x_collect,
|
||||
(HmmaSize::M16N8K8, FloatType::F16, _) => MMA_1x_collect,
|
||||
_ => panic!("Illegal HMMA in reg category {}", h),
|
||||
},
|
||||
Op::Ipa(_) => DecoupledAgu,
|
||||
Op::MuFu(_) => Decoupled,
|
||||
|
||||
|
|
@ -202,7 +207,17 @@ impl RegLatencySM80 {
|
|||
// CSMTEST => CoupledAlu,
|
||||
Op::Bar(_) => DecoupledAgu,
|
||||
// Remove when Imma added
|
||||
//Op::Imma(_) => IMMA,
|
||||
Op::Imma(i) => match (i.mat_size, i.src_types[0]) {
|
||||
(ImmaSize::M16N8K64, _) => MMA_2x_collect,
|
||||
(ImmaSize::M16N8K32, IntType::I8 | IntType::U8) => {
|
||||
MMA_2x_collect
|
||||
}
|
||||
// (ImmaSize::M16N8K32, IntType::I4 | IntType::U4) => MMA_1x_collect,
|
||||
(ImmaSize::M16N8K16, _) => MMA_1x_collect,
|
||||
(ImmaSize::M8N8K32, _) => IMMA_88,
|
||||
(ImmaSize::M8N8K16, _) => IMMA_88,
|
||||
_ => panic!("Illegal IMMA in reg category {}", i),
|
||||
},
|
||||
Op::IDp4(_) => CoupledFMA,
|
||||
Op::BClear(_) => Decoupled,
|
||||
Op::Bra(_) => Decoupled,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue