nak: wire up UGPR Ld/St/Atom encoding

This commit is contained in:
Karol Herbst 2026-05-03 11:46:12 +02:00 committed by Karol Herbst
parent 53bfdb400c
commit e639aa342d
6 changed files with 170 additions and 23 deletions

View file

@ -2992,6 +2992,7 @@ impl<'a> ShaderFromNir<'a> {
dst.clone().into()
},
addr: addr,
uniform_address: Src::ZERO,
cmpr: 0.into(),
data: data,
atom_op: atom_op,
@ -3018,6 +3019,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpAtom {
dst: dst.clone().into(),
addr: addr,
uniform_address: Src::ZERO,
cmpr: cmpr,
data: data,
atom_op: AtomOp::CmpExch(AtomCmpSrc::Separate),
@ -3224,6 +3226,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpLd {
dst: dst.clone().into(),
addr: addr,
uniform_addr: Src::ZERO,
pred: pred,
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3335,6 +3338,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpLd {
dst: dst.clone().into(),
addr: addr,
uniform_addr: Src::ZERO,
pred: true.into(),
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3358,6 +3362,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpLd {
dst: dst.clone().into(),
addr: addr,
uniform_addr: Src::ZERO,
pred: true.into(),
offset: intrin.base(),
stride: intrin.offset_shift_nv().try_into().unwrap(),
@ -3678,6 +3683,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpAtom {
dst: dst.clone().into(),
addr: addr,
uniform_address: Src::ZERO,
cmpr: 0.into(),
data: data,
atom_op: atom_op,
@ -3704,6 +3710,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpAtom {
dst: dst.clone().into(),
addr: addr,
uniform_address: Src::ZERO,
cmpr: cmpr,
data: data,
atom_op: AtomOp::CmpExch(AtomCmpSrc::Separate),
@ -3736,6 +3743,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpSt {
addr: addr,
uniform_addr: Src::ZERO,
data: data,
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3767,6 +3775,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpSt {
addr: addr,
uniform_addr: Src::ZERO,
data: data,
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3788,6 +3797,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpSt {
addr: addr,
uniform_addr: Src::ZERO,
data: data,
offset: intrin.base(),
stride: intrin.offset_shift_nv().try_into().unwrap(),
@ -3907,6 +3917,7 @@ impl<'a> ShaderFromNir<'a> {
mat_size,
mat_count,
addr,
uniform_addr: Src::ZERO,
offset: intrin.base(),
});
self.set_dst(&intrin.def, dst);

View file

@ -154,6 +154,7 @@ impl<'a> TestShaderBuilder<'a> {
self.push_op(OpLd {
dst: dst.clone().into(),
addr: self.data_addr.clone().into(),
uniform_addr: Src::ZERO,
pred: true.into(),
offset: offset.into(),
access: access,
@ -178,6 +179,7 @@ impl<'a> TestShaderBuilder<'a> {
assert!(data.comps() == comps);
self.push_op(OpSt {
addr: self.data_addr.clone().into(),
uniform_addr: Src::ZERO,
data: data.into(),
offset: offset.into(),
access: access,
@ -1734,6 +1736,7 @@ fn test_op_ldsm() {
let offset = b.imul(lane_id.into(), 16.into());
b.push_op(OpSt {
addr: offset.into(),
uniform_addr: Src::ZERO,
data: input.into(),
offset: 0,
access: MemAccess {
@ -1755,6 +1758,7 @@ fn test_op_ldsm() {
mat_size: LdsmSize::M8N8,
mat_count: 4,
addr: addr.into(),
uniform_addr: Src::ZERO,
offset: 0,
});
b.st_test_data(16, MemType::B128, res);

View file

@ -6502,6 +6502,9 @@ pub struct OpLd {
#[src_type(GPR)]
pub addr: Src,
#[src_type(GPR)]
pub uniform_addr: Src,
/// On false the load returns 0
#[src_type(Pred)]
pub pred: Src,
@ -6513,7 +6516,11 @@ pub struct OpLd {
impl DisplayOp for OpLd {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ld{} [{}{}", self.access, self.addr, self.stride)?;
write!(
f,
"ld{} [{}{}+{}",
self.access, self.addr, self.stride, self.uniform_addr
)?;
if self.offset > 0 {
write!(f, "+{:#x}", self.offset)?;
}
@ -6602,6 +6609,9 @@ pub struct OpLdsm {
#[src_type(SSA)]
pub addr: Src,
#[src_type(SSA)]
pub uniform_addr: Src,
pub offset: i32,
}
@ -6658,6 +6668,9 @@ pub struct OpSt {
#[src_type(SSA)]
pub data: Src,
#[src_type(GPR)]
pub uniform_addr: Src,
pub offset: i32,
pub stride: OffsetStride,
pub access: MemAccess,
@ -6665,7 +6678,11 @@ pub struct OpSt {
impl DisplayOp for OpSt {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "st{} [{}{}", self.access, self.addr, self.stride)?;
write!(
f,
"st{} [{}{}+{}",
self.access, self.addr, self.stride, self.uniform_addr
)?;
if self.offset > 0 {
write!(f, "+{:#x}", self.offset)?;
}
@ -6711,6 +6728,9 @@ pub struct OpAtom {
#[src_type(GPR)]
pub addr: Src,
#[src_type(GPR)]
pub uniform_address: Src,
#[src_type(GPR)]
pub cmpr: Src,
@ -6743,10 +6763,16 @@ impl DisplayOp for OpAtom {
if !self.addr.is_zero() {
write!(f, "{}{}", self.addr, self.addr_stride)?;
}
if self.addr_offset > 0 {
if !self.uniform_address.is_zero() {
if !self.addr.is_zero() {
write!(f, "+")?;
}
write!(f, "{}", self.uniform_address)?;
}
if self.addr_offset > 0 {
if !self.addr.is_zero() || !self.uniform_address.is_zero() {
write!(f, "+")?;
}
write!(f, "{:#x}", self.addr_offset)?;
}
write!(f, "]")?;

View file

@ -95,6 +95,7 @@ impl LowerCopySwap {
b.push_op(OpLd {
dst: copy.dst,
addr: Src::ZERO,
uniform_addr: Src::ZERO,
pred: true.into(),
offset: addr.try_into().unwrap(),
stride: OffsetStride::X1,
@ -175,6 +176,7 @@ impl LowerCopySwap {
self.slm_size = max(self.slm_size, addr + 4);
b.push_op(OpSt {
addr: Src::ZERO,
uniform_addr: Src::ZERO,
data: copy.src,
offset: addr.try_into().unwrap(),
stride: OffsetStride::X1,

View file

@ -292,6 +292,7 @@ pub fn test_ld_st_atom() {
let r2 = RegRef::new(RegFile::GPR, 2, 1);
let r3 = RegRef::new(RegFile::GPR, 3, 1);
let p4 = RegRef::new(RegFile::Pred, 4, 1);
let ur2 = RegRef::new(RegFile::UGPR, 2, 2);
let order = MemOrder::Strong(MemScope::CTA);
@ -318,6 +319,18 @@ pub fn test_ld_st_atom() {
{
for addr_stride in [OffsetStride::X1, OffsetStride::X8] {
let cta = if sm >= 80 { "sm" } else { "cta" };
let r1_str =
if sm >= 75 && matches!(space, MemSpace::Global(_)) {
"r1.64"
} else {
"r1"
};
let urz = if sm >= 73 {
SrcRef::Reg(ur2).into()
} else {
Src::ZERO
};
let urz_str = if sm >= 73 { "+ur2" } else { "" };
let pri = match space {
MemSpace::Global(_) => MemEvictionPriority::First,
@ -339,6 +352,7 @@ pub fn test_ld_st_atom() {
let instr = OpLd {
dst: Dst::Reg(r0),
addr: SrcRef::Reg(r1).into(),
uniform_addr: urz.clone(),
pred: if matches!(space, MemSpace::Global(_))
&& sm >= 73
{
@ -353,7 +367,7 @@ pub fn test_ld_st_atom() {
let expected = match space {
MemSpace::Global(_) if sm >= 73 => {
format!(
"ldg.e.ef.strong.{cta} r0, [r1+{addr_offset_str}], p4;"
"ldg.e.ef.strong.{cta} r0, [{r1_str}{urz_str}+{addr_offset_str}], p4;"
)
}
MemSpace::Global(_) => {
@ -363,17 +377,20 @@ pub fn test_ld_st_atom() {
}
MemSpace::Shared => {
format!(
"lds r0, [r1{addr_stride}+{addr_offset_str}];"
"lds r0, [{r1_str}{addr_stride}{urz_str}+{addr_offset_str}];"
)
}
MemSpace::Local => {
format!("ldl r0, [r1+{addr_offset_str}];")
format!(
"ldl r0, [{r1_str}{urz_str}+{addr_offset_str}];"
)
}
};
c.push(instr, expected);
let instr = OpSt {
addr: SrcRef::Reg(r1).into(),
uniform_addr: urz.clone(),
data: SrcRef::Reg(r2).into(),
offset: addr_offset,
access: access.clone(),
@ -382,16 +399,18 @@ pub fn test_ld_st_atom() {
let expected = match space {
MemSpace::Global(_) => {
format!(
"stg.e.ef.strong.{cta} [r1+{addr_offset_str}], r2;"
"stg.e.ef.strong.{cta} [{r1_str}{urz_str}+{addr_offset_str}], r2;"
)
}
MemSpace::Shared => {
format!(
"sts [r1{addr_stride}+{addr_offset_str}], r2;"
"sts [{r1_str}{addr_stride}{urz_str}+{addr_offset_str}], r2;"
)
}
MemSpace::Local => {
format!("stl [r1+{addr_offset_str}], r2;")
format!(
"stl [{r1_str}{urz_str}+{addr_offset_str}], r2;"
)
}
};
c.push(instr, expected);
@ -405,6 +424,7 @@ pub fn test_ld_st_atom() {
Dst::None
},
addr: SrcRef::Reg(r1).into(),
uniform_address: urz.clone(),
data: SrcRef::Reg(r2).into(),
atom_op: AtomOp::Add,
cmpr: SrcRef::Reg(r3).into(),
@ -429,7 +449,7 @@ pub fn test_ld_st_atom() {
};
let dst =
if use_dst { "pt, r0, " } else { "" };
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[r1+{addr_offset_str}], r2;")
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[{r1_str}{urz_str}+{addr_offset_str}], r2;")
}
MemSpace::Shared => {
if atom_type.is_float() {
@ -439,7 +459,7 @@ pub fn test_ld_st_atom() {
continue;
}
let dst = if use_dst { "r0" } else { "rz" };
format!("atoms.add{atom_type_str} {dst}, [r1{addr_stride}+{addr_offset_str}], r2;")
format!("atoms.add{atom_type_str} {dst}, [{r1_str}{addr_stride}{urz_str}+{addr_offset_str}], r2;")
}
MemSpace::Local => continue,
};

View file

@ -3170,15 +3170,21 @@ impl SM70Op for OpLd {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
let has_ugpr = e.sm >= 73;
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0x381);
assert_eq!(self.stride, OffsetStride::X1);
if e.sm >= 73 {
if has_ugpr {
e.set_opcode(0x981);
e.set_reg_addr(24..32, &self.addr, 90);
e.set_ureg_addr(32, &self.uniform_addr, 72);
e.set_rev_pred_src(64..67, 67, &self.pred);
} else {
assert!(self.pred.is_true());
e.set_opcode(0x381);
e.set_reg_addr(24..32, &self.addr, 72);
}
e.set_pred_dst(81..84, &Dst::None);
e.set_mem_access(&self.access);
}
@ -3186,6 +3192,10 @@ impl SM70Op for OpLd {
assert!(self.pred.is_true());
assert_eq!(self.stride, OffsetStride::X1);
e.set_opcode(0x983);
e.set_reg_src(24..32, &self.addr);
if has_ugpr {
e.set_ureg_src(32, &self.uniform_addr);
}
e.set_field(84..87, 1_u8);
e.set_mem_type(73..76, self.access.mem_type);
@ -3199,6 +3209,10 @@ impl SM70Op for OpLd {
e.set_opcode(0x984);
assert!(self.pred.is_true());
e.set_reg_src(24..32, &self.addr);
if has_ugpr {
e.set_ureg_src(32, &self.uniform_addr);
}
e.set_mem_type(73..76, self.access.mem_type);
assert!(self.access.order == MemOrder::Strong(MemScope::CTA));
assert!(
@ -3213,8 +3227,11 @@ impl SM70Op for OpLd {
}
e.set_dst(&self.dst);
e.set_reg_addr(24..32, &self.addr, 72);
e.set_field(40..64, self.offset);
// We always enable UGPR mode, because the .E bit changes
// which source it applies to depending on it.
// This way it always applies to the UGPR.
e.set_bit(91, has_ugpr);
}
}
@ -3315,15 +3332,30 @@ impl SM70Op for OpSt {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
let has_ugpr = e.sm >= 75;
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0x386);
assert_eq!(self.stride, OffsetStride::X1);
if has_ugpr {
e.set_opcode(0x986);
e.set_reg_addr(24..32, &self.addr, 90);
e.set_ureg_addr(64, &self.uniform_addr, 72);
} else {
e.set_opcode(0x386);
e.set_reg_addr(24..32, &self.addr, 72);
}
e.set_mem_access(&self.access);
}
MemSpace::Local => {
e.set_opcode(0x387);
assert_eq!(self.stride, OffsetStride::X1);
if has_ugpr {
e.set_opcode(0x987);
e.set_reg_src(24..32, &self.addr);
e.set_ureg_src(64, &self.uniform_addr);
} else {
e.set_opcode(0x387);
e.set_reg_src(24..32, &self.addr);
}
e.set_field(84..87, 1_u8);
e.set_mem_type(73..76, self.access.mem_type);
@ -3334,7 +3366,14 @@ impl SM70Op for OpSt {
);
}
MemSpace::Shared => {
e.set_opcode(0x388);
if has_ugpr {
e.set_opcode(0x988);
e.set_reg_src(24..32, &self.addr);
e.set_ureg_src(64, &self.uniform_addr);
} else {
e.set_opcode(0x388);
e.set_reg_src(24..32, &self.addr);
}
e.set_mem_type(73..76, self.access.mem_type);
assert!(self.access.order == MemOrder::Strong(MemScope::CTA));
@ -3348,9 +3387,12 @@ impl SM70Op for OpSt {
}
}
e.set_reg_addr(24..32, &self.addr, 72);
e.set_reg_src(32..40, &self.data);
e.set_field(40..64, self.offset);
// We always enable UGPR mode, because the .E bit changes
// which source it applies to depending on it.
// This way it always applies to the UGPR.
e.set_bit(91, has_ugpr);
}
}
@ -3425,6 +3467,7 @@ impl SM70Op for OpAtom {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
let has_ugpr = e.sm >= 75;
match self.mem_space {
MemSpace::Global(_) => {
if self.dst.is_none() {
@ -3435,24 +3478,56 @@ impl SM70Op for OpAtom {
}
e.set_reg_src(32..40, &self.data);
e.set_field(40..64, self.addr_offset);
e.set_atom_op(87..90, self.atom_op);
if has_ugpr {
e.set_reg_addr(24..32, &self.addr, 90);
e.set_ureg_addr(64, &self.uniform_address, 72);
e.set_bit(91, true);
} else {
e.set_reg_addr(24..32, &self.addr, 72);
assert!(self.uniform_address.is_zero());
}
} else if let AtomOp::CmpExch(cmp_src) = self.atom_op {
e.set_opcode(0x3a9);
assert!(cmp_src == AtomCmpSrc::Separate);
assert!(self.uniform_address.is_zero());
e.set_reg_addr(24..32, &self.addr, 72);
e.set_reg_src(32..40, &self.cmpr);
e.set_field(40..64, self.addr_offset);
e.set_reg_src(64..72, &self.data);
e.set_pred_dst(81..84, &Dst::None);
} else {
if e.sm >= 90 && self.atom_type.is_float() {
e.set_opcode(0x3a3);
e.set_opcode(0x9a3);
} else if has_ugpr {
e.set_opcode(0x9a8);
} else {
e.set_opcode(0x3a8);
}
if e.sm >= 100 {
e.set_reg_addr(24..32, &self.addr, 63);
e.set_ureg_addr(64, &self.uniform_address, 72);
} else if has_ugpr {
e.set_reg_addr(24..32, &self.addr, 70);
e.set_ureg_addr(64, &self.uniform_address, 72);
} else {
e.set_reg_addr(24..32, &self.addr, 72);
assert!(self.uniform_address.is_zero());
};
if e.sm >= 100 {
e.set_field(40..63, self.addr_offset);
} else {
e.set_field(40..64, self.addr_offset);
};
e.set_reg_src(32..40, &self.data);
e.set_pred_dst(81..84, &Dst::None);
e.set_atom_op(87..91, self.atom_op);
e.set_bit(91, has_ugpr);
}
e.set_mem_order(&self.mem_order);
@ -3465,10 +3540,17 @@ impl SM70Op for OpAtom {
e.set_opcode(0x38d);
assert!(cmp_src == AtomCmpSrc::Separate);
assert!(self.uniform_address.is_zero());
e.set_reg_src(32..40, &self.cmpr);
e.set_reg_src(64..72, &self.data);
} else {
e.set_opcode(0x38c);
if has_ugpr {
e.set_opcode(0x98c);
e.set_ureg_src(64, &self.uniform_address);
e.set_bit(91, true);
} else {
e.set_opcode(0x38c);
}
e.set_reg_src(32..40, &self.data);
assert!(
@ -3483,6 +3565,8 @@ impl SM70Op for OpAtom {
e.set_atom_op(87..91, self.atom_op);
}
e.set_reg_src(24..32, &self.addr);
e.set_field(40..64, self.addr_offset);
assert!(e.sm >= 75 || self.addr_stride == OffsetStride::X1);
e.set_field(78..80, self.addr_stride.encode_sm75());
@ -3494,8 +3578,6 @@ impl SM70Op for OpAtom {
}
e.set_dst(&self.dst);
e.set_reg_addr(24..32, &self.addr, 72);
e.set_field(40..64, self.addr_offset);
e.set_atom_type(self.atom_type, false);
}
}
@ -4218,6 +4300,7 @@ impl SM70Op for OpLdsm {
e.set_opcode(0x83b);
e.set_dst(&self.dst);
e.set_reg_src(24..32, &self.addr);
e.set_ureg_src(32, &self.uniform_addr);
e.set_field(40..64, self.offset);
e.set_field(
72..74,
@ -4238,6 +4321,7 @@ impl SM70Op for OpLdsm {
// LdsmSize::M8N32 => 3,
},
);
e.set_bit(91, !self.uniform_addr.is_zero());
}
}