nak: properly adjust max_warps_per_sm according to thread blocks

We can only have a limited amount of blocks resident per SM, so we should
take this into account inside `max_warps_per_sm()` as well.

Totals from 21402 (1.76% of 1212873) affected shaders:
Max warps/SM: 963056 -> 642704 (-33.26%)
This commit is contained in:
Karol Herbst 2026-04-29 12:51:07 +02:00 committed by Karol Herbst
parent a81aa90832
commit a09af6ce29
6 changed files with 30 additions and 7 deletions

View file

@ -200,6 +200,7 @@ pub extern "C" fn nak_compiler_create(
let nak = Box::new(nak_compiler {
sm: dev.sm,
warps_per_sm: dev.max_warps_per_mp,
blocks_per_sm: dev.max_blocks_per_mp,
max_shared_mem: u32::from(
dev.sm_smem_sizes_kB[usize::from(dev.sm_smem_size_count) - 1],
) * 1024,
@ -452,7 +453,12 @@ fn nak_compile_shader_internal(
Some(unsafe { &*fs_key })
};
let sm = ShaderModelInfo::new(nak.sm, nak.warps_per_sm, nak.max_shared_mem);
let sm = ShaderModelInfo::new(
nak.sm,
nak.warps_per_sm,
nak.blocks_per_sm,
nak.max_shared_mem,
);
let mut s = nak_shader_from_nir(nak, nir, &sm);
if DEBUG.print() {

View file

@ -37,8 +37,12 @@ impl RunSingleton {
let run = Runner::new(dev_id);
let sm_nr = run.dev_info().sm;
let sm =
ShaderModelInfo::new(sm_nr, run.dev_info().max_warps_per_mp, 0);
let sm = ShaderModelInfo::new(
sm_nr,
run.dev_info().max_warps_per_mp,
0,
0,
);
RunSingleton { sm, run }
})
}

View file

@ -9774,14 +9774,21 @@ pub trait ShaderModel {
pub struct ShaderModelInfo {
sm: u8,
warps_per_sm: u8,
blocks_per_sm: u8,
shared_mem_per_sm: u32,
}
impl ShaderModelInfo {
pub fn new(sm: u8, warps_per_sm: u8, shared_mem_per_sm: u32) -> Self {
pub fn new(
sm: u8,
warps_per_sm: u8,
blocks_per_sm: u8,
shared_mem_per_sm: u32,
) -> Self {
ShaderModelInfo {
sm,
warps_per_sm,
blocks_per_sm,
shared_mem_per_sm,
}
}
@ -9910,7 +9917,6 @@ pub fn max_warps_per_sm(
shared_mem: u16,
block_size: u16,
) -> u32 {
// TODO: Take local_size and max blocks/SM into account for compute
let total_regs: u32 = 65536;
// GPRs are allocated in multiples of 8
let gprs = max(gprs, 1);
@ -9918,6 +9924,12 @@ pub fn max_warps_per_sm(
let mut max_warps = prev_multiple_of((total_regs / 32) / gprs, 4);
let block_size = u32::from(block_size.next_multiple_of(32));
// Next we limit the warps according to our available blocks
if block_size != 0 {
max_warps =
max_warps.min((block_size * u32::from(sm.blocks_per_sm)) / 32);
}
// Next we limit the warps according to our available shared memory
if shared_mem != 0 && block_size != 0 {
let max_blocks = sm.shared_mem_per_sm / u32::from(shared_mem);

View file

@ -87,7 +87,7 @@ fn disassemble_instrs(instrs: Vec<Instr>, sm: u8) -> Vec<String> {
io: ShaderIoInfo::None,
};
let sm = ShaderModelInfo::new(sm, 0, 0);
let sm = ShaderModelInfo::new(sm, 0, 0, 0);
let s = Shader {
sm: &sm,
info: info,

View file

@ -42,7 +42,7 @@ fn next_occupancy_cliff(sm: &ShaderModelInfo, x: u32) -> u32 {
#[test]
fn test_next_occupancy_cliff() {
for max_hw_warps in [32, 48, 64] {
let sm = ShaderModelInfo::new(75, max_hw_warps, 100 * 1024);
let sm = ShaderModelInfo::new(75, max_hw_warps, 0, 100 * 1024);
for x in 0..255 {
let y = next_occupancy_cliff(&sm, x);
assert!(y >= x);

View file

@ -20,6 +20,7 @@ bool nak_debug_no_ugpr(void);
struct nak_compiler {
uint8_t sm;
uint8_t warps_per_sm;
uint8_t blocks_per_sm;
uint32_t max_shared_mem;
struct nir_shader_compiler_options nir_options;