nak: Fix sm90+ atomg/redg encoding

and add a test for ld, st, atom

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34334>
This commit is contained in:
Mel Henning 2025-04-01 20:25:50 -04:00 committed by Marge Bot
parent 869452aaf0
commit f70b7d10c2
3 changed files with 200 additions and 14 deletions

View file

@ -2523,6 +2523,15 @@ impl AtomType {
AtomType::U64 | AtomType::I64 | AtomType::F64 => 64,
}
}
pub fn is_float(&self) -> bool {
match self {
AtomType::F16x2 | AtomType::F32 | AtomType::F64 => true,
AtomType::U32 | AtomType::I32 | AtomType::U64 | AtomType::I64 => {
false
}
}
}
}
impl fmt::Display for AtomType {

View file

@ -168,3 +168,140 @@ pub fn test_nop() {
c.check(sm);
}
}
#[test]
pub fn test_ld_st_atom() {
let r0 = RegRef::new(RegFile::GPR, 0, 1);
let r1 = RegRef::new(RegFile::GPR, 1, 1);
let r2 = RegRef::new(RegFile::GPR, 2, 1);
let r3 = RegRef::new(RegFile::GPR, 3, 1);
let order = MemOrder::Strong(MemScope::CTA);
let atom_types = [
(AtomType::F16x2, ".f16x2.rn"),
(AtomType::U32, ""),
(AtomType::I32, ".s32"),
(AtomType::F32, ".f32.ftz.rn"),
(AtomType::U64, ".64"),
(AtomType::I64, ".s64"),
(AtomType::F64, ".f64.rn"),
];
let spaces = [
MemSpace::Global(MemAddrType::A64),
MemSpace::Shared,
MemSpace::Local,
];
for sm in SM_LIST {
let mut c = DisasmCheck::new();
for space in spaces {
for (addr_offset, addr_offset_str) in [(0x12, "0x12"), (-1, "-0x1")]
{
let cta = if sm >= 80 { "sm" } else { "cta" };
let pri = match space {
MemSpace::Global(_) => MemEvictionPriority::First,
MemSpace::Shared => MemEvictionPriority::Normal,
MemSpace::Local => MemEvictionPriority::Normal,
};
let access = MemAccess {
mem_type: MemType::B32,
space,
order: order,
eviction_priority: pri,
};
let instr = OpLd {
dst: Dst::Reg(r0),
addr: SrcRef::Reg(r1).into(),
offset: addr_offset,
access: access.clone(),
};
let expected = match space {
MemSpace::Global(_) => {
format!(
"ldg.e.ef.strong.{cta} r0, [r1+{addr_offset_str}];"
)
}
MemSpace::Shared => {
format!("lds r0, [r1+{addr_offset_str}];")
}
MemSpace::Local => {
format!("ldl r0, [r1+{addr_offset_str}];")
}
};
c.push(instr, expected);
let instr = OpSt {
addr: SrcRef::Reg(r1).into(),
data: SrcRef::Reg(r2).into(),
offset: addr_offset,
access: access.clone(),
};
let expected = match space {
MemSpace::Global(_) => {
format!(
"stg.e.ef.strong.{cta} [r1+{addr_offset_str}], r2;"
)
}
MemSpace::Shared => {
format!("sts [r1+{addr_offset_str}], r2;")
}
MemSpace::Local => {
format!("stl [r1+{addr_offset_str}], r2;")
}
};
c.push(instr, expected);
for (atom_type, atom_type_str) in atom_types {
for use_dst in [true, false] {
let instr = OpAtom {
dst: if use_dst { Dst::Reg(r0) } else { Dst::None },
addr: SrcRef::Reg(r1).into(),
data: SrcRef::Reg(r2).into(),
atom_op: AtomOp::Add,
cmpr: SrcRef::Reg(r3).into(),
atom_type,
addr_offset,
mem_space: space,
mem_order: order,
mem_eviction_priority: pri,
};
let expected = match space {
MemSpace::Global(_) => {
let op = if use_dst {
"atomg"
} else if sm >= 90 {
"redg"
} else {
"red"
};
let dst = if use_dst { "pt, r0, " } else { "" };
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[r1+{addr_offset_str}], r2;")
}
MemSpace::Shared => {
if atom_type.is_float() {
continue;
}
if atom_type.bits() == 64 {
continue;
}
let dst = if use_dst { "r0" } else { "rz" };
format!("atoms.add{atom_type_str} {dst}, [r1+{addr_offset_str}], r2;")
}
MemSpace::Local => continue,
};
c.push(instr, expected);
}
}
}
}
c.check(sm);
}
}

View file

@ -2966,18 +2966,46 @@ impl SM70Encoder<'_> {
}
fn set_atom_type(&mut self, atom_type: AtomType) {
self.set_field(
73..76,
match atom_type {
AtomType::U32 => 0_u8,
AtomType::I32 => 1_u8,
AtomType::U64 => 2_u8,
AtomType::F32 => 3_u8,
AtomType::F16x2 => 4_u8,
AtomType::I64 => 5_u8,
AtomType::F64 => 6_u8,
},
);
if self.sm >= 90 {
// Float/int is differentiated by opcode
self.set_field(
73..77,
match atom_type {
AtomType::F16x2 => 0_u8,
// f16x4 => 1
// f16x8 => 2
// bf16x2 => 3
// bf16x4 => 4
// bf16x8 => 5
AtomType::F32 => 9_u8, // .ftz
// f32x2.ftz => 10
// f32x4.ftz => 11
// f32x1 => 12
// f32x2 => 13
// f32x4 => 14
AtomType::F64 => 15_u8,
AtomType::U32 => 0,
AtomType::I32 => 1,
AtomType::U64 => 2,
AtomType::I64 => 3,
// u128 => 4,
},
);
} else {
self.set_field(
73..76,
match atom_type {
AtomType::U32 => 0_u8,
AtomType::I32 => 1_u8,
AtomType::U64 => 2_u8,
AtomType::F32 => 3_u8,
AtomType::F16x2 => 4_u8,
AtomType::I64 => 5_u8,
AtomType::F64 => 6_u8,
},
);
}
}
}
@ -2990,7 +3018,11 @@ impl SM70Op for OpAtom {
match self.mem_space {
MemSpace::Global(_) => {
if self.dst.is_none() {
e.set_opcode(0x98e);
if e.sm >= 90 && self.atom_type.is_float() {
e.set_opcode(0x9a6);
} else {
e.set_opcode(0x98e);
}
e.set_reg_src(32..40, self.data);
e.set_atom_op(87..90, self.atom_op);
@ -3001,7 +3033,11 @@ impl SM70Op for OpAtom {
e.set_reg_src(32..40, self.cmpr);
e.set_reg_src(64..72, self.data);
} else {
e.set_opcode(0x3a8);
if e.sm >= 90 && self.atom_type.is_float() {
e.set_opcode(0x3a3);
} else {
e.set_opcode(0x3a8);
}
e.set_reg_src(32..40, self.data);
e.set_atom_op(87..91, self.atom_op);
@ -3037,6 +3073,10 @@ impl SM70Op for OpAtom {
|| self.atom_op == AtomOp::Exch,
"64-bit Shared atomics only support CmpExch or Exch"
);
assert!(
!self.atom_type.is_float(),
"Shared atomics don't support float"
);
e.set_atom_op(87..91, self.atom_op);
}