diff --git a/src/compiler/rust/enum_as_u8.rs b/src/compiler/rust/enum_as_u8.rs index 568efe0916a..c919ec1cb76 100644 --- a/src/compiler/rust/enum_as_u8.rs +++ b/src/compiler/rust/enum_as_u8.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT use crate::bitset::ConstBitSet; +use std::fmt; /// A trait for enums which are `#[repr(u8)]` which provides some extra sugar /// on top. By deriving this trait with `#[derive(EnumAsU8)]`, you get @@ -9,16 +10,65 @@ use crate::bitset::ConstBitSet; /// an iterator over all valid variants in the enum. This trait works /// regardless of whether or not discriminant are explicitly specified. pub trait EnumAsU8: Sized { - const VARIANTS: ConstBitSet<8, u8>; + type VariantSet; + const VARIANTS: Self::VariantSet; const MAX_DISCRIMINANT: u8; fn as_u8(self) -> u8; unsafe fn from_u8_unchecked(u: u8) -> Self; +} - fn iter() -> impl Iterator { - Self::VARIANTS - .iter() - .map(|u| unsafe { Self::from_u8_unchecked(u) }) +/// A set of EnumAsU8 values. +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct U8EnumSet { + set: ConstBitSet, + phantom: std::marker::PhantomData, +} + +impl U8EnumSet { + const ASSERT: () = { + assert!(size_of::() == 1); + assert!((E::MAX_DISCRIMINANT as usize) < N * 32); + }; + + /// SAFETY: Every value in the array must be a valid E discriminant. + pub const unsafe fn from_u8_array(arr: [u8; A]) -> Self { + let _ = Self::ASSERT; + U8EnumSet { + set: ConstBitSet::::from_array(arr), + phantom: std::marker::PhantomData, + } + } + + pub fn is_empty(&self) -> bool { + self.set.is_empty() + } + + pub fn contains(&self, e: E) -> bool { + self.set.contains(e.as_u8()) + } + + pub const fn contains_u8(&self, u: u8) -> bool { + self.set.contains(u) + } + + pub fn iter(&self) -> impl Iterator + use<'_, E, N> { + // SAFETY: We ensure that the only elements added to the set are valid + // E values. + self.set.iter().map(|u| unsafe { E::from_u8_unchecked(u) }) + } +} + +impl fmt::Debug for U8EnumSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "U8EnumSet{{")?; + for (i, v) in self.iter().enumerate() { + write!(f, "{v:?}")?; + if i > 0 { + write!(f, ", ")?; + } + } + write!(f, "}}") } } diff --git a/src/compiler/rust/proc/enum_as_u8.rs b/src/compiler/rust/proc/enum_as_u8.rs index ad6a442619e..9c1940cbfdf 100644 --- a/src/compiler/rust/proc/enum_as_u8.rs +++ b/src/compiler/rust/proc/enum_as_u8.rs @@ -59,12 +59,23 @@ pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream { }); } + let var_set_u32s = if max_desc_ts.is_empty() { + (usize::from(max_desc) + 1).div_ceil(32) + } else { + // Worst case + 256 / 32 + }; + let ident_s = ident.to_string(); let try_from_err = format!("Invalid {ident_s} variant."); let imp = quote! { impl EnumAsU8 for #ident { - const VARIANTS: compiler::bitset::ConstBitSet<8, u8> = - compiler::bitset::ConstBitSet::<8, u8>::from_array([#variants_ts]); + type VariantSet = compiler::enum_as_u8::U8EnumSet<#ident, #var_set_u32s>; + const VARIANTS: compiler::enum_as_u8::U8EnumSet<#ident, #var_set_u32s> = { + unsafe { + compiler::enum_as_u8::U8EnumSet::from_u8_array([#variants_ts]) + } + }; const MAX_DISCRIMINANT: u8 = { let mut max_desc = #max_desc; #max_desc_ts @@ -90,7 +101,7 @@ pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream { type Error = &'static str; fn try_from(u: u8) -> Result { - if Self::VARIANTS.contains(u) { + if Self::VARIANTS.contains_u8(u) { Ok(unsafe { Self::from_u8_unchecked(u) }) } else { Err(#try_from_err) diff --git a/src/panfrost/compiler/kraid/swizzle.rs b/src/panfrost/compiler/kraid/swizzle.rs index 26a72eb7bef..71e1c536df0 100644 --- a/src/panfrost/compiler/kraid/swizzle.rs +++ b/src/panfrost/compiler/kraid/swizzle.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT use crate::data_type::DataType; -use compiler::enum_as_u8::EnumAsU8; +use compiler::enum_as_u8::*; use compiler::float16::F16; use kraid_proc_macros::{AsmSwizzleWiden, EnumAsU8}; use std::fmt; @@ -236,7 +236,7 @@ impl Swizzle { let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8; // SAFETY: We only ever construct Swizzle from SwizzleByte - debug_assert!(SwizzleByte::VARIANTS.contains(b)); + debug_assert!(SwizzleByte::VARIANTS.contains_u8(b)); Some(unsafe { std::mem::transmute(b) }) } } @@ -250,7 +250,7 @@ impl Swizzle { let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8; // SAFETY: We only ever construct Swizzle from SwizzleByte - debug_assert!(SwizzleWord::VARIANTS.contains(b)); + debug_assert!(SwizzleWord::VARIANTS.contains_u8(b)); Some(unsafe { std::mem::transmute(b) }) } else { None