From 4f1c72e665f454e89e18d04a295dcf24be1bcdfc Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Fri, 5 Jun 2026 10:23:18 -0400 Subject: [PATCH] kraid: RA per-byte Part-of: --- src/panfrost/compiler/kraid/ir.rs | 22 +++- src/panfrost/compiler/kraid/ra.rs | 151 ++++++++++++++++++------ src/panfrost/compiler/kraid/validate.rs | 12 ++ 3 files changed, 144 insertions(+), 41 deletions(-) diff --git a/src/panfrost/compiler/kraid/ir.rs b/src/panfrost/compiler/kraid/ir.rs index b45713d6221..6fbeda53ff1 100644 --- a/src/panfrost/compiler/kraid/ir.rs +++ b/src/panfrost/compiler/kraid/ir.rs @@ -111,6 +111,10 @@ impl From<&SmallConstant> for FAURef { /// write is simply masked. #[derive(Clone, Copy, PartialEq)] pub enum RegRange { + Byte0, + Byte1, + Byte2, + Byte3, Half0, Half1, Regs(u8), @@ -125,6 +129,10 @@ pub struct RegRef { impl fmt::Display for RegRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.range { + RegRange::Byte0 => write!(f, "r{}.b0", self.idx), + RegRange::Byte1 => write!(f, "r{}.b1", self.idx), + RegRange::Byte2 => write!(f, "r{}.b2", self.idx), + RegRange::Byte3 => write!(f, "r{}.b3", self.idx), RegRange::Half0 => write!(f, "r{}.h0", self.idx), RegRange::Half1 => write!(f, "r{}.h1", self.idx), RegRange::Regs(n) => { @@ -141,6 +149,10 @@ impl fmt::Display for RegRef { impl RegRef { pub fn bytes(&self) -> u8 { match self.range { + RegRange::Byte0 + | RegRange::Byte1 + | RegRange::Byte2 + | RegRange::Byte3 => 1, RegRange::Half0 | RegRange::Half1 => 2, RegRange::Regs(n) => n * 4, } @@ -148,8 +160,10 @@ impl RegRef { pub fn byte_offset(&self) -> u8 { match self.range { - RegRange::Half0 | RegRange::Regs(_) => 0, - RegRange::Half1 => 2, + RegRange::Byte0 | RegRange::Half0 | RegRange::Regs(_) => 0, + RegRange::Byte1 => 1, + RegRange::Byte2 | RegRange::Half1 => 2, + RegRange::Byte3 => 3, } } } @@ -553,6 +567,10 @@ impl fmt::Display for DstLanes { impl From for DstLanes { fn from(range: RegRange) -> DstLanes { match range { + RegRange::Byte0 => DstLanes::B0, + RegRange::Byte1 => DstLanes::B1, + RegRange::Byte2 => DstLanes::B2, + RegRange::Byte3 => DstLanes::B3, RegRange::Half0 => DstLanes::H0, RegRange::Half1 => DstLanes::H1, RegRange::Regs(_) => DstLanes::All, diff --git a/src/panfrost/compiler/kraid/ra.rs b/src/panfrost/compiler/kraid/ra.rs index 5fd7334a54c..15f58916e2a 100644 --- a/src/panfrost/compiler/kraid/ra.rs +++ b/src/panfrost/compiler/kraid/ra.rs @@ -3,6 +3,7 @@ use crate::ir::*; use compiler::bitset::*; +use compiler::smallvec::*; use rustc_hash::FxHashMap; struct BlockIp { @@ -82,22 +83,41 @@ impl TrivialLiveness { } } -fn reg_ref_for_hr(hr: u8, bytes: u8) -> RegRef { - let range = if bytes == 2 { - if hr & 1 == 0 { +fn widen_lanes(lanes: DstLanes) -> DstLanes { + use DstLanes::*; + match lanes { + None => AnyB, + All => panic!("Everything supports ALL"), + AnyB => AnyH, + AnyH | H0 | H1 => All, + B0 | B1 => H0, + B2 | B3 => H1, + } +} + +fn reg_ref_for_byte(b: u8, bytes: u8) -> RegRef { + let range = if bytes >= 4 { + assert_eq!(bytes % 4, 0); + RegRange::Regs((bytes / 4).try_into().unwrap()) + } else if bytes == 2 { + assert_eq!(b % 2, 0); + if b % 4 == 0 { RegRange::Half0 } else { RegRange::Half1 } } else { - assert_eq!(bytes % 4, 0); - RegRange::Regs(bytes / 4) + assert_eq!(bytes, 1); + match b % 4 { + 0 => RegRange::Byte0, + 1 => RegRange::Byte1, + 2 => RegRange::Byte2, + 3 => RegRange::Byte3, + _ => panic!("bytes % 4 < 4"), + } }; - RegRef { - idx: hr >> 1, - range, - } + RegRef { idx: b / 4, range } } fn ra_trivial(s: &mut Shader) { @@ -105,8 +125,8 @@ fn ra_trivial(s: &mut Shader) { // Allocate in units of half registers. We might be a dumb allocator but // we can at least try to exercise Kraid's half register model. - let mut hr_used: BitSet = Default::default(); - let mut ssa_hr: FxHashMap = Default::default(); + let mut byte_used: BitSet = Default::default(); + let mut ssa_b: FxHashMap = Default::default(); for (bi, block) in s.blocks.iter_mut().enumerate() { for (ip, instr) in block.instrs.iter_mut().enumerate() { @@ -115,27 +135,31 @@ fn ra_trivial(s: &mut Shader) { continue; }; - let mut vec_hr = 0; + let mut vec_b = 0; for (i, ssa) in vec.iter().enumerate() { - let hr = *ssa_hr.get(ssa).unwrap(); + let b = *ssa_b.get(ssa).unwrap(); if live.is_killed_by(ssa, bi, ip) { - let hr_count = ssa.bits().div_ceil(16); - for hr in hr..(hr + hr_count) { - hr_used.remove(hr.into()); + let bytes = ssa.bits() / 8; + for b in b..(b + bytes) { + byte_used.remove(b.into()); } } if i == 0 { - vec_hr = hr; + vec_b = b; } else { // We don't know how to move registers - assert_eq!(hr, vec_hr + u8::try_from(i * 2).unwrap()); + assert_eq!(b, vec_b + u8::try_from(i * 4).unwrap()); } } - let reg = reg_ref_for_hr(vec_hr, vec.bytes()); + let reg = reg_ref_for_byte(vec_b, vec.bytes()); let swz = match reg.range { + RegRange::Byte0 => Swizzle::B0000, + RegRange::Byte1 => Swizzle::B1111, + RegRange::Byte2 => Swizzle::B2222, + RegRange::Byte3 => Swizzle::B3333, RegRange::Half0 => Swizzle::H00, RegRange::Half1 => Swizzle::H11, RegRange::Regs(_) => Swizzle::NONE, @@ -146,33 +170,82 @@ fn ra_trivial(s: &mut Shader) { src.src_ref = reg.into(); } - for dst in instr.dsts_mut() { + let mut dst_regs = SmallVec::new(); + for dst in instr.dsts() { let DstRef::SSA(vec) = &dst.dst_ref else { continue; }; - let hr_count = vec.bytes().div_ceil(2); - let hr_align = hr_count.next_power_of_two(); - let vec_hr = hr_used.find_aligned_unset_range( - 0, // start_point - hr_count.into(), - hr_align.into(), - 0, // align_offset - ); + let mut alloc_lanes = dst.lanes; + while !s.model.op_dst_supports_lanes(&instr.op, alloc_lanes) { + alloc_lanes = widen_lanes(alloc_lanes); + } - assert!(vec_hr <= 128, "Ran out of registers!"); - let vec_hr = vec_hr as u8; + let bytes = vec.bytes(); + let alloc_bytes = alloc_lanes.bytes(bytes); + let (align_mul, align_off) = if bytes > 4 { + debug_assert_eq!(alloc_lanes, DstLanes::All); + (bytes.next_power_of_two(), 0) + } else { + alloc_lanes.align() + }; + + let mut alloc_start = 0; + let (b, reg) = loop { + let b = byte_used.find_aligned_unset_range( + alloc_start, + alloc_bytes.into(), + align_mul.into(), + align_off.into(), + ); + + assert!( + b + usize::from(bytes) <= 256, + "Ran out of registers!" + ); + let b = b as u8; + let reg = reg_ref_for_byte(b, alloc_bytes); + let lanes = DstLanes::from(reg.range); + + match alloc_lanes { + DstLanes::All => debug_assert_eq!(lanes, DstLanes::All), + DstLanes::AnyB => debug_assert!(lanes.is_byte()), + DstLanes::AnyH => debug_assert!(lanes.is_half()), + _ => debug_assert_eq!(lanes, alloc_lanes), + } + + if s.model.op_dst_supports_lanes(&instr.op, lanes) { + break (b, reg); + } + + alloc_start = usize::from(b) + 1; + }; + + // In case when the SSA value is smaller than the region we + // just allocated, adjust accordingly. + let (dst_mul, dst_off) = dst.lanes.align(); + let b = (b & !(dst_mul - 1)) | dst_off; for (i, ssa) in vec.iter().enumerate() { - let hr = vec_hr + u8::try_from(i * 2).unwrap(); - ssa_hr.insert(*ssa, hr); + ssa_b.insert(*ssa, b + u8::try_from(i * 4).unwrap()); } - for hr in vec_hr..(vec_hr + hr_count) { - hr_used.insert(hr.into()); + // In case when the SSA value is smaller than the region we + // just allocated, this only marks the bytes consumed by the + // SSA value as used. This effectively kills the other bytes + // immediately. + for i in 0..bytes { + byte_used.insert(usize::from(b) + usize::from(i)); } - *dst = reg_ref_for_hr(vec_hr, vec.bytes()).into(); + dst_regs.push(reg); + } + + debug_assert_eq!(instr.dsts().len(), dst_regs.len()); + for (dst, reg) in + instr.dsts_mut().iter_mut().zip(dst_regs.into_iter()) + { + *dst = reg.into(); } for dst in instr.dsts() { @@ -182,10 +255,10 @@ fn ra_trivial(s: &mut Shader) { for ssa in vec { if live.is_never_used(ssa) { - let first_hr = *ssa_hr.get(ssa).unwrap(); - let hr_count = ssa.bits().div_ceil(16); - for hr in first_hr..(first_hr + hr_count) { - hr_used.remove(hr.into()); + let vec_b = *ssa_b.get(ssa).unwrap(); + let bytes = ssa.bits() / 8; + for b in 0..bytes { + byte_used.insert((vec_b + b).into()); } } } diff --git a/src/panfrost/compiler/kraid/validate.rs b/src/panfrost/compiler/kraid/validate.rs index 166cd13cee8..3eb9b808d24 100644 --- a/src/panfrost/compiler/kraid/validate.rs +++ b/src/panfrost/compiler/kraid/validate.rs @@ -35,6 +35,18 @@ fn validate_instr(instr: &Instr, ssa_vals: &mut FxHashSet) { } } SrcRef::Reg(reg) => match reg.range { + RegRange::Byte0 => { + assert!(src.swizzle.bytes_read() & !0b0001 == 0); + } + RegRange::Byte1 => { + assert!(src.swizzle.bytes_read() & !0b0010 == 0); + } + RegRange::Byte2 => { + assert!(src.swizzle.bytes_read() & !0b0100 == 0); + } + RegRange::Byte3 => { + assert!(src.swizzle.bytes_read() & !0b1000 == 0); + } RegRange::Half0 => { assert!(src.swizzle.bytes_read() & !0b0011 == 0); }