nak: Add dst_type decorations

This is similar to the src_type decorations we have scattered all over
the IR, only for destinations.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30275>
This commit is contained in:
Faith Ekstrand 2024-07-19 13:26:17 -05:00 committed by Marge Bot
parent 9e25b6c0ff
commit 841737925f
2 changed files with 203 additions and 89 deletions

View file

@ -1045,22 +1045,6 @@ impl SrcMod {
}
}
#[repr(u8)]
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub enum SrcType {
SSA,
GPR,
ALU,
F16,
F16v2,
F32,
F64,
I32,
B32,
Pred,
Bar,
}
#[derive(Clone, Copy, PartialEq)]
#[allow(dead_code)]
pub enum SrcSwizzle {
@ -1395,26 +1379,44 @@ impl fmt::Display for Src {
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum SrcType {
SSA,
GPR,
ALU,
F16,
F16v2,
F32,
F64,
I32,
B32,
Pred,
Bar,
}
impl SrcType {
const DEFAULT: SrcType = SrcType::GPR;
}
pub enum SrcTypeList {
Array(&'static [SrcType]),
Uniform(SrcType),
pub enum TypeList<T: 'static> {
Array(&'static [T]),
Uniform(T),
}
impl Index<usize> for SrcTypeList {
type Output = SrcType;
impl<T: 'static> Index<usize> for TypeList<T> {
type Output = T;
fn index(&self, idx: usize) -> &SrcType {
fn index(&self, idx: usize) -> &T {
match self {
SrcTypeList::Array(arr) => &arr[idx],
SrcTypeList::Uniform(typ) => typ,
TypeList::Array(arr) => &arr[idx],
TypeList::Uniform(typ) => typ,
}
}
}
pub type SrcTypeList = TypeList<SrcType>;
pub trait SrcsAsSlice {
fn srcs_as_slice(&self) -> &[Src];
fn srcs_as_mut_slice(&mut self) -> &mut [Src];
@ -1435,9 +1437,29 @@ fn all_dsts_uniform(dsts: &[Dst]) -> bool {
uniform == Some(true)
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum DstType {
Pred,
GPR,
F16,
F16v2,
F32,
F64,
Bar,
Vec,
}
impl DstType {
const DEFAULT: DstType = DstType::Vec;
}
pub type DstTypeList = TypeList<DstType>;
pub trait DstsAsSlice {
fn dsts_as_slice(&self) -> &[Dst];
fn dsts_as_mut_slice(&mut self) -> &mut [Dst];
fn dst_types(&self) -> DstTypeList;
fn is_uniform(&self) -> bool {
all_dsts_uniform(self.dsts_as_slice())
@ -2335,6 +2357,7 @@ pub struct AttrAccess {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFAdd {
#[dst_type(F32)]
pub dst: Dst,
#[src_type(F32)]
@ -2363,6 +2386,7 @@ impl_display_for_op!(OpFAdd);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFFma {
#[dst_type(F32)]
pub dst: Dst,
#[src_type(F32)]
@ -2394,6 +2418,7 @@ impl_display_for_op!(OpFFma);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFMnMx {
#[dst_type(F32)]
pub dst: Dst,
#[src_type(F32)]
@ -2420,6 +2445,7 @@ impl_display_for_op!(OpFMnMx);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFMul {
#[dst_type(F32)]
pub dst: Dst,
#[src_type(F32)]
@ -2451,7 +2477,9 @@ impl_display_for_op!(OpFMul);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFSet {
#[dst_type(F32)]
pub dst: Dst,
pub cmp_op: FloatCmpOp,
#[src_type(F32)]
@ -2475,6 +2503,7 @@ impl_display_for_op!(OpFSet);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFSetP {
#[dst_type(Pred)]
pub dst: Dst,
pub set_op: PredSetOp,
@ -2528,6 +2557,7 @@ impl fmt::Display for FSwzAddOp {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFSwzAdd {
#[dst_type(F32)]
pub dst: Dst,
#[src_type(GPR)]
@ -2582,7 +2612,9 @@ impl fmt::Display for RroOp {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpRro {
#[dst_type(F32)]
pub dst: Dst,
pub op: RroOp,
#[src_type(F32)]
@ -2631,7 +2663,9 @@ impl fmt::Display for MuFuOp {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpMuFu {
#[dst_type(F32)]
pub dst: Dst,
pub op: MuFuOp,
#[src_type(F32)]
@ -2648,6 +2682,7 @@ impl_display_for_op!(OpMuFu);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpDAdd {
#[dst_type(F64)]
pub dst: Dst,
#[src_type(F64)]
@ -2670,6 +2705,7 @@ impl_display_for_op!(OpDAdd);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpDMul {
#[dst_type(F64)]
pub dst: Dst,
#[src_type(F64)]
@ -2692,6 +2728,7 @@ impl_display_for_op!(OpDMul);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpDFma {
#[dst_type(F64)]
pub dst: Dst,
#[src_type(F64)]
@ -2714,6 +2751,7 @@ impl_display_for_op!(OpDFma);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpDMnMx {
#[dst_type(F64)]
pub dst: Dst,
#[src_type(F64)]
@ -2733,6 +2771,7 @@ impl_display_for_op!(OpDMnMx);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpDSetP {
#[dst_type(Pred)]
pub dst: Dst,
pub set_op: PredSetOp,
@ -2763,6 +2802,7 @@ impl_display_for_op!(OpDSetP);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHAdd2 {
#[dst_type(F16v2)]
pub dst: Dst,
#[src_type(F16v2)]
@ -2789,6 +2829,7 @@ impl_display_for_op!(OpHAdd2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHSet2 {
#[dst_type(F16v2)]
pub dst: Dst,
pub set_op: PredSetOp,
@ -2822,6 +2863,7 @@ impl_display_for_op!(OpHSet2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHSetP2 {
#[dst_type(Pred)]
pub dsts: [Dst; 2],
pub set_op: PredSetOp,
@ -2861,6 +2903,7 @@ impl_display_for_op!(OpHSetP2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHMul2 {
#[dst_type(F16v2)]
pub dst: Dst,
#[src_type(F16v2)]
@ -2888,6 +2931,7 @@ impl_display_for_op!(OpHMul2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHFma2 {
#[dst_type(F16v2)]
pub dst: Dst,
#[src_type(F16v2)]
@ -2917,6 +2961,7 @@ impl_display_for_op!(OpHFma2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHMnMx2 {
#[dst_type(F16v2)]
pub dst: Dst,
#[src_type(F16v2)]
@ -2943,6 +2988,7 @@ impl_display_for_op!(OpHMnMx2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBMsk {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -2965,6 +3011,7 @@ impl_display_for_op!(OpBMsk);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBRev {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -2984,6 +3031,7 @@ impl_display_for_op!(OpBRev);
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBfe {
/// Where to insert the bits.
#[dst_type(GPR)]
pub dst: Dst,
/// The source of bits to extract.
@ -3025,6 +3073,7 @@ impl_display_for_op!(OpBfe);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpFlo {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3048,6 +3097,7 @@ impl_display_for_op!(OpFlo);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIAbs {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3065,6 +3115,7 @@ impl_display_for_op!(OpIAbs);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIAdd2 {
#[dst_type(GPR)]
pub dst: Dst,
pub carry_out: Dst,
@ -3103,7 +3154,10 @@ impl DisplayOp for OpIAdd2X {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIAdd3 {
#[dst_type(GPR)]
pub dst: Dst,
#[dst_type(Pred)]
pub overflow: [Dst; 2],
#[src_type(I32)]
@ -3124,7 +3178,10 @@ impl_display_for_op!(OpIAdd3);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIAdd3X {
#[dst_type(GPR)]
pub dst: Dst,
#[dst_type(Pred)]
pub overflow: [Dst; 2],
#[src_type(B32)]
@ -3152,6 +3209,7 @@ impl_display_for_op!(OpIAdd3X);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIDp4 {
#[dst_type(GPR)]
pub dst: Dst,
pub src_types: [IntType; 2],
@ -3178,6 +3236,7 @@ impl_display_for_op!(OpIDp4);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIMad {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3197,6 +3256,7 @@ impl_display_for_op!(OpIMad);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIMul {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3226,6 +3286,7 @@ impl DisplayOp for OpIMul {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIMad64 {
#[dst_type(Vec)]
pub dst: Dst,
#[src_type(ALU)]
@ -3248,7 +3309,9 @@ impl_display_for_op!(OpIMad64);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIMnMx {
#[dst_type(GPR)]
pub dst: Dst,
pub cmp_type: IntCmpType,
#[src_type(ALU)]
@ -3272,6 +3335,7 @@ impl_display_for_op!(OpIMnMx);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpISetP {
#[dst_type(Pred)]
pub dst: Dst,
pub set_op: PredSetOp,
@ -3313,6 +3377,7 @@ impl_display_for_op!(OpISetP);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpLop2 {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(B32)]
@ -3330,6 +3395,7 @@ impl DisplayOp for OpLop2 {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpLop3 {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3371,6 +3437,7 @@ impl fmt::Display for ShflOp {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpShf {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(GPR)]
@ -3412,6 +3479,7 @@ impl_display_for_op!(OpShf);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpShl {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(GPR)]
@ -3437,6 +3505,7 @@ impl DisplayOp for OpShl {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpShr {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(GPR)]
@ -3463,10 +3532,8 @@ impl DisplayOp for OpShr {
}
#[repr(C)]
#[derive(DstsAsSlice)]
pub struct OpF2F {
pub dst: Dst,
pub src: Src,
pub src_type: FloatType,
@ -3500,6 +3567,25 @@ impl SrcsAsSlice for OpF2F {
}
}
impl DstsAsSlice for OpF2F {
fn dsts_as_slice(&self) -> &[Dst] {
std::slice::from_ref(&self.dst)
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
std::slice::from_mut(&mut self.dst)
}
fn dst_types(&self) -> DstTypeList {
let dst_type = match self.dst_type {
FloatType::F16 => DstType::F16,
FloatType::F32 => DstType::F32,
FloatType::F64 => DstType::F64,
};
DstTypeList::Uniform(dst_type)
}
}
impl DisplayOp for OpF2F {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "f2f")?;
@ -3521,6 +3607,7 @@ impl_display_for_op!(OpF2F);
#[repr(C)]
#[derive(DstsAsSlice)]
pub struct OpF2I {
#[dst_type(GPR)]
pub dst: Dst,
pub src: Src,
@ -3563,10 +3650,8 @@ impl DisplayOp for OpF2I {
impl_display_for_op!(OpF2I);
#[repr(C)]
#[derive(DstsAsSlice)]
pub struct OpI2F {
pub dst: Dst,
pub src: Src,
pub dst_type: FloatType,
@ -3592,6 +3677,25 @@ impl SrcsAsSlice for OpI2F {
}
}
impl DstsAsSlice for OpI2F {
fn dsts_as_slice(&self) -> &[Dst] {
std::slice::from_ref(&self.dst)
}
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
std::slice::from_mut(&mut self.dst)
}
fn dst_types(&self) -> DstTypeList {
let dst_type = match self.dst_type {
FloatType::F16 => DstType::F16,
FloatType::F32 => DstType::F32,
FloatType::F64 => DstType::F64,
};
DstTypeList::Uniform(dst_type)
}
}
impl DisplayOp for OpI2F {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
@ -3607,6 +3711,7 @@ impl_display_for_op!(OpI2F);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpI2I {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3641,6 +3746,7 @@ impl_display_for_op!(OpI2I);
#[repr(C)]
#[derive(DstsAsSlice)]
pub struct OpFRnd {
#[dst_type(F32)]
pub dst: Dst,
pub src: Src,
@ -3685,6 +3791,7 @@ impl_display_for_op!(OpFRnd);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpMov {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3793,6 +3900,7 @@ impl fmt::Display for PrmtMode {
#[derive(SrcsAsSlice, DstsAsSlice)]
/// Permutes `srcs` into `dst` using `selection`.
pub struct OpPrmt {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3853,6 +3961,7 @@ impl_display_for_op!(OpPrmt);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpSel {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(Pred)]
@ -3872,7 +3981,10 @@ impl_display_for_op!(OpSel);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpShfl {
#[dst_type(GPR)]
pub dst: Dst,
#[dst_type(Pred)]
pub in_bounds: Dst,
#[src_type(SSA)]
@ -3897,6 +4009,7 @@ impl_display_for_op!(OpShfl);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpPLop3 {
#[dst_type(Pred)]
pub dsts: [Dst; 2],
#[src_type(Pred)]
@ -3923,6 +4036,7 @@ impl_display_for_op!(OpPLop3);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpPSetP {
#[dst_type(Pred)]
pub dsts: [Dst; 2],
pub ops: [PredSetOp; 2],
@ -3944,6 +4058,7 @@ impl DisplayOp for OpPSetP {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpPopC {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
@ -3960,6 +4075,7 @@ impl_display_for_op!(OpPopC);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpR2UR {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(GPR)]
@ -4669,6 +4785,7 @@ impl_display_for_op!(OpBMov);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBreak {
#[dst_type(Bar)]
pub bar_out: Dst,
#[src_type(Bar)]
@ -4688,6 +4805,7 @@ impl_display_for_op!(OpBreak);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBSSy {
#[dst_type(Bar)]
pub bar_out: Dst,
#[src_type(Pred)]
@ -4788,6 +4906,7 @@ impl_display_for_op!(OpCS2R);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIsberd {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(SSA)]
@ -4903,7 +5022,10 @@ impl fmt::Display for VoteOp {
pub struct OpVote {
pub op: VoteOp,
#[dst_type(GPR)]
pub ballot: Dst,
#[dst_type(Pred)]
pub vote: Dst,
#[src_type(Pred)]
@ -5125,6 +5247,10 @@ impl DstsAsSlice for OpPhiDsts {
&mut self.dsts.b
}
fn dst_types(&self) -> DstTypeList {
DstTypeList::Uniform(DstType::Vec)
}
fn is_uniform(&self) -> bool {
false
}
@ -5253,6 +5379,10 @@ impl DstsAsSlice for OpParCopy {
fn dsts_as_mut_slice(&mut self) -> &mut [Dst] {
&mut self.dsts_srcs.a
}
fn dst_types(&self) -> DstTypeList {
DstTypeList::Uniform(DstType::Vec)
}
}
impl DisplayOp for OpParCopy {

View file

@ -47,10 +47,10 @@ fn count_type(ty: &Type, search_type: &str) -> usize {
}
}
fn get_src_type(field: &Field) -> Option<String> {
fn get_type_attr(field: &Field, ty_attr: &str) -> Option<String> {
for attr in &field.attrs {
if let Meta::List(ml) = &attr.meta {
if ml.path.is_ident("src_type") {
if ml.path.is_ident(ty_attr) {
return Some(format!("{}", ml.tokens));
}
}
@ -71,9 +71,13 @@ fn derive_as_slice(
let trait_name = Ident::new(trait_name, Span::call_site());
let elem_type = Ident::new(search_type, Span::call_site());
let as_slice =
Ident::new(&format!("{}_as_slice", func_prefix), Span::call_site());
Ident::new(&format!("{func_prefix}s_as_slice"), Span::call_site());
let as_mut_slice =
Ident::new(&format!("{}_as_mut_slice", func_prefix), Span::call_site());
Ident::new(&format!("{func_prefix}s_as_mut_slice"), Span::call_site());
let types_fn =
Ident::new(&format!("{func_prefix}_types"), Span::call_site());
let ty_attr = format!("{func_prefix}_type");
let ty_type = Ident::new(&format!("{search_type}Type"), Span::call_site());
match data {
Data::Struct(s) => {
@ -95,41 +99,36 @@ fn derive_as_slice(
let mut first = None;
let mut count = 0_usize;
let mut found_last = false;
let mut src_types = TokenStream2::new();
let mut types = TokenStream2::new();
if let Fields::Named(named) = s.fields {
for f in named.named {
let ty_count = count_type(&f.ty, search_type);
if search_type == "Src" {
let src_type = get_src_type(&f);
if ty_count == 0 && !src_type.is_none() {
panic!(
"src_type attribute is only allowed on sources"
);
}
let src_type = if let Some(s) = src_type {
let s = syn::parse_str::<Ident>(&s).unwrap();
quote! { SrcType::#s, }
} else {
quote! { SrcType::DEFAULT, }
};
for _ in 0..ty_count {
src_types.extend(src_type.clone());
}
}
let ty = get_type_attr(&f, &ty_attr);
if ty_count > 0 {
assert!(
!found_last,
"All fields of type {} must be consecutive",
search_type
"All fields of type {search_type} must be consecutive",
);
let ty = if let Some(s) = ty {
let s = syn::parse_str::<Ident>(&s).unwrap();
quote! { #ty_type::#s, }
} else {
quote! { #ty_type::DEFAULT, }
};
first.get_or_insert(f.ident);
for _ in 0..ty_count {
types.extend(ty.clone());
}
count += ty_count;
} else {
assert!(
ty.is_none(),
"{ty_attr} attribute is only allowed on {search_type}"
);
if !first.is_none() {
found_last = true;
}
@ -139,17 +138,6 @@ fn derive_as_slice(
panic!("Fields are not named");
}
let src_type_func = if search_type == "Src" {
quote! {
fn src_types(&self) -> SrcTypeList {
static SRC_TYPES: [SrcType; #count] = [#src_types];
SrcTypeList::Array(&SRC_TYPES)
}
}
} else {
TokenStream2::new()
};
if let Some(name) = first {
quote! {
impl #trait_name for #ident {
@ -167,7 +155,10 @@ fn derive_as_slice(
}
}
#src_type_func
fn #types_fn(&self) -> TypeList<#ty_type> {
static TYPES: [#ty_type; #count] = [#types];
TypeList::Array(&TYPES)
}
}
}
} else {
@ -181,7 +172,9 @@ fn derive_as_slice(
&mut []
}
#src_type_func
fn #types_fn(&self) -> TypeList<#ty_type> {
TypeList::Uniform(#ty_type::DEFAULT)
}
}
}
}
@ -190,7 +183,7 @@ fn derive_as_slice(
Data::Enum(e) => {
let mut as_slice_cases = TokenStream2::new();
let mut as_mut_slice_cases = TokenStream2::new();
let mut src_types_cases = TokenStream2::new();
let mut types_cases = TokenStream2::new();
let mut is_uniform_cases = TokenStream2::new();
for v in e.variants {
let case = v.ident;
@ -200,28 +193,15 @@ fn derive_as_slice(
as_mut_slice_cases.extend(quote! {
#ident::#case(x) => x.#as_mut_slice(),
});
if search_type == "Src" {
src_types_cases.extend(quote! {
#ident::#case(x) => x.src_types(),
});
}
types_cases.extend(quote! {
#ident::#case(x) => x.#types_fn(),
});
if search_type == "Dst" {
is_uniform_cases.extend(quote! {
#ident::#case(x) => x.is_uniform(),
});
}
}
let src_type_func = if search_type == "Src" {
quote! {
fn src_types(&self) -> SrcTypeList {
match self {
#src_types_cases
}
}
}
} else {
TokenStream2::new()
};
let is_uniform_func = if search_type == "Dst" {
quote! {
fn is_uniform(&self) -> bool {
@ -247,7 +227,11 @@ fn derive_as_slice(
}
}
#src_type_func
fn #types_fn(&self) -> TypeList<#ty_type> {
match self {
#types_cases
}
}
#is_uniform_func
}
}
@ -259,12 +243,12 @@ fn derive_as_slice(
#[proc_macro_derive(SrcsAsSlice, attributes(src_type))]
pub fn derive_srcs_as_slice(input: TokenStream) -> TokenStream {
derive_as_slice(input, "SrcsAsSlice", "srcs", "Src")
derive_as_slice(input, "SrcsAsSlice", "src", "Src")
}
#[proc_macro_derive(DstsAsSlice)]
#[proc_macro_derive(DstsAsSlice, attributes(dst_type))]
pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream {
derive_as_slice(input, "DstsAsSlice", "dsts", "Dst")
derive_as_slice(input, "DstsAsSlice", "dst", "Dst")
}
#[proc_macro_derive(DisplayOp)]