kraid: Add some float alu ops

This also sets up the rest of the ALU op infrastructure, including
handling source swizzles coming in from NIR.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41841>
This commit is contained in:
Faith Ekstrand 2026-05-11 23:28:29 -04:00 committed by Marge Bot
parent 2febabbd3c
commit b2adfd28e9
2 changed files with 217 additions and 0 deletions

View file

@ -4,6 +4,7 @@
#![allow(non_upper_case_globals)]
use crate::builder::*;
use crate::data_type::*;
use crate::ir::*;
use crate::ops::*;
use crate::ssa_value::SSAValueAllocator;
@ -82,6 +83,53 @@ impl<'a> ShaderFromNir<'a> {
self.get_src_ssa(src).into()
}
fn get_alu_src(&self, src: &nir_alu_src, comps: u8) -> Src {
let src_vec = self.get_ssa(src.src.as_def());
match src.src.as_def().bit_size {
8 => {
assert!(comps <= 4);
let w = src.swizzle[0] / 4;
let mut bytes = [src.swizzle[0] % 4; 4];
for i in 1..usize::from(comps) {
assert!(src.swizzle[i] / 4 == w);
bytes[i] = src.swizzle[i] % 4;
}
if comps == 2 {
// For vec2's, make it symmetric
bytes[2] = bytes[0];
bytes[3] = bytes[1];
}
let swizzle = Swizzle::from_bytes(bytes);
Src::from(src_vec[usize::from(w)]).swizzle(swizzle)
}
16 => {
assert!(comps <= 2);
let w = src.swizzle[0] / 2;
let mut halves = [src.swizzle[0] % 2; 2];
if comps == 2 {
assert!(src.swizzle[1] / 2 == w);
halves[1] = src.swizzle[1] % 2;
}
let swizzle = Swizzle::from_halves(halves);
Src::from(src_vec[usize::from(w)]).swizzle(swizzle)
}
32 => {
assert!(comps == 1);
src_vec[usize::from(src.swizzle[0])].into()
}
64 => {
assert!(comps == 1);
[
src_vec[usize::from(src.swizzle[0]) * 2],
src_vec[usize::from(src.swizzle[0]) * 2 + 1],
]
.into()
}
bit_size => panic!("Unsupported bit size: {bit_size}"),
}
}
fn parse_const(
&mut self,
b: &mut impl SSABuilder,
@ -246,7 +294,61 @@ impl<'a> ShaderFromNir<'a> {
return;
}
let mut srcs_vec = Vec::new();
for i in 0..alu.info().num_inputs {
let comps = alu.src_components(i);
srcs_vec.push(self.get_alu_src(alu.get_src(i.into()), comps));
}
let srcs_vec = srcs_vec;
// Cloning ALU sources should always be cheap but the helper makes
// things more ergonamic.
let srcs = |i: usize| srcs_vec[i].clone();
let src_type = |i, num_type| {
let comps = alu.src_components(i);
let bits = alu.get_src(i.into()).bit_size();
DataType::get(comps, num_type, bits)
};
let dst = self.alloc_ssa(b, &alu.def);
let dst_type = |num_type| {
DataType::get(alu.def.num_components, num_type, alu.def.bit_size)
};
match alu.op {
nir_op_fabs => {
// TODO: Do we really want FAdd for this?
b.push_op(OpFAdd {
dst: dst.into(),
dst_type: dst_type(NumericType::Float),
srcs: [srcs(0).fabs(), Src::from(0).fneg()],
});
}
nir_op_fadd => {
b.push_op(OpFAdd {
dst: dst.into(),
dst_type: dst_type(NumericType::Float),
srcs: [srcs(0), srcs(1)],
});
}
nir_op_feq16 | nir_op_feq32 | nir_op_fge16 | nir_op_fge32
| nir_op_flt16 | nir_op_flt32 | nir_op_fneu16 | nir_op_fneu32 => {
b.push_op(OpFCmp {
dst: dst.into(),
src_type: src_type(0, NumericType::Float),
res_type: CmpResultType::M1,
cmp_op: match alu.op {
nir_op_feq16 | nir_op_feq32 => CmpOp::Eq,
nir_op_fge16 | nir_op_fge32 => CmpOp::Ge,
nir_op_flt16 | nir_op_flt32 => CmpOp::Lt,
nir_op_fneu16 | nir_op_fneu32 => CmpOp::Ne,
_ => panic!("Usupported float comparison"),
},
srcs: [srcs(0), srcs(1)],
accum: 0.into(),
accum_op: CmpAccumOp::None,
});
}
_ => panic!("Unsupported ALU instruction: {}", alu.info().name()),
}
}

View file

@ -66,6 +66,119 @@ impl fmt::Display for OpEnd {
}
}
#[repr(C)]
#[derive(Clone, Opcode)]
#[variants(dst_type in [F16, V2F16, F32])]
pub struct OpFAdd {
pub dst: Dst,
pub dst_type: DataType,
pub srcs: [Src; 2],
}
impl fmt::Display for OpFAdd {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} = FADD.{} {} {}",
&self.dst, &self.dst_type, &self.srcs[0], &self.srcs[1]
)
}
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub enum CmpAccumOp {
None,
And,
Or,
}
impl fmt::Display for CmpAccumOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CmpAccumOp::None => Ok(()),
CmpAccumOp::And => write!(f, "_AND"),
CmpAccumOp::Or => write!(f, "_OR"),
}
}
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub enum CmpResultType {
I1,
F1,
M1,
}
impl fmt::Display for CmpResultType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CmpResultType::I1 => write!(f, ".i1"),
CmpResultType::F1 => write!(f, ".f1"),
CmpResultType::M1 => write!(f, ".m1"),
}
}
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub enum CmpOp {
Eq,
Gt,
Ge,
Ne,
Lt,
Le,
GtLt,
Total,
}
impl fmt::Display for CmpOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CmpOp::Eq => write!(f, ".eq"),
CmpOp::Gt => write!(f, ".gt"),
CmpOp::Ge => write!(f, ".ge"),
CmpOp::Ne => write!(f, ".ne"),
CmpOp::Lt => write!(f, ".lt"),
CmpOp::Le => write!(f, ".le"),
CmpOp::GtLt => write!(f, "gtlt"),
CmpOp::Total => write!(f, ".total"),
}
}
}
#[repr(C)]
#[derive(Clone, Opcode)]
#[variants(src_type in [F16, V2F16, F32])]
pub struct OpFCmp {
pub dst: Dst,
pub src_type: DataType,
pub res_type: CmpResultType,
pub cmp_op: CmpOp,
pub srcs: [Src; 2],
#[src_type(VNIN)]
pub accum: Src,
pub accum_op: CmpAccumOp,
}
impl fmt::Display for OpFCmp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} = FCMP{}.{}{}{} {} {}",
&self.dst,
self.accum_op,
self.src_type,
self.res_type,
self.cmp_op,
&self.srcs[0],
&self.srcs[1],
)
}
}
#[repr(C)]
#[derive(Clone, Opcode)]
pub struct OpMkVecV2I8 {
@ -129,6 +242,8 @@ impl fmt::Display for OpMov {
pub enum Op {
Branch(OpBranch),
End(OpEnd),
FAdd(OpFAdd),
FCmp(OpFCmp),
MkVecV2I8(OpMkVecV2I8),
MkVecV4I8(OpMkVecV4I8),
Mov(OpMov),