diff --git a/src/compiler/rust/as_slice.rs b/src/compiler/rust/as_slice.rs index 8e6657cc36a..c4e7e6cc04e 100644 --- a/src/compiler/rust/as_slice.rs +++ b/src/compiler/rust/as_slice.rs @@ -3,30 +3,13 @@ use std::ops::Deref; use std::ops::DerefMut; -use std::ops::Index; - -pub enum AttrList { - Array(&'static [T]), - Uniform(T), -} - -impl Index for AttrList { - type Output = T; - - fn index(&self, idx: usize) -> &T { - match self { - AttrList::Array(arr) => &arr[idx], - AttrList::Uniform(typ) => typ, - } - } -} pub trait AsSlice { type Attr; fn as_slice(&self) -> &[T]; fn as_mut_slice(&mut self) -> &mut [T]; - fn attrs(&self) -> AttrList; + fn attrs(&self) -> &'static [Self::Attr]; } impl AsSlice for Box @@ -41,7 +24,7 @@ where fn as_mut_slice(&mut self) -> &mut [T] { self.deref_mut().as_mut_slice() } - fn attrs(&self) -> AttrList { + fn attrs(&self) -> &'static [Self::Attr] { self.deref().attrs() } } diff --git a/src/compiler/rust/proc/as_slice.rs b/src/compiler/rust/proc/as_slice.rs index 6ea28506abf..725b88d4f67 100644 --- a/src/compiler/rust/proc/as_slice.rs +++ b/src/compiler/rust/proc/as_slice.rs @@ -142,9 +142,9 @@ pub fn derive_as_slice( } } - fn attrs(&self) -> AttrList { + fn attrs(&self) -> &'static [Self::Attr] { static ATTRS: [#attr_type; #count] = [#attrs]; - AttrList::Array(&ATTRS) + &ATTRS } } } @@ -161,8 +161,8 @@ pub fn derive_as_slice( &mut [] } - fn attrs(&self) -> AttrList { - AttrList::Uniform(#attr_type::DEFAULT) + fn attrs(&self) -> &'static [Self::Attr] { + &[] } } } @@ -203,7 +203,7 @@ pub fn derive_as_slice( } } - fn attrs(&self) -> AttrList { + fn attrs(&self) -> &'static [Self::Attr] { match self { #types_cases } diff --git a/src/compiler/rust/proc/from_variants.rs b/src/compiler/rust/proc/from_variants.rs new file mode 100644 index 00000000000..9c79986381d --- /dev/null +++ b/src/compiler/rust/proc/from_variants.rs @@ -0,0 +1,42 @@ +// Copyright © 2023 Collabora, Ltd. +// SPDX-License-Identifier: MIT + +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use syn::*; + +pub fn derive_from_variants(input: TokenStream) -> TokenStream { + let DeriveInput { ident, data, .. } = parse_macro_input!(input); + let enum_type = ident; + + let mut impls = TokenStream2::new(); + + if let Data::Enum(e) = data { + for v in e.variants { + let var_ident = v.ident; + let from_type = match v.fields { + Fields::Named(_) => { + panic!("FromVariants does not support named fields") + } + Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => unnamed, + Fields::Unit => continue, + }; + + assert!( + from_type.len() == 1, + "FromVariants does not support multiple unnamed fields" + ); + let from_type = &from_type.first().unwrap().ty; + + impls.extend(quote! { + impl From<#from_type> for #enum_type { + fn from (v: #from_type) -> #enum_type { + #enum_type::#var_ident(v) + } + } + }); + } + } + + impls.into() +} diff --git a/src/compiler/rust/proc/lib.rs b/src/compiler/rust/proc/lib.rs index 144658e1b0b..cb4f9940fd5 100644 --- a/src/compiler/rust/proc/lib.rs +++ b/src/compiler/rust/proc/lib.rs @@ -8,3 +8,4 @@ extern crate quote; extern crate syn; pub mod as_slice; +pub mod from_variants; diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 8d5feffb9c7..29c9a5cb93a 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -1345,9 +1345,39 @@ impl SrcType { } } +pub enum AttrList { + Array(&'static [T]), + Uniform(T), +} + +impl Index for AttrList { + type Output = T; + + fn index(&self, idx: usize) -> &T { + match self { + AttrList::Array(arr) => &arr[idx], + AttrList::Uniform(typ) => typ, + } + } +} + pub type SrcTypeList = AttrList; -pub trait SrcsAsSlice: AsSlice { +pub trait SrcsAsSlice { + fn srcs_as_slice(&self) -> &[Src]; + + fn srcs_as_mut_slice(&mut self) -> &mut [Src]; + + fn src_types(&self) -> SrcTypeList; + + fn src_idx(&self, src: &Src) -> usize { + let r = self.srcs_as_slice().as_ptr_range(); + assert!(r.contains(&(src as *const Src))); + unsafe { (src as *const Src).offset_from(r.start) as usize } + } +} + +impl> SrcsAsSlice for T { fn srcs_as_slice(&self) -> &[Src] { self.as_slice() } @@ -1357,18 +1387,10 @@ pub trait SrcsAsSlice: AsSlice { } fn src_types(&self) -> SrcTypeList { - self.attrs() - } - - fn src_idx(&self, src: &Src) -> usize { - let r = self.srcs_as_slice().as_ptr_range(); - assert!(r.contains(&(src as *const Src))); - unsafe { (src as *const Src).offset_from(r.start) as usize } + AttrList::Array(self.attrs()) } } -impl> SrcsAsSlice for T {} - fn all_dsts_uniform(dsts: &[Dst]) -> bool { let mut uniform = None; for dst in dsts { @@ -1384,6 +1406,7 @@ fn all_dsts_uniform(dsts: &[Dst]) -> bool { } #[repr(u8)] +#[allow(dead_code)] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum DstType { Pred, @@ -1403,20 +1426,14 @@ impl DstType { pub type DstTypeList = AttrList; -pub trait DstsAsSlice: AsSlice { - fn dsts_as_slice(&self) -> &[Dst] { - self.as_slice() - } +pub trait DstsAsSlice { + fn dsts_as_slice(&self) -> &[Dst]; - fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { - self.as_mut_slice() - } + fn dsts_as_mut_slice(&mut self) -> &mut [Dst]; // Currently only used by test code #[allow(dead_code)] - fn dst_types(&self) -> DstTypeList { - self.attrs() - } + fn dst_types(&self) -> DstTypeList; fn dst_idx(&self, dst: &Dst) -> usize { let r = self.dsts_as_slice().as_ptr_range(); @@ -1425,7 +1442,19 @@ pub trait DstsAsSlice: AsSlice { } } -impl> DstsAsSlice for T {} +impl> DstsAsSlice for T { + fn dsts_as_slice(&self) -> &[Dst] { + self.as_slice() + } + + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + self.as_mut_slice() + } + + fn dst_types(&self) -> DstTypeList { + AttrList::Array(self.attrs()) + } +} pub trait IsUniform { fn is_uniform(&self) -> bool; @@ -3178,18 +3207,16 @@ pub struct OpMuFu { pub op_type: FloatType, } -impl AsSlice for OpMuFu { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpMuFu { + fn srcs_as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { let src_type = match self.op_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -4695,18 +4722,16 @@ pub struct OpF2F { pub integer_rnd: bool, } -impl AsSlice for OpF2F { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpF2F { + fn srcs_as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { let src_type = match self.src_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -4716,18 +4741,16 @@ impl AsSlice for OpF2F { } } -impl AsSlice for OpF2F { - type Attr = DstType; - - fn as_slice(&self) -> &[Dst] { +impl DstsAsSlice for OpF2F { + fn dsts_as_slice(&self) -> &[Dst] { std::slice::from_ref(&self.dst) } - fn as_mut_slice(&mut self) -> &mut [Dst] { + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { std::slice::from_mut(&mut self.dst) } - fn attrs(&self) -> DstTypeList { + fn dst_types(&self) -> DstTypeList { let dst_type = match self.dst_type { FloatType::F16 => DstType::F16, FloatType::F32 => DstType::F32, @@ -4792,18 +4815,16 @@ pub struct OpF2I { pub ftz: bool, } -impl AsSlice for OpF2I { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpF2I { + fn srcs_as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { let src_type = match self.src_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -4835,18 +4856,16 @@ pub struct OpI2F { pub rnd_mode: FRndMode, } -impl AsSlice for OpI2F { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpI2F { + fn srcs_as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { if self.src_type.bits() <= 32 { SrcTypeList::Uniform(SrcType::ALU) } else { @@ -4855,18 +4874,16 @@ impl AsSlice for OpI2F { } } -impl AsSlice for OpI2F { - type Attr = DstType; - - fn as_slice(&self) -> &[Dst] { +impl DstsAsSlice for OpI2F { + fn dsts_as_slice(&self) -> &[Dst] { std::slice::from_ref(&self.dst) } - fn as_mut_slice(&mut self) -> &mut [Dst] { + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { std::slice::from_mut(&mut self.dst) } - fn attrs(&self) -> DstTypeList { + fn dst_types(&self) -> DstTypeList { let dst_type = match self.dst_type { FloatType::F16 => DstType::F16, FloatType::F32 => DstType::F32, @@ -4937,18 +4954,16 @@ pub struct OpFRnd { pub ftz: bool, } -impl AsSlice for OpFRnd { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpFRnd { + fn srcs_as_slice(&self) -> &[Src] { std::slice::from_ref(&self.src) } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { std::slice::from_mut(&mut self.src) } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { let src_type = match self.src_type { FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, @@ -7787,18 +7802,16 @@ impl OpPhiSrcs { } } -impl AsSlice for OpPhiSrcs { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpPhiSrcs { + fn srcs_as_slice(&self) -> &[Src] { &self.srcs.b } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { &mut self.srcs.b } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { SrcTypeList::Uniform(SrcType::GPR) } } @@ -7836,18 +7849,16 @@ impl OpPhiDsts { } } -impl AsSlice for OpPhiDsts { - type Attr = DstType; - - fn as_slice(&self) -> &[Dst] { +impl DstsAsSlice for OpPhiDsts { + fn dsts_as_slice(&self) -> &[Dst] { &self.dsts.b } - fn as_mut_slice(&mut self) -> &mut [Dst] { + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { &mut self.dsts.b } - fn attrs(&self) -> DstTypeList { + fn dst_types(&self) -> DstTypeList { DstTypeList::Uniform(DstType::Vec) } } @@ -7953,34 +7964,30 @@ impl OpParCopy { } } -impl AsSlice for OpParCopy { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpParCopy { + fn srcs_as_slice(&self) -> &[Src] { &self.dsts_srcs.b } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { &mut self.dsts_srcs.b } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { SrcTypeList::Uniform(SrcType::GPR) } } -impl AsSlice for OpParCopy { - type Attr = DstType; - - fn as_slice(&self) -> &[Dst] { +impl DstsAsSlice for OpParCopy { + fn dsts_as_slice(&self) -> &[Dst] { &self.dsts_srcs.a } - fn as_mut_slice(&mut self) -> &mut [Dst] { + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { &mut self.dsts_srcs.a } - fn attrs(&self) -> DstTypeList { + fn dst_types(&self) -> DstTypeList { DstTypeList::Uniform(DstType::Vec) } } @@ -8009,18 +8016,16 @@ pub struct OpRegOut { pub srcs: Vec, } -impl AsSlice for OpRegOut { - type Attr = SrcType; - - fn as_slice(&self) -> &[Src] { +impl SrcsAsSlice for OpRegOut { + fn srcs_as_slice(&self) -> &[Src] { &self.srcs } - fn as_mut_slice(&mut self) -> &mut [Src] { + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { &mut self.srcs } - fn attrs(&self) -> SrcTypeList { + fn src_types(&self) -> SrcTypeList { SrcTypeList::Uniform(SrcType::GPR) } } @@ -8251,6 +8256,16 @@ const _: () = { debug_assert!(size_of::() == 16); }; +// The DisplayOp constraint exists to keep the type system from recursing +impl From for Op +where + Box: Into, +{ + fn from(op: T) -> Self { + Box::new(op).into() + } +} + impl Op { pub fn is_branch(&self) -> bool { matches!( diff --git a/src/nouveau/compiler/nak/ir_proc.rs b/src/nouveau/compiler/nak/ir_proc.rs index 4f9ad615468..3fd9dccaa04 100644 --- a/src/nouveau/compiler/nak/ir_proc.rs +++ b/src/nouveau/compiler/nak/ir_proc.rs @@ -14,12 +14,98 @@ use syn::*; #[proc_macro_derive(SrcsAsSlice, attributes(src_type))] pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream { - derive_as_slice(input, "Src", "src_type", "SrcType") + let input2 = input.clone(); + let DeriveInput { ident, data, .. } = parse_macro_input!(input2); + + if let Data::Enum(e) = data { + let mut as_slice_cases = TokenStream2::new(); + let mut as_mut_slice_cases = TokenStream2::new(); + let mut types_cases = TokenStream2::new(); + for v in e.variants { + let case = v.ident; + as_slice_cases.extend(quote! { + #ident::#case(x) => x.srcs_as_slice(), + }); + as_mut_slice_cases.extend(quote! { + #ident::#case(x) => x.srcs_as_mut_slice(), + }); + types_cases.extend(quote! { + #ident::#case(x) => x.src_types(), + }); + } + quote! { + impl SrcsAsSlice for #ident { + fn srcs_as_slice(&self) -> &[Src] { + match self { + #as_slice_cases + } + } + + fn srcs_as_mut_slice(&mut self) -> &mut [Src] { + match self { + #as_mut_slice_cases + } + } + + fn src_types(&self) -> SrcTypeList { + match self { + #types_cases + } + } + } + } + .into() + } else { + derive_as_slice(input, "Src", "src_type", "SrcType") + } } #[proc_macro_derive(DstsAsSlice, attributes(dst_type))] pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream { - derive_as_slice(input, "Dst", "dst_type", "DstType") + let input2 = input.clone(); + let DeriveInput { ident, data, .. } = parse_macro_input!(input2); + + if let Data::Enum(e) = data { + let mut as_slice_cases = TokenStream2::new(); + let mut as_mut_slice_cases = TokenStream2::new(); + let mut types_cases = TokenStream2::new(); + for v in e.variants { + let case = v.ident; + as_slice_cases.extend(quote! { + #ident::#case(x) => x.dsts_as_slice(), + }); + as_mut_slice_cases.extend(quote! { + #ident::#case(x) => x.dsts_as_mut_slice(), + }); + types_cases.extend(quote! { + #ident::#case(x) => x.dst_types(), + }); + } + quote! { + impl DstsAsSlice for #ident { + fn dsts_as_slice(&self) -> &[Dst] { + match self { + #as_slice_cases + } + } + + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + match self { + #as_mut_slice_cases + } + } + + fn dst_types(&self) -> DstTypeList { + match self { + #types_cases + } + } + } + } + .into() + } else { + derive_as_slice(input, "Dst", "dst_type", "DstType") + } } #[proc_macro_derive(DisplayOp)] @@ -59,73 +145,7 @@ pub fn enum_derive_display_op(input: TokenStream) -> TokenStream { } } -fn into_box_inner_type<'a>(from_type: &'a syn::Type) -> Option<&'a syn::Type> { - let last = match from_type { - Type::Path(TypePath { path, .. }) => path.segments.last()?, - _ => return None, - }; - - if last.ident != "Box" { - return None; - } - - let PathArguments::AngleBracketed(AngleBracketedGenericArguments { - args, - .. - }) = &last.arguments - else { - panic!("Expected Box (with angle brackets)"); - }; - - for arg in args { - if let GenericArgument::Type(inner_type) = arg { - return Some(inner_type); - } - } - panic!("Expected Box to use a type argument"); -} - #[proc_macro_derive(FromVariants)] pub fn derive_from_variants(input: TokenStream) -> TokenStream { - let DeriveInput { ident, data, .. } = parse_macro_input!(input); - let enum_type = ident; - - let mut impls = TokenStream2::new(); - - if let Data::Enum(e) = data { - for v in e.variants { - let var_ident = v.ident; - let from_type = match v.fields { - Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => unnamed, - _ => panic!("Expected Op(OpFoo)"), - }; - - assert!(from_type.len() == 1, "Expected Op(OpFoo)"); - let from_type = &from_type.first().unwrap().ty; - - let quote = quote! { - impl From<#from_type> for #enum_type { - fn from (op: #from_type) -> #enum_type { - #enum_type::#var_ident(op) - } - } - }; - - impls.extend(quote); - - if let Some(inner_type) = into_box_inner_type(from_type) { - let quote = quote! { - impl From<#inner_type> for #enum_type { - fn from(value: #inner_type) -> Self { - From::from(Box::new(value)) - } - } - }; - - impls.extend(quote); - } - } - } - - impls.into() + compiler_proc::from_variants::derive_from_variants(input) }