From f70b7d10c2e6435702641db4dd98e34841bf64e0 Mon Sep 17 00:00:00 2001 From: Mel Henning Date: Tue, 1 Apr 2025 20:25:50 -0400 Subject: [PATCH] nak: Fix sm90+ atomg/redg encoding and add a test for ld, st, atom Part-of: --- src/nouveau/compiler/nak/ir.rs | 9 ++ src/nouveau/compiler/nak/nvdisasm_tests.rs | 137 +++++++++++++++++++++ src/nouveau/compiler/nak/sm70_encode.rs | 68 +++++++--- 3 files changed, 200 insertions(+), 14 deletions(-) diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 373c913a8cb..a4ab1deb114 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -2523,6 +2523,15 @@ impl AtomType { AtomType::U64 | AtomType::I64 | AtomType::F64 => 64, } } + + pub fn is_float(&self) -> bool { + match self { + AtomType::F16x2 | AtomType::F32 | AtomType::F64 => true, + AtomType::U32 | AtomType::I32 | AtomType::U64 | AtomType::I64 => { + false + } + } + } } impl fmt::Display for AtomType { diff --git a/src/nouveau/compiler/nak/nvdisasm_tests.rs b/src/nouveau/compiler/nak/nvdisasm_tests.rs index 1585871d2d8..f1c120c83b8 100644 --- a/src/nouveau/compiler/nak/nvdisasm_tests.rs +++ b/src/nouveau/compiler/nak/nvdisasm_tests.rs @@ -168,3 +168,140 @@ pub fn test_nop() { c.check(sm); } } + +#[test] +pub fn test_ld_st_atom() { + let r0 = RegRef::new(RegFile::GPR, 0, 1); + let r1 = RegRef::new(RegFile::GPR, 1, 1); + let r2 = RegRef::new(RegFile::GPR, 2, 1); + let r3 = RegRef::new(RegFile::GPR, 3, 1); + + let order = MemOrder::Strong(MemScope::CTA); + + let atom_types = [ + (AtomType::F16x2, ".f16x2.rn"), + (AtomType::U32, ""), + (AtomType::I32, ".s32"), + (AtomType::F32, ".f32.ftz.rn"), + (AtomType::U64, ".64"), + (AtomType::I64, ".s64"), + (AtomType::F64, ".f64.rn"), + ]; + + let spaces = [ + MemSpace::Global(MemAddrType::A64), + MemSpace::Shared, + MemSpace::Local, + ]; + + for sm in SM_LIST { + let mut c = DisasmCheck::new(); + for space in spaces { + for (addr_offset, addr_offset_str) in [(0x12, "0x12"), (-1, "-0x1")] + { + let cta = if sm >= 80 { "sm" } else { "cta" }; + + let pri = match space { + MemSpace::Global(_) => MemEvictionPriority::First, + MemSpace::Shared => MemEvictionPriority::Normal, + MemSpace::Local => MemEvictionPriority::Normal, + }; + let access = MemAccess { + mem_type: MemType::B32, + space, + order: order, + eviction_priority: pri, + }; + + let instr = OpLd { + dst: Dst::Reg(r0), + addr: SrcRef::Reg(r1).into(), + offset: addr_offset, + access: access.clone(), + }; + let expected = match space { + MemSpace::Global(_) => { + format!( + "ldg.e.ef.strong.{cta} r0, [r1+{addr_offset_str}];" + ) + } + MemSpace::Shared => { + format!("lds r0, [r1+{addr_offset_str}];") + } + MemSpace::Local => { + format!("ldl r0, [r1+{addr_offset_str}];") + } + }; + c.push(instr, expected); + + let instr = OpSt { + addr: SrcRef::Reg(r1).into(), + data: SrcRef::Reg(r2).into(), + offset: addr_offset, + access: access.clone(), + }; + let expected = match space { + MemSpace::Global(_) => { + format!( + "stg.e.ef.strong.{cta} [r1+{addr_offset_str}], r2;" + ) + } + MemSpace::Shared => { + format!("sts [r1+{addr_offset_str}], r2;") + } + MemSpace::Local => { + format!("stl [r1+{addr_offset_str}], r2;") + } + }; + c.push(instr, expected); + + for (atom_type, atom_type_str) in atom_types { + for use_dst in [true, false] { + let instr = OpAtom { + dst: if use_dst { Dst::Reg(r0) } else { Dst::None }, + addr: SrcRef::Reg(r1).into(), + data: SrcRef::Reg(r2).into(), + atom_op: AtomOp::Add, + cmpr: SrcRef::Reg(r3).into(), + atom_type, + + addr_offset, + + mem_space: space, + mem_order: order, + mem_eviction_priority: pri, + }; + + let expected = match space { + MemSpace::Global(_) => { + let op = if use_dst { + "atomg" + } else if sm >= 90 { + "redg" + } else { + "red" + }; + let dst = if use_dst { "pt, r0, " } else { "" }; + format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[r1+{addr_offset_str}], r2;") + } + MemSpace::Shared => { + if atom_type.is_float() { + continue; + } + if atom_type.bits() == 64 { + continue; + } + let dst = if use_dst { "r0" } else { "rz" }; + format!("atoms.add{atom_type_str} {dst}, [r1+{addr_offset_str}], r2;") + } + MemSpace::Local => continue, + }; + + c.push(instr, expected); + } + } + } + } + c.check(sm); + } +} diff --git a/src/nouveau/compiler/nak/sm70_encode.rs b/src/nouveau/compiler/nak/sm70_encode.rs index ce18748326a..57de2130ee5 100644 --- a/src/nouveau/compiler/nak/sm70_encode.rs +++ b/src/nouveau/compiler/nak/sm70_encode.rs @@ -2966,18 +2966,46 @@ impl SM70Encoder<'_> { } fn set_atom_type(&mut self, atom_type: AtomType) { - self.set_field( - 73..76, - match atom_type { - AtomType::U32 => 0_u8, - AtomType::I32 => 1_u8, - AtomType::U64 => 2_u8, - AtomType::F32 => 3_u8, - AtomType::F16x2 => 4_u8, - AtomType::I64 => 5_u8, - AtomType::F64 => 6_u8, - }, - ); + if self.sm >= 90 { + // Float/int is differentiated by opcode + self.set_field( + 73..77, + match atom_type { + AtomType::F16x2 => 0_u8, + // f16x4 => 1 + // f16x8 => 2 + // bf16x2 => 3 + // bf16x4 => 4 + // bf16x8 => 5 + AtomType::F32 => 9_u8, // .ftz + // f32x2.ftz => 10 + // f32x4.ftz => 11 + // f32x1 => 12 + // f32x2 => 13 + // f32x4 => 14 + AtomType::F64 => 15_u8, + + AtomType::U32 => 0, + AtomType::I32 => 1, + AtomType::U64 => 2, + AtomType::I64 => 3, + // u128 => 4, + }, + ); + } else { + self.set_field( + 73..76, + match atom_type { + AtomType::U32 => 0_u8, + AtomType::I32 => 1_u8, + AtomType::U64 => 2_u8, + AtomType::F32 => 3_u8, + AtomType::F16x2 => 4_u8, + AtomType::I64 => 5_u8, + AtomType::F64 => 6_u8, + }, + ); + } } } @@ -2990,7 +3018,11 @@ impl SM70Op for OpAtom { match self.mem_space { MemSpace::Global(_) => { if self.dst.is_none() { - e.set_opcode(0x98e); + if e.sm >= 90 && self.atom_type.is_float() { + e.set_opcode(0x9a6); + } else { + e.set_opcode(0x98e); + } e.set_reg_src(32..40, self.data); e.set_atom_op(87..90, self.atom_op); @@ -3001,7 +3033,11 @@ impl SM70Op for OpAtom { e.set_reg_src(32..40, self.cmpr); e.set_reg_src(64..72, self.data); } else { - e.set_opcode(0x3a8); + if e.sm >= 90 && self.atom_type.is_float() { + e.set_opcode(0x3a3); + } else { + e.set_opcode(0x3a8); + } e.set_reg_src(32..40, self.data); e.set_atom_op(87..91, self.atom_op); @@ -3037,6 +3073,10 @@ impl SM70Op for OpAtom { || self.atom_op == AtomOp::Exch, "64-bit Shared atomics only support CmpExch or Exch" ); + assert!( + !self.atom_type.is_float(), + "Shared atomics don't support float" + ); e.set_atom_op(87..91, self.atom_op); }