kraid: RA per-byte

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/42200>
This commit is contained in:
Faith Ekstrand 2026-06-05 10:23:18 -04:00 committed by Marge Bot
parent 4ed955b7f8
commit 4f1c72e665
3 changed files with 144 additions and 41 deletions

View file

@ -111,6 +111,10 @@ impl From<&SmallConstant> for FAURef {
/// write is simply masked.
#[derive(Clone, Copy, PartialEq)]
pub enum RegRange {
Byte0,
Byte1,
Byte2,
Byte3,
Half0,
Half1,
Regs(u8),
@ -125,6 +129,10 @@ pub struct RegRef {
impl fmt::Display for RegRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.range {
RegRange::Byte0 => write!(f, "r{}.b0", self.idx),
RegRange::Byte1 => write!(f, "r{}.b1", self.idx),
RegRange::Byte2 => write!(f, "r{}.b2", self.idx),
RegRange::Byte3 => write!(f, "r{}.b3", self.idx),
RegRange::Half0 => write!(f, "r{}.h0", self.idx),
RegRange::Half1 => write!(f, "r{}.h1", self.idx),
RegRange::Regs(n) => {
@ -141,6 +149,10 @@ impl fmt::Display for RegRef {
impl RegRef {
pub fn bytes(&self) -> u8 {
match self.range {
RegRange::Byte0
| RegRange::Byte1
| RegRange::Byte2
| RegRange::Byte3 => 1,
RegRange::Half0 | RegRange::Half1 => 2,
RegRange::Regs(n) => n * 4,
}
@ -148,8 +160,10 @@ impl RegRef {
pub fn byte_offset(&self) -> u8 {
match self.range {
RegRange::Half0 | RegRange::Regs(_) => 0,
RegRange::Half1 => 2,
RegRange::Byte0 | RegRange::Half0 | RegRange::Regs(_) => 0,
RegRange::Byte1 => 1,
RegRange::Byte2 | RegRange::Half1 => 2,
RegRange::Byte3 => 3,
}
}
}
@ -553,6 +567,10 @@ impl fmt::Display for DstLanes {
impl From<RegRange> for DstLanes {
fn from(range: RegRange) -> DstLanes {
match range {
RegRange::Byte0 => DstLanes::B0,
RegRange::Byte1 => DstLanes::B1,
RegRange::Byte2 => DstLanes::B2,
RegRange::Byte3 => DstLanes::B3,
RegRange::Half0 => DstLanes::H0,
RegRange::Half1 => DstLanes::H1,
RegRange::Regs(_) => DstLanes::All,

View file

@ -3,6 +3,7 @@
use crate::ir::*;
use compiler::bitset::*;
use compiler::smallvec::*;
use rustc_hash::FxHashMap;
struct BlockIp {
@ -82,22 +83,41 @@ impl TrivialLiveness {
}
}
fn reg_ref_for_hr(hr: u8, bytes: u8) -> RegRef {
let range = if bytes == 2 {
if hr & 1 == 0 {
fn widen_lanes(lanes: DstLanes) -> DstLanes {
use DstLanes::*;
match lanes {
None => AnyB,
All => panic!("Everything supports ALL"),
AnyB => AnyH,
AnyH | H0 | H1 => All,
B0 | B1 => H0,
B2 | B3 => H1,
}
}
fn reg_ref_for_byte(b: u8, bytes: u8) -> RegRef {
let range = if bytes >= 4 {
assert_eq!(bytes % 4, 0);
RegRange::Regs((bytes / 4).try_into().unwrap())
} else if bytes == 2 {
assert_eq!(b % 2, 0);
if b % 4 == 0 {
RegRange::Half0
} else {
RegRange::Half1
}
} else {
assert_eq!(bytes % 4, 0);
RegRange::Regs(bytes / 4)
assert_eq!(bytes, 1);
match b % 4 {
0 => RegRange::Byte0,
1 => RegRange::Byte1,
2 => RegRange::Byte2,
3 => RegRange::Byte3,
_ => panic!("bytes % 4 < 4"),
}
};
RegRef {
idx: hr >> 1,
range,
}
RegRef { idx: b / 4, range }
}
fn ra_trivial(s: &mut Shader) {
@ -105,8 +125,8 @@ fn ra_trivial(s: &mut Shader) {
// Allocate in units of half registers. We might be a dumb allocator but
// we can at least try to exercise Kraid's half register model.
let mut hr_used: BitSet = Default::default();
let mut ssa_hr: FxHashMap<SSAValue, u8> = Default::default();
let mut byte_used: BitSet = Default::default();
let mut ssa_b: FxHashMap<SSAValue, u8> = Default::default();
for (bi, block) in s.blocks.iter_mut().enumerate() {
for (ip, instr) in block.instrs.iter_mut().enumerate() {
@ -115,27 +135,31 @@ fn ra_trivial(s: &mut Shader) {
continue;
};
let mut vec_hr = 0;
let mut vec_b = 0;
for (i, ssa) in vec.iter().enumerate() {
let hr = *ssa_hr.get(ssa).unwrap();
let b = *ssa_b.get(ssa).unwrap();
if live.is_killed_by(ssa, bi, ip) {
let hr_count = ssa.bits().div_ceil(16);
for hr in hr..(hr + hr_count) {
hr_used.remove(hr.into());
let bytes = ssa.bits() / 8;
for b in b..(b + bytes) {
byte_used.remove(b.into());
}
}
if i == 0 {
vec_hr = hr;
vec_b = b;
} else {
// We don't know how to move registers
assert_eq!(hr, vec_hr + u8::try_from(i * 2).unwrap());
assert_eq!(b, vec_b + u8::try_from(i * 4).unwrap());
}
}
let reg = reg_ref_for_hr(vec_hr, vec.bytes());
let reg = reg_ref_for_byte(vec_b, vec.bytes());
let swz = match reg.range {
RegRange::Byte0 => Swizzle::B0000,
RegRange::Byte1 => Swizzle::B1111,
RegRange::Byte2 => Swizzle::B2222,
RegRange::Byte3 => Swizzle::B3333,
RegRange::Half0 => Swizzle::H00,
RegRange::Half1 => Swizzle::H11,
RegRange::Regs(_) => Swizzle::NONE,
@ -146,33 +170,82 @@ fn ra_trivial(s: &mut Shader) {
src.src_ref = reg.into();
}
for dst in instr.dsts_mut() {
let mut dst_regs = SmallVec::new();
for dst in instr.dsts() {
let DstRef::SSA(vec) = &dst.dst_ref else {
continue;
};
let hr_count = vec.bytes().div_ceil(2);
let hr_align = hr_count.next_power_of_two();
let vec_hr = hr_used.find_aligned_unset_range(
0, // start_point
hr_count.into(),
hr_align.into(),
0, // align_offset
);
let mut alloc_lanes = dst.lanes;
while !s.model.op_dst_supports_lanes(&instr.op, alloc_lanes) {
alloc_lanes = widen_lanes(alloc_lanes);
}
assert!(vec_hr <= 128, "Ran out of registers!");
let vec_hr = vec_hr as u8;
let bytes = vec.bytes();
let alloc_bytes = alloc_lanes.bytes(bytes);
let (align_mul, align_off) = if bytes > 4 {
debug_assert_eq!(alloc_lanes, DstLanes::All);
(bytes.next_power_of_two(), 0)
} else {
alloc_lanes.align()
};
let mut alloc_start = 0;
let (b, reg) = loop {
let b = byte_used.find_aligned_unset_range(
alloc_start,
alloc_bytes.into(),
align_mul.into(),
align_off.into(),
);
assert!(
b + usize::from(bytes) <= 256,
"Ran out of registers!"
);
let b = b as u8;
let reg = reg_ref_for_byte(b, alloc_bytes);
let lanes = DstLanes::from(reg.range);
match alloc_lanes {
DstLanes::All => debug_assert_eq!(lanes, DstLanes::All),
DstLanes::AnyB => debug_assert!(lanes.is_byte()),
DstLanes::AnyH => debug_assert!(lanes.is_half()),
_ => debug_assert_eq!(lanes, alloc_lanes),
}
if s.model.op_dst_supports_lanes(&instr.op, lanes) {
break (b, reg);
}
alloc_start = usize::from(b) + 1;
};
// In case when the SSA value is smaller than the region we
// just allocated, adjust accordingly.
let (dst_mul, dst_off) = dst.lanes.align();
let b = (b & !(dst_mul - 1)) | dst_off;
for (i, ssa) in vec.iter().enumerate() {
let hr = vec_hr + u8::try_from(i * 2).unwrap();
ssa_hr.insert(*ssa, hr);
ssa_b.insert(*ssa, b + u8::try_from(i * 4).unwrap());
}
for hr in vec_hr..(vec_hr + hr_count) {
hr_used.insert(hr.into());
// In case when the SSA value is smaller than the region we
// just allocated, this only marks the bytes consumed by the
// SSA value as used. This effectively kills the other bytes
// immediately.
for i in 0..bytes {
byte_used.insert(usize::from(b) + usize::from(i));
}
*dst = reg_ref_for_hr(vec_hr, vec.bytes()).into();
dst_regs.push(reg);
}
debug_assert_eq!(instr.dsts().len(), dst_regs.len());
for (dst, reg) in
instr.dsts_mut().iter_mut().zip(dst_regs.into_iter())
{
*dst = reg.into();
}
for dst in instr.dsts() {
@ -182,10 +255,10 @@ fn ra_trivial(s: &mut Shader) {
for ssa in vec {
if live.is_never_used(ssa) {
let first_hr = *ssa_hr.get(ssa).unwrap();
let hr_count = ssa.bits().div_ceil(16);
for hr in first_hr..(first_hr + hr_count) {
hr_used.remove(hr.into());
let vec_b = *ssa_b.get(ssa).unwrap();
let bytes = ssa.bits() / 8;
for b in 0..bytes {
byte_used.insert((vec_b + b).into());
}
}
}

View file

@ -35,6 +35,18 @@ fn validate_instr(instr: &Instr, ssa_vals: &mut FxHashSet<SSAValue>) {
}
}
SrcRef::Reg(reg) => match reg.range {
RegRange::Byte0 => {
assert!(src.swizzle.bytes_read() & !0b0001 == 0);
}
RegRange::Byte1 => {
assert!(src.swizzle.bytes_read() & !0b0010 == 0);
}
RegRange::Byte2 => {
assert!(src.swizzle.bytes_read() & !0b0100 == 0);
}
RegRange::Byte3 => {
assert!(src.swizzle.bytes_read() & !0b1000 == 0);
}
RegRange::Half0 => {
assert!(src.swizzle.bytes_read() & !0b0011 == 0);
}