From d127e2fffa79201d20fbaee8af68b4d26fd03f9f Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Fri, 5 Jun 2026 12:48:46 -0400 Subject: [PATCH] compiler/rust/enum_as_u8: Add an ConstU8EnumSet struct For enums which implement EnumAsU8, it's nice to be able to construct cheap, copyable sets of them. We were already doing this for variants. This codifies and exports it. Part-of: --- src/compiler/rust/enum_as_u8.rs | 60 +++++++++++++++++++++++--- src/compiler/rust/proc/enum_as_u8.rs | 17 ++++++-- src/panfrost/compiler/kraid/swizzle.rs | 6 +-- 3 files changed, 72 insertions(+), 11 deletions(-) diff --git a/src/compiler/rust/enum_as_u8.rs b/src/compiler/rust/enum_as_u8.rs index 568efe0916a..c919ec1cb76 100644 --- a/src/compiler/rust/enum_as_u8.rs +++ b/src/compiler/rust/enum_as_u8.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT use crate::bitset::ConstBitSet; +use std::fmt; /// A trait for enums which are `#[repr(u8)]` which provides some extra sugar /// on top. By deriving this trait with `#[derive(EnumAsU8)]`, you get @@ -9,16 +10,65 @@ use crate::bitset::ConstBitSet; /// an iterator over all valid variants in the enum. This trait works /// regardless of whether or not discriminant are explicitly specified. pub trait EnumAsU8: Sized { - const VARIANTS: ConstBitSet<8, u8>; + type VariantSet; + const VARIANTS: Self::VariantSet; const MAX_DISCRIMINANT: u8; fn as_u8(self) -> u8; unsafe fn from_u8_unchecked(u: u8) -> Self; +} - fn iter() -> impl Iterator { - Self::VARIANTS - .iter() - .map(|u| unsafe { Self::from_u8_unchecked(u) }) +/// A set of EnumAsU8 values. +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct U8EnumSet { + set: ConstBitSet, + phantom: std::marker::PhantomData, +} + +impl U8EnumSet { + const ASSERT: () = { + assert!(size_of::() == 1); + assert!((E::MAX_DISCRIMINANT as usize) < N * 32); + }; + + /// SAFETY: Every value in the array must be a valid E discriminant. + pub const unsafe fn from_u8_array(arr: [u8; A]) -> Self { + let _ = Self::ASSERT; + U8EnumSet { + set: ConstBitSet::::from_array(arr), + phantom: std::marker::PhantomData, + } + } + + pub fn is_empty(&self) -> bool { + self.set.is_empty() + } + + pub fn contains(&self, e: E) -> bool { + self.set.contains(e.as_u8()) + } + + pub const fn contains_u8(&self, u: u8) -> bool { + self.set.contains(u) + } + + pub fn iter(&self) -> impl Iterator + use<'_, E, N> { + // SAFETY: We ensure that the only elements added to the set are valid + // E values. + self.set.iter().map(|u| unsafe { E::from_u8_unchecked(u) }) + } +} + +impl fmt::Debug for U8EnumSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "U8EnumSet{{")?; + for (i, v) in self.iter().enumerate() { + write!(f, "{v:?}")?; + if i > 0 { + write!(f, ", ")?; + } + } + write!(f, "}}") } } diff --git a/src/compiler/rust/proc/enum_as_u8.rs b/src/compiler/rust/proc/enum_as_u8.rs index ad6a442619e..9c1940cbfdf 100644 --- a/src/compiler/rust/proc/enum_as_u8.rs +++ b/src/compiler/rust/proc/enum_as_u8.rs @@ -59,12 +59,23 @@ pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream { }); } + let var_set_u32s = if max_desc_ts.is_empty() { + (usize::from(max_desc) + 1).div_ceil(32) + } else { + // Worst case + 256 / 32 + }; + let ident_s = ident.to_string(); let try_from_err = format!("Invalid {ident_s} variant."); let imp = quote! { impl EnumAsU8 for #ident { - const VARIANTS: compiler::bitset::ConstBitSet<8, u8> = - compiler::bitset::ConstBitSet::<8, u8>::from_array([#variants_ts]); + type VariantSet = compiler::enum_as_u8::U8EnumSet<#ident, #var_set_u32s>; + const VARIANTS: compiler::enum_as_u8::U8EnumSet<#ident, #var_set_u32s> = { + unsafe { + compiler::enum_as_u8::U8EnumSet::from_u8_array([#variants_ts]) + } + }; const MAX_DISCRIMINANT: u8 = { let mut max_desc = #max_desc; #max_desc_ts @@ -90,7 +101,7 @@ pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream { type Error = &'static str; fn try_from(u: u8) -> Result { - if Self::VARIANTS.contains(u) { + if Self::VARIANTS.contains_u8(u) { Ok(unsafe { Self::from_u8_unchecked(u) }) } else { Err(#try_from_err) diff --git a/src/panfrost/compiler/kraid/swizzle.rs b/src/panfrost/compiler/kraid/swizzle.rs index 26a72eb7bef..71e1c536df0 100644 --- a/src/panfrost/compiler/kraid/swizzle.rs +++ b/src/panfrost/compiler/kraid/swizzle.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT use crate::data_type::DataType; -use compiler::enum_as_u8::EnumAsU8; +use compiler::enum_as_u8::*; use compiler::float16::F16; use kraid_proc_macros::{AsmSwizzleWiden, EnumAsU8}; use std::fmt; @@ -236,7 +236,7 @@ impl Swizzle { let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8; // SAFETY: We only ever construct Swizzle from SwizzleByte - debug_assert!(SwizzleByte::VARIANTS.contains(b)); + debug_assert!(SwizzleByte::VARIANTS.contains_u8(b)); Some(unsafe { std::mem::transmute(b) }) } } @@ -250,7 +250,7 @@ impl Swizzle { let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8; // SAFETY: We only ever construct Swizzle from SwizzleByte - debug_assert!(SwizzleWord::VARIANTS.contains(b)); + debug_assert!(SwizzleWord::VARIANTS.contains_u8(b)); Some(unsafe { std::mem::transmute(b) }) } else { None