From cb708ad2ad8b9a1b2cdeb4899c0cfbb64a0e655b Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Sat, 9 May 2026 12:13:05 -0400 Subject: [PATCH] compiler/rust: Add LowerBoundedU32[Array] types This is a generalization of NAK's SSAValue and SSAValueArray structs. But instead of depending on NAK's bespoke invariants, this depends on something far simpler: A lower bound on the u32. As long as you can guarantee that the maximum array length is strictly less than the minimum U32 value, we can pull the same trick as NAK and generalize it into a LowerBoundedU32Array type. Reviewed-by: Mel Henning Acked-by: Karol Herbst --- src/compiler/rust/lib.rs | 1 + src/compiler/rust/lower_bounded.rs | 517 +++++++++++++++++++++++++++++ src/compiler/rust/meson.build | 1 + 3 files changed, 519 insertions(+) create mode 100644 src/compiler/rust/lower_bounded.rs diff --git a/src/compiler/rust/lib.rs b/src/compiler/rust/lib.rs index 318fda2d3d6..9f3dd92c3e9 100644 --- a/src/compiler/rust/lib.rs +++ b/src/compiler/rust/lib.rs @@ -7,6 +7,7 @@ pub mod bitset; pub mod cfg; pub mod dataflow; pub mod depth_first_search; +pub mod lower_bounded; pub mod memstream; pub mod nir; pub mod nir_instr_printer; diff --git a/src/compiler/rust/lower_bounded.rs b/src/compiler/rust/lower_bounded.rs new file mode 100644 index 00000000000..12fe7f67166 --- /dev/null +++ b/src/compiler/rust/lower_bounded.rs @@ -0,0 +1,517 @@ +// Copyright © 2026 Collabora, Ltd. +// SPDX-License-Identifier: MIT + +use std::cmp::Ordering; +use std::fmt; +use std::iter::FusedIterator; +use std::num::{NonZero, NonZeroU32}; +use std::ops; +use std::slice::SliceIndex; + +/// This is a generalization of NonZeroU32 which takes allow an arbitrary +/// lower bound. Because it is implemented as NonZeroU32 internally (and hints +/// to the compiler accordingly), the lower bound must be non-zero. +#[repr(transparent)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct LowerBoundedU32 { + n: NonZeroU32, +} + +impl LowerBoundedU32 { + pub const BITS: u32 = u32::BITS; + + pub const MIN: LowerBoundedU32 = { + Self { + n: NonZero::new(MIN).expect("MIN must be non-zero"), + } + }; + + pub const MAX: LowerBoundedU32 = Self { n: NonZeroU32::MAX }; + + pub const fn get(self) -> u32 { + self.n.get() + } + + pub const fn new(n: u32) -> Option> { + // Using Self::MIN forces a compile-time check for MIN > 0 + if n >= Self::MIN.get() { + // SAFETY: n is unsigned and we assert n >= MIN > 0 + unsafe { Some(Self::new_unchecked(n)) } + } else { + None + } + } + + pub const unsafe fn new_unchecked(n: u32) -> LowerBoundedU32 { + // Using Self::MIN forces a compile-time check for MIN > 0 + _ = Self::MIN; + let n = unsafe { NonZero::new_unchecked(n) }; + Self { n } + } +} + +impl fmt::Binary for LowerBoundedU32 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + self.n.fmt(f) + } +} + +macro_rules! impl_bitor_for_lbu32 { + ($typ: path) => { + impl ops::BitOr<$typ> for LowerBoundedU32 { + type Output = LowerBoundedU32; + + fn bitor(mut self, rhs: $typ) -> LowerBoundedU32 { + self |= rhs; + self + } + } + + impl ops::BitOrAssign<$typ> for LowerBoundedU32 { + fn bitor_assign(&mut self, rhs: $typ) { + // SAFETY: Setting more bits can only increase the value + self.n |= u32::from(rhs); + } + } + }; +} + +impl_bitor_for_lbu32!(u32); +impl_bitor_for_lbu32!(LowerBoundedU32); + +impl From> for u32 { + fn from(n: LowerBoundedU32) -> u32 { + n.get() + } +} + +impl Ord for LowerBoundedU32 { + fn cmp(&self, other: &Self) -> Ordering { + self.n.cmp(&other.n) + } +} + +impl PartialOrd for LowerBoundedU32 { + fn partial_cmp(&self, other: &Self) -> Option { + self.n.partial_cmp(&other.n) + } +} + +impl TryFrom for LowerBoundedU32 { + type Error = std::num::TryFromIntError; + + fn try_from( + n: u32, + ) -> Result, std::num::TryFromIntError> { + if let Some(n) = LowerBoundedU32::new(n) { + Ok(n) + } else { + // We can't construct TryFromIntError ourselves but we can + // trigger one to be generated. + u32::try_from(u64::MAX)?; + panic!("u64::MAX -> u32 should have generated an error"); + } + } +} + +/// This struct stores an array of LowerBoundedU32 as a flat fixed-length array +/// in memory. The array as presented to the user is variable-length with a +/// maximum length of `MAX_ARR_IDX + 1`. The lower bound on the data type is +/// required to be at least `MAX_ARR_IDX + 2` so that no LowerBoundedU32 can +/// ever be equal to the array length. By utilizing this constraint, we are +/// able to guarantee two useful invariants: +/// +/// 1. All but the last element is non-zero. If `MAX_ARR_IDX >= 1`, this +/// allows the compiler to optimize Option. If +/// `MAX_ARR_IDX >= 3`, the compiler can also optimize enums where one +/// variant is LowerBoundedU32Array and the other is Box. +/// +/// 2. The array always takes the same amount of space as a static array of max +/// length. The array size is hidden inside the array so we don't burn +/// extra bytes on a usize length. +/// +/// TODO: Improve the interface once the generic_const_exprs feature lands +/// in Rust. +#[repr(C)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct LowerBoundedU32Array { + arr: [LowerBoundedU32; MAX_ARR_IDX], + + /// Last array element or length + last: u32, +} + +impl + LowerBoundedU32Array +{ + /// The maximum length of this array type. This will always be + /// `MAX_ARR_IDX + 1`. + pub const MAX_LEN: usize = { + // Use LowerBoundedU32::MIN to force the invariants check. + _ = LowerBoundedU32::::MIN; + + let max_len = MAX_ARR_IDX + .checked_add(1) + .expect("MAX_ARR_IDX + 1 must not overflow"); + assert!( + max_len <= (u32::MAX as usize) && (max_len as u32) < MIN_U32, + "MAX_ARR_IDX + 1 must not be less than MIN_U32" + ); + max_len + }; + + // SAFETY: MAX_LEN < MIN_U32, which is a u32 + const MAX_LEN_U32: u32 = Self::MAX_LEN as u32; + + pub fn new() -> LowerBoundedU32Array { + // Use MAX_LEN to force the compile-time invariants check here + _ = LowerBoundedU32Array::::MAX_LEN; + LowerBoundedU32Array { + arr: [LowerBoundedU32::MAX; MAX_ARR_IDX], + last: 0, + } + } + + pub fn as_slice(&self) -> &[LowerBoundedU32] { + if self.last < Self::MAX_LEN_U32 { + &self.arr[0..(self.last as usize)] + } else { + // SAFETY: + // + // We only ever place a length or a LowerBoundedU32 in self.last. + // So if it's not a valid length, it must be a valid + // LowerBoundedU32. + debug_assert!(self.last >= MIN_U32); + unsafe { + std::slice::from_raw_parts( + &self.arr as *const _ as *const LowerBoundedU32, + Self::MAX_LEN, + ) + } + } + } + + pub fn as_mut_slice(&mut self) -> &mut [LowerBoundedU32] { + if self.last < Self::MAX_LEN_U32 { + &mut self.arr[0..(self.last as usize)] + } else { + // SAFETY: + // + // We only ever place a length or a LowerBoundedU32 in self.last. + // So if it's not a valid length, it must be a valid + // LowerBoundedU32. + debug_assert!(self.last >= MIN_U32); + unsafe { + std::slice::from_raw_parts_mut( + &mut self.arr as *mut _ as *mut LowerBoundedU32, + Self::MAX_LEN, + ) + } + } + } + + pub fn get(&self, idx: usize) -> Option<&LowerBoundedU32> { + self.as_slice().get(idx) + } + + pub fn get_mut( + &mut self, + idx: usize, + ) -> Option<&mut LowerBoundedU32> { + self.as_mut_slice().get_mut(idx) + } + + pub fn is_empty(&self) -> bool { + self.last == 0 + } + + pub fn iter(&self) -> impl Iterator> { + self.as_slice().iter() + } + + pub fn iter_mut( + &mut self, + ) -> impl Iterator> { + self.as_mut_slice().iter_mut() + } + + pub fn len(&self) -> usize { + if self.last < Self::MAX_LEN_U32 { + self.last as usize + } else { + Self::MAX_LEN + } + } + + /// Tries to pop an element off the end of the array. None is returned if + /// the array is empty. + pub fn pop(&mut self) -> Option> { + if self.last == 0 { + None + } else if self.last < Self::MAX_LEN_U32 { + self.last -= 1; + Some(self.arr[self.last as usize]) + } else { + // SAFETY: + // + // We only ever place a length or a LowerBoundedU32 in self.last. + // So if it's not a valid length, it must be a valid + // LowerBoundedU32. + debug_assert!(self.last >= MIN_U32); + let elem = unsafe { LowerBoundedU32::new_unchecked(self.last) }; + self.last = Self::MAX_LEN_U32 - 1; + Some(elem) + } + } + + /// Tries to push an element onto the array. If the array is full, it + /// returns Err. + pub fn try_push( + &mut self, + val: LowerBoundedU32, + ) -> Result<(), &'static str> { + if self.last < Self::MAX_LEN_U32 { + let idx = self.last as usize; + let new_len = self.last + 1; + if new_len < Self::MAX_LEN_U32 { + self.arr[idx] = val; + self.last = new_len; + } else { + self.last = val.get(); + } + Ok(()) + } else { + Err("Array is full") + } + } +} + +impl Default + for LowerBoundedU32Array +{ + fn default() -> Self { + LowerBoundedU32Array::new() + } +} + +impl ops::Index + for LowerBoundedU32Array +where + I: SliceIndex<[LowerBoundedU32]>, +{ + type Output = ]>>::Output; + + fn index(&self, index: I) -> &Self::Output { + self.as_slice().index(index) + } +} + +struct AssertArraySize {} + +impl AssertArraySize { + const ASSERT: () = { + assert!(N <= MAX_ARR_IDX + 1); + }; +} + +/// This implementation of From allows converting a fixed-length array of +/// LowerBoundedU32 to a LowerBoundedU32Array. Even though it's not specified +/// in the trait bound, it will throw a compile error if the array is too large +/// so it's safe to return a Self without Option or Result. +impl + From<[LowerBoundedU32; N]> + for LowerBoundedU32Array +{ + fn from(arr: [LowerBoundedU32; N]) -> Self { + // Throw a compile error if the array is too long + _ = AssertArraySize::::ASSERT; + arr.as_slice() + .try_into() + .expect("We already verified the array size") + } +} + +impl + TryFrom<&[LowerBoundedU32]> + for LowerBoundedU32Array +{ + type Error = &'static str; + + fn try_from( + arr: &[LowerBoundedU32], + ) -> Result { + let mut out: Self = Default::default(); + for val in arr.iter() { + out.try_push(*val)?; + } + Ok(out) + } +} + +pub struct LowerBoundedU32ArrayIntoIter< + const MIN_U32: u32, + const MAX_ARR_IDX: usize, +> { + arr: LowerBoundedU32Array, + idx: usize, +} + +impl Iterator + for LowerBoundedU32ArrayIntoIter +{ + type Item = LowerBoundedU32; + + fn next(&mut self) -> Option { + if self.idx < self.arr.len() { + let item = self.arr[self.idx]; + self.idx += 1; + Some(item) + } else { + None + } + } +} + +impl FusedIterator + for LowerBoundedU32ArrayIntoIter +{ +} + +impl IntoIterator + for LowerBoundedU32Array +{ + type Item = LowerBoundedU32; + type IntoIter = LowerBoundedU32ArrayIntoIter; + + // Required method + fn into_iter(self) -> Self::IntoIter { + LowerBoundedU32ArrayIntoIter { arr: self, idx: 0 } + } +} + +#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] +const _: () = { + assert!(size_of::>>() == 8); +}; + +#[cfg(test)] +mod tests { + use crate::lower_bounded::*; + + #[test] + fn test_u32_new() { + type TestU32 = LowerBoundedU32<5>; + + for i in 0..u32::from(TestU32::MIN) { + assert!(TestU32::new(i).is_none()); + } + + for i in 0..31 { + let u = 5 | (1 << i); + let Some(lb) = TestU32::new(u) else { + panic!("LowerBoundedU32::new() should have succeeded"); + }; + assert_eq!(lb.get(), u); + assert_eq!(u32::from(lb), u); + } + } + + #[test] + fn test_u32arr_push() { + type TestArray = LowerBoundedU32Array<5, 3>; + let test_data = [ + LowerBoundedU32::new(10_u32).unwrap(), + LowerBoundedU32::new(15_u32).unwrap(), + LowerBoundedU32::new(17_u32).unwrap(), + LowerBoundedU32::new(23_u32).unwrap(), + ]; + let test_data_extra = LowerBoundedU32::new(55_u32).unwrap(); + + // Sanity check before we test + assert_eq!(test_data.len(), TestArray::MAX_LEN); + + let mut arr: TestArray = Default::default(); + for (i, &d) in test_data.iter().enumerate() { + assert_eq!(arr.len(), i); + arr.try_push(d).expect("This push should not fail"); + } + assert_eq!(arr.len(), test_data.len()); + + arr.try_push(test_data_extra) + .expect_err("We tried to push one too many"); + + assert_eq!(arr.len(), test_data.len()); + + for (i, &d) in test_data.iter().enumerate() { + assert_eq!(arr[i], d); + } + } + + #[test] + fn test_u32arr_from_array() { + type TestArray = LowerBoundedU32Array<5, 3>; + + let test_data_short = [ + LowerBoundedU32::new(10_u32).unwrap(), + LowerBoundedU32::new(15_u32).unwrap(), + ]; + + let test_data_full = [ + LowerBoundedU32::new(10_u32).unwrap(), + LowerBoundedU32::new(15_u32).unwrap(), + LowerBoundedU32::new(17_u32).unwrap(), + LowerBoundedU32::new(23_u32).unwrap(), + ]; + + let arr: TestArray = test_data_short.clone().into(); + + assert_eq!(arr.len(), test_data_short.len()); + for (i, &d) in test_data_short.iter().enumerate() { + assert_eq!(arr[i], d); + } + + let arr: TestArray = test_data_full.clone().into(); + + assert_eq!(arr.len(), test_data_full.len()); + for (i, &d) in test_data_full.iter().enumerate() { + assert_eq!(arr[i], d); + } + } + + #[test] + fn test_u32arr_from_slice() { + type TestArray = LowerBoundedU32Array<5, 3>; + + let test_data = [ + LowerBoundedU32::new(10_u32).unwrap(), + LowerBoundedU32::new(15_u32).unwrap(), + LowerBoundedU32::new(17_u32).unwrap(), + LowerBoundedU32::new(23_u32).unwrap(), + LowerBoundedU32::new(55_u32).unwrap(), + ]; + let test_data_short = &test_data[0..2]; + let test_data_full = &test_data[0..4]; + let test_data_too_long = test_data.as_slice(); + + let arr: TestArray = test_data_short + .try_into() + .expect("Should have fit in the TestArray"); + + assert_eq!(arr.len(), test_data_short.len()); + for (i, &d) in test_data_short.iter().enumerate() { + assert_eq!(arr[i], d); + } + + let arr: TestArray = test_data_full + .try_into() + .expect("Should have fit in the TestArray"); + + assert_eq!(arr.len(), test_data_full.len()); + for (i, &d) in test_data_full.iter().enumerate() { + assert_eq!(arr[i], d); + } + + TestArray::try_from(test_data_too_long) + .expect_err("Input data should have been too long"); + } +} diff --git a/src/compiler/rust/meson.build b/src/compiler/rust/meson.build index e7638718350..1768872ef8a 100644 --- a/src/compiler/rust/meson.build +++ b/src/compiler/rust/meson.build @@ -7,6 +7,7 @@ _compiler_rs_sources = [ 'cfg.rs', 'dataflow.rs', 'depth_first_search.rs', + 'lower_bounded.rs', 'memstream.rs', 'nir_instr_printer.rs', 'nir.rs',