nak: add UGPR/GPR lowering for load/store/atom instructions

This tries to handle all combinations we might run into to. We should rely
on previous optimizations that the more difficult cases never happend.

As a side benefit instead of lowering a UGPR to a GPR, it will now be
moved to the UGPR slot.

Totals from 258010 (21.27% of 1212873) affected shaders:
CodeSize: 3742700224 -> 3576740928 (-4.43%); split: -4.44%, +0.01%
Number of GPRs: 13606055 -> 13496463 (-0.81%); split: -0.86%, +0.05%
SLM Size: 589740 -> 589660 (-0.01%)
Static cycle count: 3271547493 -> 3272550831 (+0.03%); split: -0.47%, +0.50%
Spills to memory: 56180 -> 56136 (-0.08%)
Fills from memory: 56180 -> 56136 (-0.08%)
Spills to reg: 108211 -> 110013 (+1.67%); split: -0.63%, +2.30%
Fills from reg: 99216 -> 100471 (+1.26%); split: -0.30%, +1.56%
Max warps/SM: 9921228 -> 9972060 (+0.51%); split: +0.52%, -0.00%
This commit is contained in:
Karol Herbst 2026-02-24 04:32:57 +01:00 committed by Karol Herbst
parent 24b725a5d2
commit eeadd23c09
2 changed files with 96 additions and 5 deletions

View file

@ -6468,6 +6468,17 @@ pub enum OffsetStride {
X16 = 4,
}
impl OffsetStride {
pub fn shift(&self) -> u32 {
match self {
Self::X1 => 0,
Self::X4 => 2,
Self::X8 => 3,
Self::X16 => 4,
}
}
}
impl fmt::Display for OffsetStride {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {

View file

@ -10,6 +10,7 @@ use crate::sm70::ShaderModel70;
use bitview::*;
use rustc_hash::FxHashMap;
use std::mem;
use std::ops::Range;
/// A per-op trait that implements Volta+ opcode semantics
@ -774,6 +775,60 @@ fn op_gpr(op: &impl DstsAsSlice) -> RegFile {
}
}
fn legalize_load_store_address(
b: &mut LegalizeBuilder,
addr: &mut Src,
uniform_addr: &mut Src,
stride: Option<&mut OffsetStride>,
) {
let stride_x1_or_none = matches!(stride, Some(OffsetStride::X1) | None);
if addr.is_ugpr_reg() {
if stride_x1_or_none && uniform_addr.is_zero() {
*uniform_addr = mem::replace(addr, Src::ZERO);
} else {
b.copy_src_if_uniform(addr);
}
}
if uniform_addr.is_gpr_reg() {
if addr.is_zero() {
assert!(stride_x1_or_none);
*addr = mem::replace(uniform_addr, Src::ZERO);
} else {
let uniform_ssa = uniform_addr.as_ssa().unwrap();
let mut ssa = addr.as_ssa().unwrap();
let addr_comps = ssa.comps();
if let Some(stride) = stride {
if *stride != OffsetStride::X1 {
assert_eq!(addr_comps, 1);
let shift = stride.shift();
let shift = b.copy(shift.into());
*addr = b.shl(addr.clone(), shift.into()).into();
ssa = addr.as_ssa().unwrap();
*stride = OffsetStride::X1;
}
}
if uniform_ssa.comps() == 2 {
// In case the non uniform address is 32 bits and the uniform one 64,
// we need convert it to 64 bits.
if uniform_ssa.comps() != addr_comps {
let zero = b.copy(0.into());
*addr = [ssa[0], zero].into();
}
*addr = b
.iadd64(addr.clone(), uniform_addr.clone(), Src::ZERO)
.into()
} else {
*addr =
b.iadd(addr.clone(), uniform_addr.clone(), Src::ZERO).into()
}
*uniform_addr = 0.into();
}
}
}
//
// Implementations of SM70Op for each op we support on Volta+
//
@ -3165,7 +3220,12 @@ impl SM70Op for OpSuAtom {
impl SM70Op for OpLd {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_addr,
Some(&mut self.stride),
);
b.copy_src_if_uniform(&mut self.pred);
}
@ -3327,8 +3387,13 @@ impl SM70Op for OpLdc {
impl SM70Op for OpSt {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
b.copy_src_if_uniform(&mut self.data);
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_addr,
Some(&mut self.stride),
);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
@ -3461,9 +3526,19 @@ impl SM70Encoder<'_> {
impl SM70Op for OpAtom {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
b.copy_src_if_uniform(&mut self.cmpr);
b.copy_src_if_uniform(&mut self.data);
if matches!(self.atom_op, AtomOp::CmpExch(_)) {
b.copy_src_if_uniform(&mut self.addr);
b.copy_src_if_uniform(&mut self.cmpr);
} else {
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_address,
Some(&mut self.addr_stride),
);
}
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
@ -4291,7 +4366,12 @@ impl SM70Op for OpHmma {
impl SM70Op for OpLdsm {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_addr,
None,
);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {