kraid: Rework swizzles

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41841>
This commit is contained in:
Faith Ekstrand 2026-05-14 12:34:06 -04:00 committed by Marge Bot
parent 1631ebbc5a
commit 16d3ad1820
5 changed files with 565 additions and 80 deletions

View file

@ -101,6 +101,11 @@ impl DataType {
DataType::from_pieces(comps, num_type, bits)
}
pub const fn scalar_type(self) -> DataType {
let (_, num_type, bits) = self.to_pieces();
DataType::from_pieces(1, num_type, bits)
}
pub fn bits(&self) -> Option<NonZeroU8> {
NonZeroU8::new(self.to_pieces().2)
}

View file

@ -12,6 +12,7 @@ use proc_macro::TokenStream;
mod data_type;
mod ir;
mod swizzle;
#[proc_macro_attribute]
pub fn variants(attr: TokenStream, item: TokenStream) -> TokenStream {
@ -42,6 +43,16 @@ pub fn derive_data_type(input: TokenStream) -> TokenStream {
data_type::derive_data_type(input)
}
#[proc_macro_derive(AsmSwizzleWiden)]
pub fn derive_asm_swizzle_widen(input: TokenStream) -> TokenStream {
swizzle::derive_asm_swizzle_widen(input)
}
#[proc_macro_derive(EnumAsU8)]
pub fn derive_enum_as_u8(input: TokenStream) -> TokenStream {
compiler_proc::enum_as_u8::derive_enum_as_u8(input)
}
#[proc_macro_derive(FromVariants)]
pub fn derive_from_variants(input: TokenStream) -> TokenStream {
compiler_proc::from_variants::derive_from_variants(input)

View file

@ -0,0 +1,235 @@
// Copyright © 2026 Collabora, Ltd.
// SPDX-License-Identifier: MIT
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use syn::*;
fn is_valid_data_type(comps: u8, num_type: char, bits: u8) -> bool {
if bits == 64 {
return comps == 1 && ['I', 'S', 'U'].contains(&num_type);
}
if bits == 8 && num_type == 'F' {
return false;
}
if comps * bits > 32 {
return false;
}
true
}
fn data_type_str(comps: u8, num_type: char, bits: u8) -> String {
if comps == 1 {
format!("{num_type}{bits}")
} else {
format!("V{comps}{num_type}{bits}")
}
}
fn data_type_ident(comps: u8, num_type: char, bits: u8) -> Ident {
Ident::new(&data_type_str(comps, num_type, bits), Span::call_site())
}
fn widen_call_for_asm_swizzle(
asm_swizzle: &Ident,
data_type: TokenStream2,
) -> TokenStream2 {
let swz_name = asm_swizzle.to_string();
let mut iter = swz_name.chars();
let widen_type = iter.next().unwrap();
let mut widen_ident = format!("widen_{widen_type}");
let mut widen_args = data_type;
for c in iter {
// Everything after the first character should be a
// number
let i = c.to_digit(10).unwrap() as u8;
widen_ident += "x";
widen_args.extend(quote! { , #i });
}
let widen_ident = widen_ident.to_lowercase();
let widen_ident = Ident::new(&widen_ident, Span::call_site());
quote! { Swizzle::#widen_ident(#widen_args) }
}
pub fn derive_asm_swizzle_widen(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input);
let enum_type = ident;
let Data::Enum(e) = data else {
panic!("Not an enum type");
};
const NUM_COMPS: [u8; 3] = [1, 2, 4];
const NUM_TYPES: [char; 4] = ['F', 'I', 'S', 'U'];
const BIT_SIZES: [u8; 4] = [8, 16, 32, 64];
let mut from_swizzle_dt_cases = TokenStream2::new();
let mut swizzle_consts = TokenStream2::new();
for &dt_comps in &NUM_COMPS {
for &num_type in &NUM_TYPES {
for &dt_bits in &BIT_SIZES {
if !is_valid_data_type(dt_comps, num_type, dt_bits) {
continue;
}
let (dt_ident, dt_case) = if dt_bits == 8 {
// 8-bit types don't do any format conversion so we can
// handle them together. This is good because there are
// a LOT of 8-bit types.
if num_type != 'I' {
continue;
}
let i = data_type_ident(dt_comps, 'I', dt_bits);
let s = data_type_ident(dt_comps, 'S', dt_bits);
let u = data_type_ident(dt_comps, 'U', dt_bits);
let case =
quote! { DataType::#i | DataType::#s | DataType::#u };
(i, case)
} else {
let t = data_type_ident(dt_comps, num_type, dt_bits);
let case = quote! { DataType::#t };
(t, case)
};
let mut from_swizzle_cases = quote! {
Swizzle::NONE => Some(#enum_type::None),
};
for v in &e.variants {
if v.ident == "None" {
continue;
}
let v_name = v.ident.to_string();
let mut v_iter = v_name.chars();
let widen_type = v_iter.next().unwrap();
let widen_comps = v_iter.count() as u8;
if widen_comps != dt_comps {
continue;
}
let widen_bits = match widen_type {
'B' => 8,
'H' => 16,
'W' => 32,
_ => panic!("Invalid widen: {}", v.ident),
};
// Bytes can't be widened to floats
if widen_bits == 8 && num_type == 'F' {
continue;
}
// This is widen, not narrow
if widen_bits > dt_bits {
continue;
}
// I types can't be widened
if num_type == 'I' && dt_bits != widen_bits {
continue;
}
// 32 and 64-bit can't swizzle unless we're widening
if dt_bits >= 32 && dt_bits == widen_bits {
continue;
}
let widen = widen_call_for_asm_swizzle(
&v.ident,
quote! { DataType::#dt_ident },
);
let v_ident = &v.ident;
let const_ident = Ident::new(
&format!("SWIZZLE_{dt_ident}_{v_ident}"),
Span::call_site(),
);
swizzle_consts.extend(quote! {
const #const_ident: Swizzle = #widen;
});
from_swizzle_cases.extend(quote! {
Self::#const_ident => Some(#enum_type::#v_ident),
});
}
from_swizzle_dt_cases.extend(quote! {
#dt_case => match swizzle {
#from_swizzle_cases
_ => None,
}
});
}
}
}
let mut to_swizzle_cases = TokenStream2::new();
let mut display_cases = TokenStream2::new();
for v in &e.variants {
if v.ident == "None" {
to_swizzle_cases.extend(quote! {
#enum_type::None => Swizzle::NONE,
});
display_cases.extend(quote! {
#enum_type::None => Ok(()),
});
} else {
let widen =
widen_call_for_asm_swizzle(&v.ident, quote! { data_type });
let v_ident = &v.ident;
to_swizzle_cases.extend(quote! {
#enum_type::#v_ident => #widen,
});
let v_disp = format!(".{v_ident}").to_lowercase();
display_cases.extend(quote! {
#enum_type::#v_ident => write!(f, #v_disp),
});
}
}
let imp = quote! {
impl #enum_type {
#swizzle_consts
pub fn from_swizzle(
src_type: DataType,
swizzle: Swizzle,
) -> Option<AsmSwizzleWiden> {
if swizzle == Swizzle::NONE {
Some(#enum_type::None)
} else {
match src_type {
#from_swizzle_dt_cases
_ => None,
}
}
}
pub fn to_swizzle(self, data_type: DataType) -> Swizzle {
match self {
#to_swizzle_cases
}
}
}
impl std::fmt::Display for #enum_type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#display_cases
}
}
}
};
imp.into()
}

View file

@ -1,7 +1,10 @@
// Copyright © 2026 Collabora, Ltd.
// SPDX-License-Identifier: MIT
use crate::data_type::DataType;
use compiler::enum_as_u8::EnumAsU8;
use compiler::float16::F16;
use kraid_proc::{AsmSwizzleWiden, EnumAsU8};
use std::fmt;
use std::num::NonZeroU16;
@ -15,11 +18,9 @@ enum ByteMod {
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, EnumAsU8, PartialEq)]
enum SwizzleByte {
// There is no zero value by design
Invalid1 = 1,
Invalid2 = 2,
Zero = 3,
Byte0 = ((ByteMod::Byte as u8) << 2) | 0,
Byte1 = ((ByteMod::Byte as u8) << 2) | 1,
@ -38,9 +39,6 @@ enum SwizzleByte {
impl fmt::Display for SwizzleByte {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SwizzleByte::Invalid1 | SwizzleByte::Invalid2 => {
panic!("Invalid swizzle");
}
SwizzleByte::Zero => write!(f, "z"),
_ => {
let idx = self.byte_idx().unwrap();
@ -108,23 +106,59 @@ impl SwizzleByte {
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, EnumAsU8, PartialEq)]
enum SwizzleWord {
// There is no zero value by design
Zero = SwizzleByte::Zero as u8,
Word0 = SwizzleByte::Byte0 as u8,
Word1 = SwizzleByte::Byte1 as u8,
Sign0 = SwizzleByte::Sign0 as u8,
Sign1 = SwizzleByte::Sign1 as u8,
}
impl fmt::Display for SwizzleWord {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SwizzleWord::Zero => write!(f, "z"),
_ => {
let idx = self.word_idx().unwrap();
match self.word_mod().unwrap() {
ByteMod::Byte => write!(f, "w{idx}"),
ByteMod::Sign => write!(f, "ws{idx}"),
_ => panic!("SwizzleWord doesn't use Fext"),
}
}
}
}
}
impl SwizzleWord {
const fn word(word_idx: u8) -> Self {
assert!(word_idx < 2);
unsafe { std::mem::transmute(SwizzleByte::byte(word_idx)) }
}
const fn sign(word_idx: u8) -> Self {
assert!(word_idx < 2);
unsafe { std::mem::transmute(SwizzleByte::sign(word_idx)) }
}
pub fn word_idx(self) -> Option<u8> {
SwizzleByte::byte_idx(unsafe { std::mem::transmute(self) })
}
pub fn word_mod(self) -> Option<ByteMod> {
SwizzleByte::byte_mod(unsafe { std::mem::transmute(self) })
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq)]
pub struct Swizzle {
packed: NonZeroU16,
}
macro_rules! swizzle {
($b0: ident, $b1: ident, $b2: ident, $b3: ident) => {
Swizzle::from_swizzle_bytes([
SwizzleByte::$b0,
SwizzleByte::$b1,
SwizzleByte::$b2,
SwizzleByte::$b3,
])
};
}
impl Swizzle {
/// The identity swizzle
pub const NONE: Swizzle = Swizzle::from_bytes([0, 1, 2, 3]);
@ -152,9 +186,6 @@ impl Swizzle {
/// 16-bit to a 32-bit floating point value.
pub const HF1: Swizzle = Swizzle::widen_f16(1);
pub const ALL64: Swizzle = swizzle!(Invalid1, Invalid1, Invalid1, Invalid1);
const LOW32: Swizzle = swizzle!(Invalid1, Invalid1, Invalid1, Invalid2);
const fn from_swizzle_bytes(bytes: [SwizzleByte; 4]) -> Swizzle {
let b0 = bytes[0] as u16;
let b1 = bytes[1] as u16;
@ -169,6 +200,63 @@ impl Swizzle {
Swizzle { packed }
}
const fn from_swizzle_words(words: [SwizzleWord; 2]) -> Swizzle {
if (words[0] as u8) == (SwizzleWord::Word0 as u8)
&& (words[1] as u8) == (SwizzleWord::Word1 as u8)
{
// We use the same NONE value for both
Swizzle::NONE
} else {
let w0 = words[0] as u16;
let w1 = words[1] as u16;
// We leave the high 8 bits zero for word swizzles
let packed = w0 | (w1 << 4);
// SAFETY: SwizzleWord cannot be zero so neither can packed
debug_assert!(packed != 0);
let packed = unsafe { NonZeroU16::new_unchecked(packed) };
Swizzle { packed }
}
}
#[inline]
pub const fn is_word_swizzle(&self) -> bool {
// We leave the high 8 bits zero for word swizzles
(self.packed.get() >> 8) == 0
}
#[inline]
const fn byte(&self, idx: u8) -> Option<SwizzleByte> {
assert!(idx < 4);
if self.is_word_swizzle() {
None
} else {
let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8;
// SAFETY: We only ever construct Swizzle from SwizzleByte
debug_assert!(SwizzleByte::VARIANTS.contains(b));
Some(unsafe { std::mem::transmute(b) })
}
}
#[inline]
const fn word(&self, idx: u8) -> Option<SwizzleWord> {
assert!(idx < 2);
if self.packed.get() == Self::NONE.packed.get() {
Some(SwizzleWord::word(idx))
} else if self.is_word_swizzle() {
let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8;
// SAFETY: We only ever construct Swizzle from SwizzleByte
debug_assert!(SwizzleWord::VARIANTS.contains(b));
Some(unsafe { std::mem::transmute(b) })
} else {
None
}
}
pub const fn from_bytes(bytes: [u8; 4]) -> Swizzle {
Swizzle::from_swizzle_bytes([
SwizzleByte::byte(bytes[0]),
@ -258,13 +346,87 @@ impl Swizzle {
])
}
const fn byte(&self, idx: u8) -> SwizzleByte {
assert!(idx < 4);
let b = ((self.packed.get() >> (idx * 4)) & 0xf) as u8;
pub const fn widen_s32(word: u8) -> Swizzle {
Swizzle::from_swizzle_words([
SwizzleWord::word(word),
SwizzleWord::sign(word),
])
}
// SAFETY: We only ever construct Swizzle from SwizzleByte
debug_assert!(b != 0);
unsafe { std::mem::transmute(b) }
pub const fn widen_u32(word: u8) -> Swizzle {
Swizzle::from_swizzle_words([
SwizzleWord::word(word),
SwizzleWord::Zero,
])
}
pub const fn widen_bx(src_type: DataType, byte: u8) -> Swizzle {
match src_type.scalar_type() {
DataType::I8 | DataType::S8 | DataType::U8 => {
Swizzle::replicate_byte(byte)
}
DataType::S16 => Swizzle::widen_v2s8(byte, byte),
DataType::U16 => Swizzle::widen_v2u8(byte, byte),
DataType::S32 | DataType::S64 => Swizzle::widen_s8(byte),
DataType::U32 | DataType::U64 => Swizzle::widen_u8(byte),
_ => panic!("Src type cannot read a bx swizzle"),
}
}
pub const fn widen_bxx(src_type: DataType, x: u8, y: u8) -> Swizzle {
match src_type {
DataType::V2I8 | DataType::V2S8 | DataType::V2U8 => {
Swizzle::from_bytes([x, y, x, y])
}
DataType::V2U16 => Swizzle::widen_v2u8(x, y),
DataType::V2S16 => Swizzle::widen_v2s8(x, y),
_ => panic!("Src type cannot read a bxx swizzle"),
}
}
const fn widen_bxxxx(
src_type: DataType,
x: u8,
y: u8,
z: u8,
w: u8,
) -> Swizzle {
assert!(matches!(
src_type,
DataType::V4I8 | DataType::V4S8 | DataType::V4U8
));
Swizzle::from_bytes([x, y, z, w])
}
pub const fn widen_hx(src_type: DataType, half: u8) -> Swizzle {
match src_type.scalar_type() {
DataType::F16 | DataType::I16 | DataType::S16 | DataType::U16 => {
Swizzle::replicate_half(half)
}
DataType::F32 => Swizzle::widen_f16(half),
DataType::S32 | DataType::S64 => Swizzle::widen_s16(half),
DataType::U32 | DataType::U64 => Swizzle::widen_u16(half),
_ => panic!("Src type cannot read an hx swizzle"),
}
}
const fn widen_hxx(src_type: DataType, x: u8, y: u8) -> Swizzle {
assert!(matches!(
src_type,
DataType::V2F16
| DataType::V2I16
| DataType::V2S16
| DataType::V2U16
));
Swizzle::from_halves([x, y])
}
pub const fn widen_wx(src_type: DataType, word: u8) -> Swizzle {
match src_type {
DataType::S64 => Swizzle::widen_s32(word),
DataType::U64 => Swizzle::widen_u32(word),
_ => panic!("Src type cannot read a wx swizzle"),
}
}
/// Applies this swizzle to a u32 value
@ -272,7 +434,7 @@ impl Swizzle {
let mut folded = 0_u32;
let mut has_fext = false;
for i in 0..4 {
let sb = self.byte(i);
let sb = self.byte(i)?;
if sb == SwizzleByte::Zero {
continue;
}
@ -300,25 +462,27 @@ impl Swizzle {
}
pub fn bytes_read(&self) -> u8 {
match *self {
Swizzle::ALL64 => 0xff,
Swizzle::LOW32 => 0x0f,
_ => {
let mut bytes = 0_u8;
for i in 0..4 {
if let Some(b) = self.byte(i).byte_idx() {
bytes |= 1 << b;
}
let mut bytes = 0_u8;
if self.is_word_swizzle() {
for i in 0..2 {
if let Some(w) = self.word(i).unwrap().word_idx() {
bytes |= 0xf << (w * 4);
}
}
} else {
for i in 0..4 {
if let Some(b) = self.byte(i).unwrap().byte_idx() {
bytes |= 1 << b;
}
bytes
}
}
bytes
}
pub fn replicates_byte(&self) -> bool {
let b0 = self.byte(0);
b0.is_byte_mod_idx_or_zero()
b0.is_some_and(SwizzleByte::is_byte_mod_idx_or_zero)
&& self.byte(1) == b0
&& self.byte(2) == b0
&& self.byte(3) == b0
@ -328,36 +492,33 @@ impl Swizzle {
let b0 = self.byte(0);
let b1 = self.byte(1);
b0.is_byte_mod_idx_or_zero()
&& b1.is_byte_mod_idx_or_zero()
b0.is_some_and(SwizzleByte::is_byte_mod_idx_or_zero)
&& b1.is_some_and(SwizzleByte::is_byte_mod_idx_or_zero)
&& self.byte(2) == b0
&& self.byte(3) == b1
}
pub fn swizzle(self, other: Swizzle) -> Option<Swizzle> {
if self == Swizzle::ALL64 || self == Swizzle::LOW32 {
unimplemented!("64-bit swizzling");
}
if other == Swizzle::ALL64 || other == Swizzle::LOW32 {
unimplemented!("64-bit swizzling");
}
if other == Swizzle::NONE {
return Some(self);
} else if self == Swizzle::NONE {
return Some(other);
}
// Disallow combining word swizzles for now
if self.is_word_swizzle() || other.is_word_swizzle() {
return None;
}
let mut has_fext = false;
let mut bytes = [SwizzleByte::Zero; 4];
for i in 0..4 {
let ob = other.byte(i);
let ob = other.byte(i).unwrap();
bytes[usize::from(i)] = if ob == SwizzleByte::Zero {
SwizzleByte::Zero
} else {
let obi = ob.byte_idx()?;
let sb = self.byte(obi);
let sb = self.byte(obi).unwrap();
if sb == SwizzleByte::Zero {
SwizzleByte::Zero
} else {
@ -405,37 +566,110 @@ impl Default for Swizzle {
impl fmt::Display for Swizzle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Swizzle::NONE => Ok(()),
Swizzle::ALL64 => Ok(()),
Swizzle::LOW32 => write!(f, ".low32"),
Swizzle::H00 => write!(f, ".h00"),
Swizzle::H10 => write!(f, ".h10"),
Swizzle::H11 => write!(f, ".h11"),
Swizzle::HF0 => write!(f, ".hf0"),
Swizzle::HF1 => write!(f, ".hf1"),
_ => {
let mut is_bytes = true;
for i in 0..4 {
if self.byte(i).byte_mod() != Some(ByteMod::Byte) {
is_bytes = false;
break;
}
if *self == Swizzle::NONE {
Ok(())
} else if self.is_word_swizzle() {
let mut is_words = true;
for i in 0..4 {
if self.word(i).unwrap().word_mod() != Some(ByteMod::Byte) {
is_words = false;
break;
}
if is_bytes {
write!(f, ".b")?;
for i in 0..4 {
write!(f, "{}", self.byte(i).byte_idx().unwrap())?;
}
} else {
write!(f, ".")?;
for i in 0..4 {
write!(f, "{}", self.byte(i))?;
}
}
Ok(())
}
if is_words {
write!(f, ".w")?;
for i in 0..4 {
write!(f, "{}", self.word(i).unwrap().word_idx().unwrap())?;
}
} else {
write!(f, ".")?;
for i in 0..4 {
write!(f, "{}", self.word(i).unwrap())?;
}
}
Ok(())
} else {
let mut is_bytes = true;
for i in 0..4 {
if self.byte(i).unwrap().byte_mod() != Some(ByteMod::Byte) {
is_bytes = false;
break;
}
}
if is_bytes {
write!(f, ".b")?;
for i in 0..4 {
write!(f, "{}", self.byte(i).unwrap().byte_idx().unwrap())?;
}
} else {
write!(f, ".")?;
for i in 0..4 {
write!(f, "{}", self.byte(i).unwrap())?;
}
}
Ok(())
}
}
}
/// Swizzles and widens as they appear in the shader assembly. These are used
/// both for final codegen and for pretty printing.
#[repr(u8)]
#[derive(AsmSwizzleWiden, Clone, Copy, Debug, EnumAsU8, PartialEq)]
pub enum AsmSwizzleWiden {
None,
// 8-bit scalar swizzles
B0,
B1,
B2,
B3,
// 8-bit vec2 swizzles
B00,
B10,
B20,
B30,
B01,
B11,
B21,
B31,
B02,
B12,
B22,
B32,
B03,
B13,
B23,
B33,
// 8-bit vec4 swizzles that appear in the HW docs
B0000,
B0011,
B0101,
B0123,
B1032,
B1111,
B2222,
B2233,
B2301,
B2323,
B3210,
B3333,
// 16-bit scalar swizzles
H0,
H1,
// 16-bit vec2 swizzles
H00,
H01,
H10,
H11,
// 32-bit swizzles
W0,
W1,
}

View file

@ -12,7 +12,7 @@ fn validate_instr(instr: &Instr, ssa_vals: &mut FxHashSet<SSAValue>) {
}
}
if src.swizzle == Swizzle::ALL64 {
if src.swizzle.is_word_swizzle() {
assert!(src_type.bits().unwrap().get() == 64);
}