From 841737925ff454a9e8748346354be09aefa77766 Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Fri, 19 Jul 2024 13:26:17 -0500 Subject: [PATCH] nak: Add dst_type decorations This is similar to the src_type decorations we have scattered all over the IR, only for destinations. Part-of: --- src/nouveau/compiler/nak/ir.rs | 186 +++++++++++++++++++++++----- src/nouveau/compiler/nak/ir_proc.rs | 106 +++++++--------- 2 files changed, 203 insertions(+), 89 deletions(-) diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 413fdc9f46f..826f3e112d3 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -1045,22 +1045,6 @@ impl SrcMod { } } -#[repr(u8)] -#[derive(Clone, Copy, Eq, Hash, PartialEq)] -pub enum SrcType { - SSA, - GPR, - ALU, - F16, - F16v2, - F32, - F64, - I32, - B32, - Pred, - Bar, -} - #[derive(Clone, Copy, PartialEq)] #[allow(dead_code)] pub enum SrcSwizzle { @@ -1395,26 +1379,44 @@ impl fmt::Display for Src { } } +#[repr(u8)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum SrcType { + SSA, + GPR, + ALU, + F16, + F16v2, + F32, + F64, + I32, + B32, + Pred, + Bar, +} + impl SrcType { const DEFAULT: SrcType = SrcType::GPR; } -pub enum SrcTypeList { - Array(&'static [SrcType]), - Uniform(SrcType), +pub enum TypeList { + Array(&'static [T]), + Uniform(T), } -impl Index for SrcTypeList { - type Output = SrcType; +impl Index for TypeList { + type Output = T; - fn index(&self, idx: usize) -> &SrcType { + fn index(&self, idx: usize) -> &T { match self { - SrcTypeList::Array(arr) => &arr[idx], - SrcTypeList::Uniform(typ) => typ, + TypeList::Array(arr) => &arr[idx], + TypeList::Uniform(typ) => typ, } } } +pub type SrcTypeList = TypeList; + pub trait SrcsAsSlice { fn srcs_as_slice(&self) -> &[Src]; fn srcs_as_mut_slice(&mut self) -> &mut [Src]; @@ -1435,9 +1437,29 @@ fn all_dsts_uniform(dsts: &[Dst]) -> bool { uniform == Some(true) } +#[repr(u8)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum DstType { + Pred, + GPR, + F16, + F16v2, + F32, + F64, + Bar, + Vec, +} + +impl DstType { + const DEFAULT: DstType = DstType::Vec; +} + +pub type DstTypeList = TypeList; + pub trait DstsAsSlice { fn dsts_as_slice(&self) -> &[Dst]; fn dsts_as_mut_slice(&mut self) -> &mut [Dst]; + fn dst_types(&self) -> DstTypeList; fn is_uniform(&self) -> bool { all_dsts_uniform(self.dsts_as_slice()) @@ -2335,6 +2357,7 @@ pub struct AttrAccess { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFAdd { + #[dst_type(F32)] pub dst: Dst, #[src_type(F32)] @@ -2363,6 +2386,7 @@ impl_display_for_op!(OpFAdd); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFFma { + #[dst_type(F32)] pub dst: Dst, #[src_type(F32)] @@ -2394,6 +2418,7 @@ impl_display_for_op!(OpFFma); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFMnMx { + #[dst_type(F32)] pub dst: Dst, #[src_type(F32)] @@ -2420,6 +2445,7 @@ impl_display_for_op!(OpFMnMx); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFMul { + #[dst_type(F32)] pub dst: Dst, #[src_type(F32)] @@ -2451,7 +2477,9 @@ impl_display_for_op!(OpFMul); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFSet { + #[dst_type(F32)] pub dst: Dst, + pub cmp_op: FloatCmpOp, #[src_type(F32)] @@ -2475,6 +2503,7 @@ impl_display_for_op!(OpFSet); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFSetP { + #[dst_type(Pred)] pub dst: Dst, pub set_op: PredSetOp, @@ -2528,6 +2557,7 @@ impl fmt::Display for FSwzAddOp { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFSwzAdd { + #[dst_type(F32)] pub dst: Dst, #[src_type(GPR)] @@ -2582,7 +2612,9 @@ impl fmt::Display for RroOp { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpRro { + #[dst_type(F32)] pub dst: Dst, + pub op: RroOp, #[src_type(F32)] @@ -2631,7 +2663,9 @@ impl fmt::Display for MuFuOp { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpMuFu { + #[dst_type(F32)] pub dst: Dst, + pub op: MuFuOp, #[src_type(F32)] @@ -2648,6 +2682,7 @@ impl_display_for_op!(OpMuFu); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpDAdd { + #[dst_type(F64)] pub dst: Dst, #[src_type(F64)] @@ -2670,6 +2705,7 @@ impl_display_for_op!(OpDAdd); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpDMul { + #[dst_type(F64)] pub dst: Dst, #[src_type(F64)] @@ -2692,6 +2728,7 @@ impl_display_for_op!(OpDMul); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpDFma { + #[dst_type(F64)] pub dst: Dst, #[src_type(F64)] @@ -2714,6 +2751,7 @@ impl_display_for_op!(OpDFma); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpDMnMx { + #[dst_type(F64)] pub dst: Dst, #[src_type(F64)] @@ -2733,6 +2771,7 @@ impl_display_for_op!(OpDMnMx); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpDSetP { + #[dst_type(Pred)] pub dst: Dst, pub set_op: PredSetOp, @@ -2763,6 +2802,7 @@ impl_display_for_op!(OpDSetP); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpHAdd2 { + #[dst_type(F16v2)] pub dst: Dst, #[src_type(F16v2)] @@ -2789,6 +2829,7 @@ impl_display_for_op!(OpHAdd2); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpHSet2 { + #[dst_type(F16v2)] pub dst: Dst, pub set_op: PredSetOp, @@ -2822,6 +2863,7 @@ impl_display_for_op!(OpHSet2); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpHSetP2 { + #[dst_type(Pred)] pub dsts: [Dst; 2], pub set_op: PredSetOp, @@ -2861,6 +2903,7 @@ impl_display_for_op!(OpHSetP2); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpHMul2 { + #[dst_type(F16v2)] pub dst: Dst, #[src_type(F16v2)] @@ -2888,6 +2931,7 @@ impl_display_for_op!(OpHMul2); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpHFma2 { + #[dst_type(F16v2)] pub dst: Dst, #[src_type(F16v2)] @@ -2917,6 +2961,7 @@ impl_display_for_op!(OpHFma2); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpHMnMx2 { + #[dst_type(F16v2)] pub dst: Dst, #[src_type(F16v2)] @@ -2943,6 +2988,7 @@ impl_display_for_op!(OpHMnMx2); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpBMsk { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -2965,6 +3011,7 @@ impl_display_for_op!(OpBMsk); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpBRev { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -2984,6 +3031,7 @@ impl_display_for_op!(OpBRev); #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpBfe { /// Where to insert the bits. + #[dst_type(GPR)] pub dst: Dst, /// The source of bits to extract. @@ -3025,6 +3073,7 @@ impl_display_for_op!(OpBfe); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpFlo { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3048,6 +3097,7 @@ impl_display_for_op!(OpFlo); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIAbs { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3065,6 +3115,7 @@ impl_display_for_op!(OpIAbs); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIAdd2 { + #[dst_type(GPR)] pub dst: Dst, pub carry_out: Dst, @@ -3103,7 +3154,10 @@ impl DisplayOp for OpIAdd2X { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIAdd3 { + #[dst_type(GPR)] pub dst: Dst, + + #[dst_type(Pred)] pub overflow: [Dst; 2], #[src_type(I32)] @@ -3124,7 +3178,10 @@ impl_display_for_op!(OpIAdd3); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIAdd3X { + #[dst_type(GPR)] pub dst: Dst, + + #[dst_type(Pred)] pub overflow: [Dst; 2], #[src_type(B32)] @@ -3152,6 +3209,7 @@ impl_display_for_op!(OpIAdd3X); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIDp4 { + #[dst_type(GPR)] pub dst: Dst, pub src_types: [IntType; 2], @@ -3178,6 +3236,7 @@ impl_display_for_op!(OpIDp4); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIMad { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3197,6 +3256,7 @@ impl_display_for_op!(OpIMad); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIMul { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3226,6 +3286,7 @@ impl DisplayOp for OpIMul { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIMad64 { + #[dst_type(Vec)] pub dst: Dst, #[src_type(ALU)] @@ -3248,7 +3309,9 @@ impl_display_for_op!(OpIMad64); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIMnMx { + #[dst_type(GPR)] pub dst: Dst, + pub cmp_type: IntCmpType, #[src_type(ALU)] @@ -3272,6 +3335,7 @@ impl_display_for_op!(OpIMnMx); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpISetP { + #[dst_type(Pred)] pub dst: Dst, pub set_op: PredSetOp, @@ -3313,6 +3377,7 @@ impl_display_for_op!(OpISetP); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpLop2 { + #[dst_type(GPR)] pub dst: Dst, #[src_type(B32)] @@ -3330,6 +3395,7 @@ impl DisplayOp for OpLop2 { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpLop3 { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3371,6 +3437,7 @@ impl fmt::Display for ShflOp { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpShf { + #[dst_type(GPR)] pub dst: Dst, #[src_type(GPR)] @@ -3412,6 +3479,7 @@ impl_display_for_op!(OpShf); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpShl { + #[dst_type(GPR)] pub dst: Dst, #[src_type(GPR)] @@ -3437,6 +3505,7 @@ impl DisplayOp for OpShl { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpShr { + #[dst_type(GPR)] pub dst: Dst, #[src_type(GPR)] @@ -3463,10 +3532,8 @@ impl DisplayOp for OpShr { } #[repr(C)] -#[derive(DstsAsSlice)] pub struct OpF2F { pub dst: Dst, - pub src: Src, pub src_type: FloatType, @@ -3500,6 +3567,25 @@ impl SrcsAsSlice for OpF2F { } } +impl DstsAsSlice for OpF2F { + fn dsts_as_slice(&self) -> &[Dst] { + std::slice::from_ref(&self.dst) + } + + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + std::slice::from_mut(&mut self.dst) + } + + fn dst_types(&self) -> DstTypeList { + let dst_type = match self.dst_type { + FloatType::F16 => DstType::F16, + FloatType::F32 => DstType::F32, + FloatType::F64 => DstType::F64, + }; + DstTypeList::Uniform(dst_type) + } +} + impl DisplayOp for OpF2F { fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "f2f")?; @@ -3521,6 +3607,7 @@ impl_display_for_op!(OpF2F); #[repr(C)] #[derive(DstsAsSlice)] pub struct OpF2I { + #[dst_type(GPR)] pub dst: Dst, pub src: Src, @@ -3563,10 +3650,8 @@ impl DisplayOp for OpF2I { impl_display_for_op!(OpF2I); #[repr(C)] -#[derive(DstsAsSlice)] pub struct OpI2F { pub dst: Dst, - pub src: Src, pub dst_type: FloatType, @@ -3592,6 +3677,25 @@ impl SrcsAsSlice for OpI2F { } } +impl DstsAsSlice for OpI2F { + fn dsts_as_slice(&self) -> &[Dst] { + std::slice::from_ref(&self.dst) + } + + fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { + std::slice::from_mut(&mut self.dst) + } + + fn dst_types(&self) -> DstTypeList { + let dst_type = match self.dst_type { + FloatType::F16 => DstType::F16, + FloatType::F32 => DstType::F32, + FloatType::F64 => DstType::F64, + }; + DstTypeList::Uniform(dst_type) + } +} + impl DisplayOp for OpI2F { fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( @@ -3607,6 +3711,7 @@ impl_display_for_op!(OpI2F); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpI2I { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3641,6 +3746,7 @@ impl_display_for_op!(OpI2I); #[repr(C)] #[derive(DstsAsSlice)] pub struct OpFRnd { + #[dst_type(F32)] pub dst: Dst, pub src: Src, @@ -3685,6 +3791,7 @@ impl_display_for_op!(OpFRnd); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpMov { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3793,6 +3900,7 @@ impl fmt::Display for PrmtMode { #[derive(SrcsAsSlice, DstsAsSlice)] /// Permutes `srcs` into `dst` using `selection`. pub struct OpPrmt { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3853,6 +3961,7 @@ impl_display_for_op!(OpPrmt); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpSel { + #[dst_type(GPR)] pub dst: Dst, #[src_type(Pred)] @@ -3872,7 +3981,10 @@ impl_display_for_op!(OpSel); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpShfl { + #[dst_type(GPR)] pub dst: Dst, + + #[dst_type(Pred)] pub in_bounds: Dst, #[src_type(SSA)] @@ -3897,6 +4009,7 @@ impl_display_for_op!(OpShfl); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpPLop3 { + #[dst_type(Pred)] pub dsts: [Dst; 2], #[src_type(Pred)] @@ -3923,6 +4036,7 @@ impl_display_for_op!(OpPLop3); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpPSetP { + #[dst_type(Pred)] pub dsts: [Dst; 2], pub ops: [PredSetOp; 2], @@ -3944,6 +4058,7 @@ impl DisplayOp for OpPSetP { #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpPopC { + #[dst_type(GPR)] pub dst: Dst, #[src_type(ALU)] @@ -3960,6 +4075,7 @@ impl_display_for_op!(OpPopC); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpR2UR { + #[dst_type(GPR)] pub dst: Dst, #[src_type(GPR)] @@ -4669,6 +4785,7 @@ impl_display_for_op!(OpBMov); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpBreak { + #[dst_type(Bar)] pub bar_out: Dst, #[src_type(Bar)] @@ -4688,6 +4805,7 @@ impl_display_for_op!(OpBreak); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpBSSy { + #[dst_type(Bar)] pub bar_out: Dst, #[src_type(Pred)] @@ -4788,6 +4906,7 @@ impl_display_for_op!(OpCS2R); #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIsberd { + #[dst_type(GPR)] pub dst: Dst, #[src_type(SSA)] @@ -4903,7 +5022,10 @@ impl fmt::Display for VoteOp { pub struct OpVote { pub op: VoteOp, + #[dst_type(GPR)] pub ballot: Dst, + + #[dst_type(Pred)] pub vote: Dst, #[src_type(Pred)] @@ -5125,6 +5247,10 @@ impl DstsAsSlice for OpPhiDsts { &mut self.dsts.b } + fn dst_types(&self) -> DstTypeList { + DstTypeList::Uniform(DstType::Vec) + } + fn is_uniform(&self) -> bool { false } @@ -5253,6 +5379,10 @@ impl DstsAsSlice for OpParCopy { fn dsts_as_mut_slice(&mut self) -> &mut [Dst] { &mut self.dsts_srcs.a } + + fn dst_types(&self) -> DstTypeList { + DstTypeList::Uniform(DstType::Vec) + } } impl DisplayOp for OpParCopy { diff --git a/src/nouveau/compiler/nak/ir_proc.rs b/src/nouveau/compiler/nak/ir_proc.rs index 032c2652178..fbd96dbffc9 100644 --- a/src/nouveau/compiler/nak/ir_proc.rs +++ b/src/nouveau/compiler/nak/ir_proc.rs @@ -47,10 +47,10 @@ fn count_type(ty: &Type, search_type: &str) -> usize { } } -fn get_src_type(field: &Field) -> Option { +fn get_type_attr(field: &Field, ty_attr: &str) -> Option { for attr in &field.attrs { if let Meta::List(ml) = &attr.meta { - if ml.path.is_ident("src_type") { + if ml.path.is_ident(ty_attr) { return Some(format!("{}", ml.tokens)); } } @@ -71,9 +71,13 @@ fn derive_as_slice( let trait_name = Ident::new(trait_name, Span::call_site()); let elem_type = Ident::new(search_type, Span::call_site()); let as_slice = - Ident::new(&format!("{}_as_slice", func_prefix), Span::call_site()); + Ident::new(&format!("{func_prefix}s_as_slice"), Span::call_site()); let as_mut_slice = - Ident::new(&format!("{}_as_mut_slice", func_prefix), Span::call_site()); + Ident::new(&format!("{func_prefix}s_as_mut_slice"), Span::call_site()); + let types_fn = + Ident::new(&format!("{func_prefix}_types"), Span::call_site()); + let ty_attr = format!("{func_prefix}_type"); + let ty_type = Ident::new(&format!("{search_type}Type"), Span::call_site()); match data { Data::Struct(s) => { @@ -95,41 +99,36 @@ fn derive_as_slice( let mut first = None; let mut count = 0_usize; let mut found_last = false; - let mut src_types = TokenStream2::new(); + let mut types = TokenStream2::new(); if let Fields::Named(named) = s.fields { for f in named.named { let ty_count = count_type(&f.ty, search_type); - - if search_type == "Src" { - let src_type = get_src_type(&f); - if ty_count == 0 && !src_type.is_none() { - panic!( - "src_type attribute is only allowed on sources" - ); - } - - let src_type = if let Some(s) = src_type { - let s = syn::parse_str::(&s).unwrap(); - quote! { SrcType::#s, } - } else { - quote! { SrcType::DEFAULT, } - }; - - for _ in 0..ty_count { - src_types.extend(src_type.clone()); - } - } + let ty = get_type_attr(&f, &ty_attr); if ty_count > 0 { assert!( !found_last, - "All fields of type {} must be consecutive", - search_type + "All fields of type {search_type} must be consecutive", ); + + let ty = if let Some(s) = ty { + let s = syn::parse_str::(&s).unwrap(); + quote! { #ty_type::#s, } + } else { + quote! { #ty_type::DEFAULT, } + }; + first.get_or_insert(f.ident); + for _ in 0..ty_count { + types.extend(ty.clone()); + } count += ty_count; } else { + assert!( + ty.is_none(), + "{ty_attr} attribute is only allowed on {search_type}" + ); if !first.is_none() { found_last = true; } @@ -139,17 +138,6 @@ fn derive_as_slice( panic!("Fields are not named"); } - let src_type_func = if search_type == "Src" { - quote! { - fn src_types(&self) -> SrcTypeList { - static SRC_TYPES: [SrcType; #count] = [#src_types]; - SrcTypeList::Array(&SRC_TYPES) - } - } - } else { - TokenStream2::new() - }; - if let Some(name) = first { quote! { impl #trait_name for #ident { @@ -167,7 +155,10 @@ fn derive_as_slice( } } - #src_type_func + fn #types_fn(&self) -> TypeList<#ty_type> { + static TYPES: [#ty_type; #count] = [#types]; + TypeList::Array(&TYPES) + } } } } else { @@ -181,7 +172,9 @@ fn derive_as_slice( &mut [] } - #src_type_func + fn #types_fn(&self) -> TypeList<#ty_type> { + TypeList::Uniform(#ty_type::DEFAULT) + } } } } @@ -190,7 +183,7 @@ fn derive_as_slice( Data::Enum(e) => { let mut as_slice_cases = TokenStream2::new(); let mut as_mut_slice_cases = TokenStream2::new(); - let mut src_types_cases = TokenStream2::new(); + let mut types_cases = TokenStream2::new(); let mut is_uniform_cases = TokenStream2::new(); for v in e.variants { let case = v.ident; @@ -200,28 +193,15 @@ fn derive_as_slice( as_mut_slice_cases.extend(quote! { #ident::#case(x) => x.#as_mut_slice(), }); - if search_type == "Src" { - src_types_cases.extend(quote! { - #ident::#case(x) => x.src_types(), - }); - } + types_cases.extend(quote! { + #ident::#case(x) => x.#types_fn(), + }); if search_type == "Dst" { is_uniform_cases.extend(quote! { #ident::#case(x) => x.is_uniform(), }); } } - let src_type_func = if search_type == "Src" { - quote! { - fn src_types(&self) -> SrcTypeList { - match self { - #src_types_cases - } - } - } - } else { - TokenStream2::new() - }; let is_uniform_func = if search_type == "Dst" { quote! { fn is_uniform(&self) -> bool { @@ -247,7 +227,11 @@ fn derive_as_slice( } } - #src_type_func + fn #types_fn(&self) -> TypeList<#ty_type> { + match self { + #types_cases + } + } #is_uniform_func } } @@ -259,12 +243,12 @@ fn derive_as_slice( #[proc_macro_derive(SrcsAsSlice, attributes(src_type))] pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream { - derive_as_slice(input, "SrcsAsSlice", "srcs", "Src") + derive_as_slice(input, "SrcsAsSlice", "src", "Src") } -#[proc_macro_derive(DstsAsSlice)] +#[proc_macro_derive(DstsAsSlice, attributes(dst_type))] pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream { - derive_as_slice(input, "DstsAsSlice", "dsts", "Dst") + derive_as_slice(input, "DstsAsSlice", "dst", "Dst") } #[proc_macro_derive(DisplayOp)]