nak: properly adjust max_warps_per_sm according to shared memory usage

There are a couple of weird edge cases where occupancy is hurt quite big,
so with this it should be easier to find them and figure something out.

But this will also allow us to identify if reducing shared memory usage
has any positive impact on occupancy.

Totals from 254 (0.02% of 1212873) affected shaders:
Max warps/SM: 9832 -> 6749 (-31.36%)
This commit is contained in:
Karol Herbst 2026-04-28 21:35:47 +02:00 committed by Karol Herbst
parent 8da4108f0b
commit a81aa90832
6 changed files with 55 additions and 16 deletions

View file

@ -200,6 +200,9 @@ pub extern "C" fn nak_compiler_create(
let nak = Box::new(nak_compiler {
sm: dev.sm,
warps_per_sm: dev.max_warps_per_mp,
max_shared_mem: u32::from(
dev.sm_smem_sizes_kB[usize::from(dev.sm_smem_size_count) - 1],
) * 1024,
nir_options: nir_options(dev),
});
@ -449,7 +452,7 @@ fn nak_compile_shader_internal(
Some(unsafe { &*fs_key })
};
let sm = ShaderModelInfo::new(nak.sm, nak.warps_per_sm);
let sm = ShaderModelInfo::new(nak.sm, nak.warps_per_sm, nak.max_shared_mem);
let mut s = nak_shader_from_nir(nak, nir, &sm);
if DEBUG.print() {

View file

@ -38,7 +38,7 @@ impl RunSingleton {
let run = Runner::new(dev_id);
let sm_nr = run.dev_info().sm;
let sm =
ShaderModelInfo::new(sm_nr, run.dev_info().max_warps_per_mp);
ShaderModelInfo::new(sm_nr, run.dev_info().max_warps_per_mp, 0);
RunSingleton { sm, run }
})
}

View file

@ -9774,11 +9774,16 @@ pub trait ShaderModel {
pub struct ShaderModelInfo {
sm: u8,
warps_per_sm: u8,
shared_mem_per_sm: u32,
}
impl ShaderModelInfo {
pub fn new(sm: u8, warps_per_sm: u8) -> Self {
ShaderModelInfo { sm, warps_per_sm }
pub fn new(sm: u8, warps_per_sm: u8, shared_mem_per_sm: u32) -> Self {
ShaderModelInfo {
sm,
warps_per_sm,
shared_mem_per_sm,
}
}
}
@ -9899,13 +9904,26 @@ pub fn gpr_limit_from_local_size(local_size: &[u16; 3]) -> u32 {
min(out, 255)
}
pub fn max_warps_per_sm(sm: &ShaderModelInfo, gprs: u32) -> u32 {
// TODO: Take local_size and shared mem limit into account for compute
pub fn max_warps_per_sm(
sm: &ShaderModelInfo,
gprs: u32,
shared_mem: u16,
block_size: u16,
) -> u32 {
// TODO: Take local_size and max blocks/SM into account for compute
let total_regs: u32 = 65536;
// GPRs are allocated in multiples of 8
let gprs = max(gprs, 1);
let gprs = gprs.next_multiple_of(8);
let max_warps = prev_multiple_of((total_regs / 32) / gprs, 4);
let mut max_warps = prev_multiple_of((total_regs / 32) / gprs, 4);
let block_size = u32::from(block_size.next_multiple_of(32));
// Next we limit the warps according to our available shared memory
if shared_mem != 0 && block_size != 0 {
let max_blocks = sm.shared_mem_per_sm / u32::from(shared_mem);
max_warps = max_warps.min((max_blocks * block_size) / 32)
}
min(max_warps, sm.warps_per_sm.into())
}
@ -10071,10 +10089,21 @@ impl Shader<'_> {
self.info.writes_global_mem = writes_global_mem;
self.info.uses_fp64 = uses_fp64;
self.info.max_warps_per_sm = max_warps_per_sm(
self.sm,
self.info.num_gprs as u32 + self.sm.hw_reserved_gprs(),
);
if let ShaderStageInfo::Compute(compute) = &self.info.stage {
self.info.max_warps_per_sm = max_warps_per_sm(
self.sm,
self.info.num_gprs as u32 + self.sm.hw_reserved_gprs(),
compute.smem_size,
compute.local_size.iter().product(),
);
} else {
self.info.max_warps_per_sm = max_warps_per_sm(
self.sm,
self.info.num_gprs as u32 + self.sm.hw_reserved_gprs(),
0,
0,
);
}
if self.sm.sm() >= 50 {
if let ShaderStageInfo::Vertex(vertex_info) = &mut self.info.stage {

View file

@ -87,7 +87,7 @@ fn disassemble_instrs(instrs: Vec<Instr>, sm: u8) -> Vec<String> {
io: ShaderIoInfo::None,
};
let sm = ShaderModelInfo::new(sm, 0);
let sm = ShaderModelInfo::new(sm, 0, 0);
let s = Shader {
sm: &sm,
info: info,

View file

@ -28,7 +28,7 @@ const TARGET_FREE: i32 = 4;
/// least `x` GPRs.
fn next_occupancy_cliff(sm: &ShaderModelInfo, x: u32) -> u32 {
let total_regs: u32 = 65536;
let threads = max_warps_per_sm(sm, x) * 32;
let threads = max_warps_per_sm(sm, x, 0, 0) * 32;
// This function doesn't actually model the maximum number of registers
// correctly - callers need to worry about that separately. We do,
@ -42,12 +42,18 @@ fn next_occupancy_cliff(sm: &ShaderModelInfo, x: u32) -> u32 {
#[test]
fn test_next_occupancy_cliff() {
for max_hw_warps in [32, 48, 64] {
let sm = ShaderModelInfo::new(75, max_hw_warps);
let sm = ShaderModelInfo::new(75, max_hw_warps, 100 * 1024);
for x in 0..255 {
let y = next_occupancy_cliff(&sm, x);
assert!(y >= x);
assert_eq!(max_warps_per_sm(&sm, x), max_warps_per_sm(&sm, y));
assert!(max_warps_per_sm(&sm, y) > max_warps_per_sm(&sm, y + 1));
assert_eq!(
max_warps_per_sm(&sm, x, 0, 0),
max_warps_per_sm(&sm, y, 0, 0)
);
assert!(
max_warps_per_sm(&sm, y, 0, 0)
> max_warps_per_sm(&sm, y + 1, 0, 0)
);
}
}
}

View file

@ -20,6 +20,7 @@ bool nak_debug_no_ugpr(void);
struct nak_compiler {
uint8_t sm;
uint8_t warps_per_sm;
uint32_t max_shared_mem;
struct nir_shader_compiler_options nir_options;
};