nak: Use LowerBoundedU32 for SSAValue

Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Acked-by: Karol Herbst <kherbst@redhat.com>
This commit is contained in:
Faith Ekstrand 2026-05-09 22:31:48 -04:00
parent cb708ad2ad
commit 23ea5bb09b

View file

@ -3,11 +3,12 @@
use crate::ir::{HasRegFile, RegFile};
use compiler::bitset::IntoBitIndex;
use std::array;
use compiler::lower_bounded::*;
use std::fmt;
use std::num::NonZeroU32;
use std::ops::{Deref, DerefMut};
type SSAValueInner = LowerBoundedU32<{ SSAValue::IDX_DELTA }>;
/// An SSA value
///
/// Each SSA in NAK represents a single 32-bit or 1-bit (if a predicate) value
@ -21,19 +22,19 @@ use std::ops::{Deref, DerefMut};
/// register file. This way the index can be used to index tightly-packed data
/// structures such as bitsets without having to determine separate ranges for
/// each register file.
#[repr(transparent)]
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub struct SSAValue {
packed: NonZeroU32,
packed: SSAValueInner,
}
impl SSAValue {
const IDX_DELTA: u32 = 17;
/// Returns an SSA value with the given register file and index
fn new(file: RegFile, idx: u32) -> SSAValue {
assert!(
idx > 0
&& idx < (1 << 29) - u32::try_from(SSARef::LARGE_SIZE).unwrap()
);
let mut packed = idx;
assert!(idx < (1 << 29) - SSAValue::IDX_DELTA);
let mut packed = idx + SSAValue::IDX_DELTA;
assert!(u8::from(file) < 8);
packed |= u32::from(u8::from(file)) << 29;
SSAValue {
@ -43,7 +44,7 @@ impl SSAValue {
/// Returns the index of this SSA value
pub fn idx(&self) -> u32 {
self.packed.get() & 0x1fffffff
(self.packed.get() & 0x1fffffff) - SSAValue::IDX_DELTA
}
}
@ -73,67 +74,15 @@ impl fmt::Debug for SSAValue {
}
}
#[derive(Clone, Eq, Hash, PartialEq)]
struct SSAValueArray<const SIZE: usize> {
v: [SSAValue; SIZE],
}
impl<const SIZE: usize> SSAValueArray<SIZE> {
/// Returns a new SSA reference.
///
/// # Panics
///
/// This method will panic if the number of `SSAValue`s in the slice is
/// greater than `SIZE`.
#[inline]
fn new(comps: &[SSAValue]) -> Self {
assert!(comps.len() > 0 && comps.len() <= SIZE);
let mut r = Self {
v: [SSAValue {
packed: NonZeroU32::MAX,
}; SIZE],
};
r.v[..comps.len()].copy_from_slice(comps);
if comps.len() < SIZE {
r.v[SIZE - 1].packed =
(comps.len() as u32).wrapping_neg().try_into().unwrap();
}
r
}
/// Returns the number of components in this SSA reference.
fn comps(&self) -> u8 {
let size: u8 = SIZE.try_into().unwrap();
if self.v[SIZE - 1].packed.get() >= u32::MAX - (u32::from(size) - 1) {
self.v[SIZE - 1].packed.get().wrapping_neg() as u8
} else {
size
}
}
}
impl<const SIZE: usize> Deref for SSAValueArray<SIZE> {
type Target = [SSAValue];
fn deref(&self) -> &[SSAValue] {
let comps = usize::from(self.comps());
&self.v[..comps]
}
}
impl<const SIZE: usize> DerefMut for SSAValueArray<SIZE> {
fn deref_mut(&mut self) -> &mut [SSAValue] {
let comps = usize::from(self.comps());
&mut self.v[..comps]
}
}
type SSARefSmall =
LowerBoundedU32Array<{ SSAValue::IDX_DELTA }, { SSARef::SMALL_MAX_IDX }>;
type SSARefLarge =
LowerBoundedU32Array<{ SSAValue::IDX_DELTA }, { SSARef::LARGE_MAX_IDX }>;
#[derive(Clone, Eq, Hash, PartialEq)]
enum SSARefInner {
Small(SSAValueArray<{ SSARef::SMALL_SIZE }>),
Large(Box<SSAValueArray<{ SSARef::LARGE_SIZE }>>),
Small(SSARefSmall),
Large(Box<SSARefLarge>),
}
/// A reference to one or more SSA values
@ -161,8 +110,8 @@ const _: () = {
};
impl SSARef {
const SMALL_SIZE: usize = 4;
const LARGE_SIZE: usize = 16;
const SMALL_MAX_IDX: usize = 3;
const LARGE_MAX_IDX: usize = 15;
/// Returns a new SSA reference.
///
@ -172,12 +121,15 @@ impl SSARef {
/// fit in an SSARef.
#[inline]
pub fn new(comps: &[SSAValue]) -> SSARef {
assert!(comps.len() > 0);
// SAFETY: SSAValue is repr(transparent) of SSAValueInner
let bounded: &[SSAValueInner] = unsafe { std::mem::transmute(comps) };
SSARef {
v: if comps.len() > Self::SMALL_SIZE {
v: if comps.len() > SSARefSmall::MAX_LEN {
Self::cold();
SSARefInner::Large(Box::new(SSAValueArray::new(comps)))
SSARefInner::Large(Box::new(bounded.try_into().unwrap()))
} else {
SSARefInner::Small(SSAValueArray::new(comps))
SSARefInner::Small(bounded.try_into().unwrap())
},
}
}
@ -188,24 +140,34 @@ impl SSARef {
///
/// This method will panic if the number of SSA values in the slice do not
/// fit in an SSARef.
fn from_iter(mut it: impl ExactSizeIterator<Item = SSAValue>) -> Self {
fn from_iter(it: impl ExactSizeIterator<Item = SSAValue>) -> Self {
let len = it.len();
assert!(len > 0 && len <= Self::LARGE_SIZE);
let v: [SSAValue; Self::LARGE_SIZE] = array::from_fn(|_| {
it.next().unwrap_or(SSAValue {
packed: NonZeroU32::MAX,
})
});
Self::new(&v[..len])
assert!(len > 0 && len <= SSARefLarge::MAX_LEN);
SSARef {
v: if len > SSARefSmall::MAX_LEN {
Self::cold();
let mut arr: SSARefLarge = Default::default();
for ssa in it {
arr.try_push(ssa.packed).unwrap();
}
SSARefInner::Large(Box::new(arr))
} else {
let mut arr: SSARefSmall = Default::default();
for ssa in it {
arr.try_push(ssa.packed).unwrap();
}
SSARefInner::Small(arr)
},
}
}
/// Returns the number of components in this SSA reference.
pub fn comps(&self) -> u8 {
match &self.v {
SSARefInner::Small(x) => x.comps(),
SSARefInner::Small(x) => x.len() as u8,
SSARefInner::Large(x) => {
Self::cold();
x.comps()
x.len() as u8
}
}
}
@ -219,25 +181,29 @@ impl Deref for SSARef {
type Target = [SSAValue];
fn deref(&self) -> &[SSAValue] {
match &self.v {
SSARefInner::Small(x) => x.deref(),
let s = match &self.v {
SSARefInner::Small(x) => x.as_slice(),
SSARefInner::Large(x) => {
Self::cold();
x.deref()
x.as_slice()
}
}
};
// SAFETY: SSAValue is repr(transparent) of SSAValueInner
unsafe { std::mem::transmute(s) }
}
}
impl DerefMut for SSARef {
fn deref_mut(&mut self) -> &mut [SSAValue] {
match &mut self.v {
SSARefInner::Small(x) => x.deref_mut(),
let s = match &mut self.v {
SSARefInner::Small(x) => x.as_mut_slice(),
SSARefInner::Large(x) => {
Self::cold();
x.deref_mut()
x.as_mut_slice()
}
}
};
// SAFETY: SSAValue is repr(transparent) of SSAValueInner
unsafe { std::mem::transmute(s) }
}
}
@ -247,7 +213,7 @@ impl TryFrom<&[SSAValue]> for SSARef {
fn try_from(comps: &[SSAValue]) -> Result<Self, Self::Error> {
if comps.len() == 0 {
Err("Empty vector")
} else if comps.len() > Self::LARGE_SIZE {
} else if comps.len() > SSARefLarge::MAX_LEN {
Err("Too many vector components")
} else {
Ok(SSARef::new(comps))