kraid: Handle nir_op_mov/vec/[un]pack

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41841>
This commit is contained in:
Faith Ekstrand 2026-05-11 23:25:55 -04:00 committed by Marge Bot
parent 765b89ed2e
commit 2febabbd3c
3 changed files with 188 additions and 1 deletions

View file

@ -37,6 +37,33 @@ pub trait SSABuilder: Builder {
});
def
}
fn mkvec_v2i8(&mut self, x: Src, y: Src) -> SSAValue {
let def = self.alloc_ssa(16);
self.push_op(OpMkVecV2I8 {
dst: def.into(),
srcs: [x, y],
});
def
}
fn mkvec_v2i16(&mut self, x: Src, y: Src) -> SSAValue {
self.mkvec_v4i8(
x.clone().byte(0),
x.clone().byte(1),
y.clone().byte(0),
y.clone().byte(1),
)
}
fn mkvec_v4i8(&mut self, x: Src, y: Src, z: Src, w: Src) -> SSAValue {
let def = self.alloc_ssa(32);
self.push_op(OpMkVecV4I8 {
dst: def.into(),
srcs: [x, y, z, w],
});
def
}
}
pub struct InstrBuilder {

View file

@ -10,7 +10,7 @@ use crate::ssa_value::SSAValueAllocator;
use compiler::bindings::*;
use compiler::nir::*;
use rustc_hash::FxHashMap;
use std::cmp::max;
use std::cmp::{max, min};
#[derive(Default)]
struct BlockLabelMap {
@ -140,6 +140,117 @@ impl<'a> ShaderFromNir<'a> {
}
}
fn parse_alu(&mut self, b: &mut impl SSABuilder, alu: &nir_alu_instr) {
// Handle vectors and pack ops as a special case since they're the only
// ALU ops that can produce more than 16B. They are also the only ALU
// ops which we allow to consume small (8 and 16-bit) vector data
// scattered across multiple dwords
if matches!(
alu.op,
nir_op_mov
| nir_op_pack_32_4x8
| nir_op_pack_32_4x8_split
| nir_op_pack_32_2x16
| nir_op_pack_32_2x16_split
| nir_op_pack_64_2x32
| nir_op_pack_64_2x32_split
| nir_op_pack_64_4x16
| nir_op_vec2
| nir_op_vec3
| nir_op_vec4
| nir_op_vec5
| nir_op_vec8
| nir_op_vec16
) {
let mut nsrcs = Vec::new();
if alu.info().num_inputs == 1 {
let src = alu.get_src(0);
for c in 0..usize::from(alu.src_components(0)) {
nsrcs.push((src.src.as_def(), src.swizzle[c]));
}
} else {
for src in alu.srcs_as_slice().iter() {
nsrcs.push((src.src.as_def(), src.swizzle[0]))
}
}
let src_bit_size = alu.get_src(0).src.bit_size();
let mut srcs = Vec::new();
match src_bit_size {
8 => {
for (def, c) in nsrcs {
let ssa = self.get_ssa(def)[usize::from(c) / 4];
srcs.push(Src::from(ssa).byte(c % 4));
}
}
16 => {
for (def, c) in nsrcs {
let ssa = self.get_ssa(def)[usize::from(c) / 2];
srcs.push(Src::from(ssa).half(c % 2));
}
}
32 => {
for (def, c) in nsrcs {
let ssa = self.get_ssa(def)[usize::from(c)];
srcs.push(Src::from(ssa));
}
}
64 => {
for (def, c) in nsrcs {
let vec = self.get_ssa(def);
srcs.push(Src::from(vec[usize::from(c) * 2 + 0]));
srcs.push(Src::from(vec[usize::from(c) * 2 + 1]));
}
}
_ => panic!("Unsupported bit size: {src_bit_size}"),
}
// We flattened i64 to v2i32
let src_bit_size = min(src_bit_size, 32);
let mut srcs = srcs.into_iter();
let mut dst_vec = Vec::new();
if srcs.len() == 1 && src_bit_size <= 16 {
let x = srcs.next().unwrap();
dst_vec.push(b.mov_i16(x));
} else if srcs.len() == 2 && src_bit_size == 8 {
let x = srcs.next().unwrap();
let y = srcs.next().unwrap();
dst_vec.push(b.mkvec_v2i8(x, y));
} else if src_bit_size == 8 {
loop {
let Some(x) = srcs.next() else {
break;
};
let y = srcs.next().unwrap_or(0.into());
let z = srcs.next().unwrap_or(0.into());
let w = srcs.next().unwrap_or(0.into());
dst_vec.push(b.mkvec_v4i8(x, y, z, w));
}
} else if src_bit_size == 16 {
let mut srcs = srcs.into_iter();
loop {
let Some(x) = srcs.next() else {
break;
};
let y = srcs.next().unwrap_or(0.into());
dst_vec.push(b.mkvec_v2i16(x, y));
}
} else if src_bit_size == 32 {
dst_vec = srcs.map(|src| b.mov_i32(src)).collect();
} else {
panic!("Unsupported bit size: {src_bit_size}");
}
self.set_ssa(&alu.def, dst_vec);
return;
}
match alu.op {
_ => panic!("Unsupported ALU instruction: {}", alu.info().name()),
}
}
fn parse_block(
&mut self,
ssa_alloc: &mut SSAValueAllocator,
@ -153,6 +264,9 @@ impl<'a> ShaderFromNir<'a> {
nir_instr_type_load_const => {
self.parse_const(&mut b, ni.as_load_const().unwrap())
}
nir_instr_type_alu => {
self.parse_alu(&mut b, ni.as_alu().unwrap())
}
_ => panic!("Unsupported instruction type"),
}
}

View file

@ -66,6 +66,50 @@ impl fmt::Display for OpEnd {
}
}
#[repr(C)]
#[derive(Clone, Opcode)]
pub struct OpMkVecV2I8 {
#[dst_type(V2I8)]
pub dst: Dst,
#[src_type(I8)]
pub srcs: [Src; 2],
}
impl fmt::Display for OpMkVecV2I8 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} = MKVEC.v4i8 {} {}",
&self.dst, &self.srcs[0], &self.srcs[1],
)
}
}
#[repr(C)]
#[derive(Clone, Opcode)]
pub struct OpMkVecV4I8 {
#[dst_type(V4I8)]
pub dst: Dst,
#[src_type(I8)]
pub srcs: [Src; 4],
}
impl fmt::Display for OpMkVecV4I8 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} = MKVEC.v4i8 {} {} {} {}",
&self.dst,
&self.srcs[0],
&self.srcs[1],
&self.srcs[2],
&self.srcs[3],
)
}
}
#[repr(C)]
#[derive(Clone, Opcode)]
#[variants(dst_type in [I16, I32])]
@ -85,5 +129,7 @@ impl fmt::Display for OpMov {
pub enum Op {
Branch(OpBranch),
End(OpEnd),
MkVecV2I8(OpMkVecV2I8),
MkVecV4I8(OpMkVecV4I8),
Mov(OpMov),
}