mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-06-21 01:38:23 +02:00
compiler/rust/enum_as_u8: Add an ConstU8EnumSet struct
For enums which implement EnumAsU8, it's nice to be able to construct cheap, copyable sets of them. We were already doing this for variants. This codifies and exports it. Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/42200>
This commit is contained in:
parent
651fe2a7f9
commit
d127e2fffa
3 changed files with 72 additions and 11 deletions
|
|
@ -2,6 +2,7 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use crate::bitset::ConstBitSet;
|
||||
use std::fmt;
|
||||
|
||||
/// A trait for enums which are `#[repr(u8)]` which provides some extra sugar
|
||||
/// on top. By deriving this trait with `#[derive(EnumAsU8)]`, you get
|
||||
|
|
@ -9,16 +10,65 @@ use crate::bitset::ConstBitSet;
|
|||
/// an iterator over all valid variants in the enum. This trait works
|
||||
/// regardless of whether or not discriminant are explicitly specified.
|
||||
pub trait EnumAsU8: Sized {
|
||||
const VARIANTS: ConstBitSet<8, u8>;
|
||||
type VariantSet;
|
||||
const VARIANTS: Self::VariantSet;
|
||||
const MAX_DISCRIMINANT: u8;
|
||||
|
||||
fn as_u8(self) -> u8;
|
||||
|
||||
unsafe fn from_u8_unchecked(u: u8) -> Self;
|
||||
}
|
||||
|
||||
fn iter() -> impl Iterator<Item = Self> {
|
||||
Self::VARIANTS
|
||||
.iter()
|
||||
.map(|u| unsafe { Self::from_u8_unchecked(u) })
|
||||
/// A set of EnumAsU8 values.
|
||||
#[derive(Clone, Copy, Eq, PartialEq)]
|
||||
pub struct U8EnumSet<E: EnumAsU8, const N: usize> {
|
||||
set: ConstBitSet<N, u8>,
|
||||
phantom: std::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: EnumAsU8, const N: usize> U8EnumSet<E, N> {
|
||||
const ASSERT: () = {
|
||||
assert!(size_of::<E>() == 1);
|
||||
assert!((E::MAX_DISCRIMINANT as usize) < N * 32);
|
||||
};
|
||||
|
||||
/// SAFETY: Every value in the array must be a valid E discriminant.
|
||||
pub const unsafe fn from_u8_array<const A: usize>(arr: [u8; A]) -> Self {
|
||||
let _ = Self::ASSERT;
|
||||
U8EnumSet {
|
||||
set: ConstBitSet::<N, u8>::from_array(arr),
|
||||
phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.set.is_empty()
|
||||
}
|
||||
|
||||
pub fn contains(&self, e: E) -> bool {
|
||||
self.set.contains(e.as_u8())
|
||||
}
|
||||
|
||||
pub const fn contains_u8(&self, u: u8) -> bool {
|
||||
self.set.contains(u)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = E> + use<'_, E, N> {
|
||||
// SAFETY: We ensure that the only elements added to the set are valid
|
||||
// E values.
|
||||
self.set.iter().map(|u| unsafe { E::from_u8_unchecked(u) })
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: EnumAsU8 + fmt::Debug, const N: usize> fmt::Debug for U8EnumSet<E, N> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "U8EnumSet{{")?;
|
||||
for (i, v) in self.iter().enumerate() {
|
||||
write!(f, "{v:?}")?;
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "}}")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,12 +59,23 @@ pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream {
|
|||
});
|
||||
}
|
||||
|
||||
let var_set_u32s = if max_desc_ts.is_empty() {
|
||||
(usize::from(max_desc) + 1).div_ceil(32)
|
||||
} else {
|
||||
// Worst case
|
||||
256 / 32
|
||||
};
|
||||
|
||||
let ident_s = ident.to_string();
|
||||
let try_from_err = format!("Invalid {ident_s} variant.");
|
||||
let imp = quote! {
|
||||
impl EnumAsU8 for #ident {
|
||||
const VARIANTS: compiler::bitset::ConstBitSet<8, u8> =
|
||||
compiler::bitset::ConstBitSet::<8, u8>::from_array([#variants_ts]);
|
||||
type VariantSet = compiler::enum_as_u8::U8EnumSet<#ident, #var_set_u32s>;
|
||||
const VARIANTS: compiler::enum_as_u8::U8EnumSet<#ident, #var_set_u32s> = {
|
||||
unsafe {
|
||||
compiler::enum_as_u8::U8EnumSet::from_u8_array([#variants_ts])
|
||||
}
|
||||
};
|
||||
const MAX_DISCRIMINANT: u8 = {
|
||||
let mut max_desc = #max_desc;
|
||||
#max_desc_ts
|
||||
|
|
@ -90,7 +101,7 @@ pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream {
|
|||
type Error = &'static str;
|
||||
|
||||
fn try_from(u: u8) -> Result<Self, &'static str> {
|
||||
if Self::VARIANTS.contains(u) {
|
||||
if Self::VARIANTS.contains_u8(u) {
|
||||
Ok(unsafe { Self::from_u8_unchecked(u) })
|
||||
} else {
|
||||
Err(#try_from_err)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use crate::data_type::DataType;
|
||||
use compiler::enum_as_u8::EnumAsU8;
|
||||
use compiler::enum_as_u8::*;
|
||||
use compiler::float16::F16;
|
||||
use kraid_proc_macros::{AsmSwizzleWiden, EnumAsU8};
|
||||
use std::fmt;
|
||||
|
|
@ -236,7 +236,7 @@ impl Swizzle {
|
|||
let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8;
|
||||
|
||||
// SAFETY: We only ever construct Swizzle from SwizzleByte
|
||||
debug_assert!(SwizzleByte::VARIANTS.contains(b));
|
||||
debug_assert!(SwizzleByte::VARIANTS.contains_u8(b));
|
||||
Some(unsafe { std::mem::transmute(b) })
|
||||
}
|
||||
}
|
||||
|
|
@ -250,7 +250,7 @@ impl Swizzle {
|
|||
let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8;
|
||||
|
||||
// SAFETY: We only ever construct Swizzle from SwizzleByte
|
||||
debug_assert!(SwizzleWord::VARIANTS.contains(b));
|
||||
debug_assert!(SwizzleWord::VARIANTS.contains_u8(b));
|
||||
Some(unsafe { std::mem::transmute(b) })
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue