diff --git a/src/compiler/rust/enum_as_u8.rs b/src/compiler/rust/enum_as_u8.rs new file mode 100644 index 00000000000..ad778d84d7a --- /dev/null +++ b/src/compiler/rust/enum_as_u8.rs @@ -0,0 +1,23 @@ +// Copyright © 2023 Collabora, Ltd. +// SPDX-License-Identifier: MIT + +use crate::bitset::ConstBitSet; + +/// A trait for enums which are `#[repr(u8)]` which provides some extra sugar +/// on top. By deriving this trait with `#[derive(EnumAsU8)]`, you get +/// `From for u8` and `TryFrom for MyEnum` for free as well as +/// an iterator over all valid variants in the enum. This trait works +/// regardless of whether or not discriminant are explicitly specified. +pub trait EnumAsU8: Sized { + const VARIANTS: ConstBitSet<8, u8>; + + fn as_u8(self) -> u8; + + unsafe fn from_u8_unchecked(u: u8) -> Self; + + fn iter() -> impl Iterator { + Self::VARIANTS + .iter() + .map(|u| unsafe { Self::from_u8_unchecked(u) }) + } +} diff --git a/src/compiler/rust/lib.rs b/src/compiler/rust/lib.rs index 1ac2d946303..dfc5788b2ed 100644 --- a/src/compiler/rust/lib.rs +++ b/src/compiler/rust/lib.rs @@ -7,6 +7,7 @@ pub mod bitset; pub mod cfg; pub mod dataflow; pub mod depth_first_search; +pub mod enum_as_u8; pub mod float16; pub mod lower_bounded; pub mod memstream; diff --git a/src/compiler/rust/meson.build b/src/compiler/rust/meson.build index 6de5a5ebe16..a18f0eb2bf6 100644 --- a/src/compiler/rust/meson.build +++ b/src/compiler/rust/meson.build @@ -7,6 +7,7 @@ _compiler_rs_sources = [ 'cfg.rs', 'dataflow.rs', 'depth_first_search.rs', + 'enum_as_u8.rs', 'float16.rs', 'lower_bounded.rs', 'memstream.rs', diff --git a/src/compiler/rust/proc/enum_as_u8.rs b/src/compiler/rust/proc/enum_as_u8.rs new file mode 100644 index 00000000000..9a83ce8beaa --- /dev/null +++ b/src/compiler/rust/proc/enum_as_u8.rs @@ -0,0 +1,74 @@ +// Copyright © 2023 Collabora, Ltd. +// SPDX-License-Identifier: MIT + +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use syn::*; + +pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream { + let DeriveInput { + attrs, ident, data, .. + } = parse_macro_input!(input); + + let Data::Enum(e) = data else { + panic!("EnumAsU8 can only be derived for enum types"); + }; + + let mut has_repr_u8 = false; + for attr in attrs { + if let Meta::List(ml) = attr.meta { + if ml.path.is_ident("repr") && format!("{}", ml.tokens) == "u8" { + has_repr_u8 = true; + break; + } + } + } + + if !has_repr_u8 { + panic!("EnumAsU8 can only be derived for enum which are #[repr(u8)]"); + }; + + let mut variants = TokenStream2::new(); + for v in e.variants { + let v_ident = v.ident; + variants.extend(quote! { + #ident::#v_ident as u8, + }); + } + + let ident_s = ident.to_string(); + let try_from_err = format!("Invalid {ident_s} variant."); + let imp = quote! { + impl EnumAsU8 for #ident { + const VARIANTS: compiler::bitset::ConstBitSet<8, u8> = + compiler::bitset::ConstBitSet::<8, u8>::from_array([#variants]); + + fn as_u8(self) -> u8 { + self as u8 + } + + unsafe fn from_u8_unchecked(u: u8) -> Self { + unsafe { std::mem::transmute(u) } + } + } + + impl From<#ident> for u8 { + fn from(e: #ident) -> Self { + e.as_u8() + } + } + + impl TryFrom for #ident { + type Error = &'static str; + + fn try_from(u: u8) -> Result { + if Self::VARIANTS.contains(u) { + Ok(unsafe { Self::from_u8_unchecked(u) }) + } else { + Err(#try_from_err) + } + } + } + }; + imp.into() +} diff --git a/src/compiler/rust/proc/lib.rs b/src/compiler/rust/proc/lib.rs index cb4f9940fd5..e632295252b 100644 --- a/src/compiler/rust/proc/lib.rs +++ b/src/compiler/rust/proc/lib.rs @@ -8,4 +8,5 @@ extern crate quote; extern crate syn; pub mod as_slice; +pub mod enum_as_u8; pub mod from_variants;