nak: add LDS/STS/ATOM address shift encoding

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39709>
This commit is contained in:
Karol Herbst 2025-09-17 14:28:39 +02:00
parent 18bf6fb96d
commit 20aa072ee5
9 changed files with 199 additions and 95 deletions

View file

@ -2941,6 +2941,7 @@ impl<'a> ShaderFromNir<'a> {
atom_op: atom_op,
atom_type: atom_type,
addr_offset: intrin.base(),
addr_stride: OffsetStride::X1,
mem_space: MemSpace::Global(MemAddrType::A64),
mem_order: MemOrder::Strong(MemScope::GPU),
mem_eviction_priority: MemEvictionPriority::Normal, // Note: no intrinic access
@ -2966,6 +2967,7 @@ impl<'a> ShaderFromNir<'a> {
atom_op: AtomOp::CmpExch(AtomCmpSrc::Separate),
atom_type: atom_type,
addr_offset: intrin.base(),
addr_stride: OffsetStride::X1,
mem_space: MemSpace::Global(MemAddrType::A64),
mem_order: MemOrder::Strong(MemScope::GPU),
mem_eviction_priority: MemEvictionPriority::Normal, // Note: no intrinic access
@ -3075,6 +3077,7 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.clone().into(),
addr: addr,
offset: intrin.base(),
stride: OffsetStride::X1,
access: access,
});
self.set_dst(&intrin.def, dst);
@ -3184,6 +3187,7 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.clone().into(),
addr: addr,
offset: intrin.base(),
stride: OffsetStride::X1,
access: access,
});
self.set_dst(&intrin.def, dst);
@ -3205,6 +3209,7 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.clone().into(),
addr: addr,
offset: intrin.base(),
stride: intrin.offset_shift_nv().try_into().unwrap(),
access: access,
});
self.set_dst(&intrin.def, dst);
@ -3527,6 +3532,7 @@ impl<'a> ShaderFromNir<'a> {
atom_op: atom_op,
atom_type: atom_type,
addr_offset: intrin.base(),
addr_stride: intrin.offset_shift_nv().try_into().unwrap(),
mem_space: MemSpace::Shared,
mem_order: MemOrder::Strong(MemScope::CTA),
mem_eviction_priority: MemEvictionPriority::Normal,
@ -3552,6 +3558,7 @@ impl<'a> ShaderFromNir<'a> {
atom_op: AtomOp::CmpExch(AtomCmpSrc::Separate),
atom_type: atom_type,
addr_offset: intrin.base(),
addr_stride: intrin.offset_shift_nv().try_into().unwrap(),
mem_space: MemSpace::Shared,
mem_order: MemOrder::Strong(MemScope::CTA),
mem_eviction_priority: MemEvictionPriority::Normal,
@ -3580,6 +3587,7 @@ impl<'a> ShaderFromNir<'a> {
addr: addr,
data: data,
offset: intrin.base(),
stride: OffsetStride::X1,
access: access,
});
}
@ -3610,6 +3618,7 @@ impl<'a> ShaderFromNir<'a> {
addr: addr,
data: data,
offset: intrin.base(),
stride: OffsetStride::X1,
access: access,
});
}
@ -3630,6 +3639,7 @@ impl<'a> ShaderFromNir<'a> {
addr: addr,
data: data,
offset: intrin.base(),
stride: intrin.offset_shift_nv().try_into().unwrap(),
access: access,
});
}

View file

@ -156,6 +156,7 @@ impl<'a> TestShaderBuilder<'a> {
addr: self.data_addr.clone().into(),
offset: offset.into(),
access: access,
stride: OffsetStride::X1,
});
dst
}
@ -179,6 +180,7 @@ impl<'a> TestShaderBuilder<'a> {
data: data.into(),
offset: offset.into(),
access: access,
stride: OffsetStride::X1,
});
}
@ -1739,6 +1741,7 @@ fn test_op_ldsm() {
order: MemOrder::Strong(MemScope::CTA),
eviction_priority: MemEvictionPriority::Normal,
},
stride: OffsetStride::X1,
});
b.push_op(OpMemBar {
scope: MemScope::CTA,

View file

@ -6417,6 +6417,40 @@ impl DisplayOp for OpSuStGa {
}
impl_display_for_op!(OpSuStGa);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OffsetStride {
X1 = 0,
X4 = 2,
X8 = 3,
X16 = 4,
}
impl fmt::Display for OffsetStride {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::X1 => return Ok(()),
Self::X4 => ".x4",
Self::X8 => ".x8",
Self::X16 => ".x16",
};
write!(f, "{s}")
}
}
impl TryFrom<u8> for OffsetStride {
type Error = &'static str;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::X1),
2 => Ok(Self::X4),
3 => Ok(Self::X8),
4 => Ok(Self::X16),
_ => Err("Unknown LdSt shift value {value}"),
}
}
}
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpLd {
@ -6426,12 +6460,13 @@ pub struct OpLd {
pub addr: Src,
pub offset: i32,
pub stride: OffsetStride,
pub access: MemAccess,
}
impl DisplayOp for OpLd {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ld{} [{}", self.access, self.addr)?;
write!(f, "ld{} [{}{}", self.access, self.addr, self.stride)?;
if self.offset > 0 {
write!(f, "+{:#x}", self.offset)?;
}
@ -6577,12 +6612,13 @@ pub struct OpSt {
pub data: Src,
pub offset: i32,
pub stride: OffsetStride,
pub access: MemAccess,
}
impl DisplayOp for OpSt {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "st{} [{}", self.access, self.addr)?;
write!(f, "st{} [{}{}", self.access, self.addr, self.stride)?;
if self.offset > 0 {
write!(f, "+{:#x}", self.offset)?;
}
@ -6638,6 +6674,7 @@ pub struct OpAtom {
pub atom_type: AtomType,
pub addr_offset: i32,
pub addr_stride: OffsetStride,
pub mem_space: MemSpace,
pub mem_order: MemOrder,
@ -6657,7 +6694,7 @@ impl DisplayOp for OpAtom {
)?;
write!(f, " [")?;
if !self.addr.is_zero() {
write!(f, "{}", self.addr)?;
write!(f, "{}{}", self.addr, self.addr_stride)?;
}
if self.addr_offset > 0 {
if !self.addr.is_zero() {

View file

@ -96,6 +96,7 @@ impl LowerCopySwap {
dst: copy.dst,
addr: Src::ZERO,
offset: addr.try_into().unwrap(),
stride: OffsetStride::X1,
access: access,
});
}
@ -175,6 +176,7 @@ impl LowerCopySwap {
addr: Src::ZERO,
data: copy.src,
offset: addr.try_into().unwrap(),
stride: OffsetStride::X1,
access: access,
});
}

View file

@ -315,105 +315,124 @@ pub fn test_ld_st_atom() {
for space in spaces {
for (addr_offset, addr_offset_str) in [(0x12, "0x12"), (-1, "-0x1")]
{
let cta = if sm >= 80 { "sm" } else { "cta" };
for addr_stride in [OffsetStride::X1, OffsetStride::X8] {
let cta = if sm >= 80 { "sm" } else { "cta" };
let pri = match space {
MemSpace::Global(_) => MemEvictionPriority::First,
MemSpace::Shared => MemEvictionPriority::Normal,
MemSpace::Local => MemEvictionPriority::Normal,
};
let access = MemAccess {
mem_type: MemType::B32,
space,
order: order,
eviction_priority: pri,
};
let instr = OpLd {
dst: Dst::Reg(r0),
addr: SrcRef::Reg(r1).into(),
offset: addr_offset,
access: access.clone(),
};
let expected = match space {
MemSpace::Global(_) => {
format!(
"ldg.e.ef.strong.{cta} r0, [r1+{addr_offset_str}];"
)
let pri = match space {
MemSpace::Global(_) => MemEvictionPriority::First,
MemSpace::Shared => MemEvictionPriority::Normal,
MemSpace::Local => MemEvictionPriority::Normal,
};
if space != MemSpace::Shared
&& addr_stride != OffsetStride::X1
{
continue;
}
MemSpace::Shared => {
format!("lds r0, [r1+{addr_offset_str}];")
}
MemSpace::Local => {
format!("ldl r0, [r1+{addr_offset_str}];")
}
};
c.push(instr, expected);
let access = MemAccess {
mem_type: MemType::B32,
space,
order: order,
eviction_priority: pri,
};
let instr = OpSt {
addr: SrcRef::Reg(r1).into(),
data: SrcRef::Reg(r2).into(),
offset: addr_offset,
access: access.clone(),
};
let expected = match space {
MemSpace::Global(_) => {
format!(
"stg.e.ef.strong.{cta} [r1+{addr_offset_str}], r2;"
)
}
MemSpace::Shared => {
format!("sts [r1+{addr_offset_str}], r2;")
}
MemSpace::Local => {
format!("stl [r1+{addr_offset_str}], r2;")
}
};
c.push(instr, expected);
let instr = OpLd {
dst: Dst::Reg(r0),
addr: SrcRef::Reg(r1).into(),
offset: addr_offset,
access: access.clone(),
stride: addr_stride,
};
let expected = match space {
MemSpace::Global(_) => {
format!(
"ldg.e.ef.strong.{cta} r0, [r1+{addr_offset_str}];"
)
}
MemSpace::Shared => {
format!(
"lds r0, [r1{addr_stride}+{addr_offset_str}];"
)
}
MemSpace::Local => {
format!("ldl r0, [r1+{addr_offset_str}];")
}
};
c.push(instr, expected);
for (atom_type, atom_type_str) in atom_types {
for use_dst in [true, false] {
let instr = OpAtom {
dst: if use_dst { Dst::Reg(r0) } else { Dst::None },
addr: SrcRef::Reg(r1).into(),
data: SrcRef::Reg(r2).into(),
atom_op: AtomOp::Add,
cmpr: SrcRef::Reg(r3).into(),
atom_type,
let instr = OpSt {
addr: SrcRef::Reg(r1).into(),
data: SrcRef::Reg(r2).into(),
offset: addr_offset,
access: access.clone(),
stride: addr_stride,
};
let expected = match space {
MemSpace::Global(_) => {
format!(
"stg.e.ef.strong.{cta} [r1+{addr_offset_str}], r2;"
)
}
MemSpace::Shared => {
format!(
"sts [r1{addr_stride}+{addr_offset_str}], r2;"
)
}
MemSpace::Local => {
format!("stl [r1+{addr_offset_str}], r2;")
}
};
c.push(instr, expected);
addr_offset,
mem_space: space,
mem_order: order,
mem_eviction_priority: pri,
};
let expected = match space {
MemSpace::Global(_) => {
let op = if use_dst {
"atomg"
} else if sm >= 90 {
"redg"
for (atom_type, atom_type_str) in atom_types {
for use_dst in [true, false] {
let instr = OpAtom {
dst: if use_dst {
Dst::Reg(r0)
} else {
"red"
};
let dst = if use_dst { "pt, r0, " } else { "" };
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[r1+{addr_offset_str}], r2;")
}
MemSpace::Shared => {
if atom_type.is_float() {
continue;
}
if atom_type.bits() == 64 {
continue;
}
let dst = if use_dst { "r0" } else { "rz" };
format!("atoms.add{atom_type_str} {dst}, [r1+{addr_offset_str}], r2;")
}
MemSpace::Local => continue,
};
Dst::None
},
addr: SrcRef::Reg(r1).into(),
data: SrcRef::Reg(r2).into(),
atom_op: AtomOp::Add,
cmpr: SrcRef::Reg(r3).into(),
atom_type,
c.push(instr, expected);
addr_offset,
addr_stride: addr_stride,
mem_space: space,
mem_order: order,
mem_eviction_priority: pri,
};
let expected = match space {
MemSpace::Global(_) => {
let op = if use_dst {
"atomg"
} else if sm >= 90 {
"redg"
} else {
"red"
};
let dst =
if use_dst { "pt, r0, " } else { "" };
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[r1+{addr_offset_str}], r2;")
}
MemSpace::Shared => {
if atom_type.is_float() {
continue;
}
if atom_type.bits() == 64 {
continue;
}
let dst = if use_dst { "r0" } else { "rz" };
format!("atoms.add{atom_type_str} {dst}, [r1{addr_stride}+{addr_offset_str}], r2;")
}
MemSpace::Local => continue,
};
c.push(instr, expected);
}
}
}
}

View file

@ -2309,6 +2309,7 @@ impl SM20Op for OpLd {
}
fn encode(&self, e: &mut SM20Encoder<'_>) {
assert_eq!(self.stride, OffsetStride::X1);
match self.access.space {
MemSpace::Global(addr_type) => {
e.set_opcode(SM20Unit::Mem, 0x20);
@ -2388,6 +2389,7 @@ impl SM20Op for OpSt {
}
fn encode(&self, e: &mut SM20Encoder<'_>) {
assert_eq!(self.stride, OffsetStride::X1);
match self.access.space {
MemSpace::Global(addr_type) => {
e.set_opcode(SM20Unit::Mem, 0x24);
@ -2472,6 +2474,7 @@ impl SM20Op for OpAtom {
panic!("SM20 only supports global atomics");
};
assert!(addr_type == MemAddrType::A64);
assert_eq!(self.addr_stride, OffsetStride::X1);
if self.dst.is_none() {
e.set_opcode(SM20Unit::Mem, 0x1);

View file

@ -2549,6 +2549,7 @@ impl SM32Op for OpLd {
}
fn encode(&self, e: &mut SM32Encoder<'_>) {
assert_eq!(self.stride, OffsetStride::X1);
// Missing:
// 0x7c8 for indirect const load
match self.access.space {
@ -2633,6 +2634,7 @@ impl SM32Op for OpSt {
}
fn encode(&self, e: &mut SM32Encoder<'_>) {
assert_eq!(self.stride, OffsetStride::X1);
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0xe00, 0);
@ -2739,6 +2741,7 @@ impl SM32Op for OpAtom {
}
fn encode(&self, e: &mut SM32Encoder<'_>) {
assert_eq!(self.addr_stride, OffsetStride::X1);
match self.mem_space {
MemSpace::Global(addr_type) => {
if let AtomOp::CmpExch(cmp_src) = self.atom_op {

View file

@ -2597,6 +2597,7 @@ impl SM50Op for OpLd {
}
fn encode(&self, e: &mut SM50Encoder<'_>) {
assert_eq!(self.stride, OffsetStride::X1);
e.set_opcode(match self.access.space {
MemSpace::Global(_) => 0xeed0,
MemSpace::Local => 0xef40,
@ -2652,6 +2653,7 @@ impl SM50Op for OpSt {
}
fn encode(&self, e: &mut SM50Encoder<'_>) {
assert_eq!(self.stride, OffsetStride::X1);
e.set_opcode(match self.access.space {
MemSpace::Global(_) => 0xeed8,
MemSpace::Local => 0xef50,
@ -2707,6 +2709,7 @@ impl SM50Op for OpAtom {
}
fn encode(&self, e: &mut SM50Encoder<'_>) {
assert_eq!(self.addr_stride, OffsetStride::X1);
match self.mem_space {
MemSpace::Global(addr_type) => {
if self.dst.is_none() {

View file

@ -377,6 +377,17 @@ impl ALUSrc {
}
}
impl OffsetStride {
fn encode_sm75(&self) -> u8 {
match self {
Self::X1 => 0,
Self::X4 => 1,
Self::X8 => 2,
Self::X16 => 3,
}
}
}
impl SM70Encoder<'_> {
fn set_swizzle(&mut self, range: Range<usize>, swizzle: SrcSwizzle) {
assert!(range.len() == 2);
@ -3049,10 +3060,12 @@ impl SM70Op for OpLd {
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0x381);
assert_eq!(self.stride, OffsetStride::X1);
e.set_pred_dst(81..84, &Dst::None);
e.set_mem_access(&self.access);
}
MemSpace::Local => {
assert_eq!(self.stride, OffsetStride::X1);
e.set_opcode(0x983);
e.set_field(84..87, 1_u8);
@ -3073,6 +3086,8 @@ impl SM70Op for OpLd {
== MemEvictionPriority::Normal
);
assert!(e.sm >= 75 || self.stride == OffsetStride::X1);
e.set_field(78..80, self.stride.encode_sm75());
e.set_bit(87, false); // !.ZD - Returns a predicate?
}
}
@ -3182,10 +3197,12 @@ impl SM70Op for OpSt {
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0x386);
assert_eq!(self.stride, OffsetStride::X1);
e.set_mem_access(&self.access);
}
MemSpace::Local => {
e.set_opcode(0x387);
assert_eq!(self.stride, OffsetStride::X1);
e.set_field(84..87, 1_u8);
e.set_mem_type(73..76, self.access.mem_type);
@ -3204,6 +3221,9 @@ impl SM70Op for OpSt {
self.access.eviction_priority
== MemEvictionPriority::Normal
);
assert!(e.sm >= 75 || self.stride == OffsetStride::X1);
e.set_field(78..80, self.stride.encode_sm75());
}
}
@ -3322,6 +3342,7 @@ impl SM70Op for OpAtom {
e.set_mem_order(&self.mem_order);
e.set_eviction_priority(&self.mem_eviction_priority);
assert_eq!(self.addr_stride, OffsetStride::X1);
}
MemSpace::Local => panic!("Atomics do not support local"),
MemSpace::Shared => {
@ -3347,6 +3368,9 @@ impl SM70Op for OpAtom {
e.set_atom_op(87..91, self.atom_op);
}
assert!(e.sm >= 75 || self.addr_stride == OffsetStride::X1);
e.set_field(78..80, self.addr_stride.encode_sm75());
assert!(self.mem_order == MemOrder::Strong(MemScope::CTA));
assert!(
self.mem_eviction_priority == MemEvictionPriority::Normal