nak/hw_runner: add ldsm tests

Reviewed-by: Mary Guillemard <mary@mary.zone>
Acked-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36363>
This commit is contained in:
Karol Herbst 2025-05-08 09:56:26 +02:00
parent ad8df8407e
commit 016159096f

View file

@ -14,7 +14,7 @@ use compiler::bindings::MESA_SHADER_COMPUTE;
use compiler::cfg::CFGBuilder;
use nak_bindings::*;
use rustc_hash::FxBuildHasher;
use std::mem::offset_of;
use std::mem::{offset_of, size_of};
use std::str::FromStr;
use std::sync::OnceLock;
@ -1652,3 +1652,99 @@ pub fn test_gpr_limit_from_local_size() {
});
}
}
#[test]
fn test_op_ldsm() {
let run = RunSingleton::get();
if run.sm.sm() < 75 {
return;
}
let mut b = TestShaderBuilder::new(run.sm.as_ref());
// First load the test data and store it inside shared memory. Each thread handles 8 elements.
let input = b.ld_test_data(0, MemType::B128)[0];
let lane_id = b.alloc_ssa(RegFile::GPR);
b.push_op(OpS2R {
dst: lane_id.into(),
idx: NAK_SV_LANE_ID,
});
let offset = b.imul(lane_id.into(), 16.into());
b.push_op(OpSt {
addr: offset.into(),
data: input.into(),
offset: 0.into(),
access: MemAccess {
mem_type: MemType::B128,
space: MemSpace::Shared,
order: MemOrder::Strong(MemScope::CTA),
eviction_priority: MemEvictionPriority::Normal,
},
});
b.push_op(OpMemBar {
scope: MemScope::CTA,
});
let res = b.alloc_ssa_vec(RegFile::GPR, 4);
let addr = b.imul(lane_id.into(), 16.into());
b.push_op(OpLdsm {
dst: res.clone().into(),
mat_size: LdsmSize::M8N8,
mat_count: 4,
addr: addr.into(),
offset: 0.into(),
});
b.st_test_data(16, MemType::B128, res.into());
type Data = [u16; 4 * 2 * 2];
b.set_shared_size((size_of::<Data>() * 32 / 2) as u16);
let bin = b.compile();
let mut data = Vec::<[u16; 4 * 2 * 2]>::new();
for i in 0..32 {
let val = i * 4 * 2;
data.push([
val,
val + 1,
val + 2,
val + 3,
val + 4,
val + 5,
val + 6,
val + 7,
0,
0,
0,
0,
0,
0,
0,
0,
]);
}
println!("{data:?}");
run.run.run(&bin, &mut data).unwrap();
println!("{data:?}");
// the results contains the following values:
// 0..32 : matrix 1
// 32..64 : matrix 2
// 64..96 : matrix 3
// 96..128: matrix 4
//
// and each thread loads from the address from thread (thread_id >> 4)
// plus the offset (thread_id & 0x3) * 2
for i in 0..32 {
assert_eq!(i * 2 + 64 * 0 + 0, data[i][8].into());
assert_eq!(i * 2 + 64 * 0 + 1, data[i][9].into());
assert_eq!(i * 2 + 64 * 1 + 0, data[i][10].into());
assert_eq!(i * 2 + 64 * 1 + 1, data[i][11].into());
assert_eq!(i * 2 + 64 * 2 + 0, data[i][12].into());
assert_eq!(i * 2 + 64 * 2 + 1, data[i][13].into());
assert_eq!(i * 2 + 64 * 3 + 0, data[i][14].into());
assert_eq!(i * 2 + 64 * 3 + 1, data[i][15].into());
}
}