diff --git a/src/panfrost/compiler/kraid/nir.rs b/src/panfrost/compiler/kraid/nir.rs index de465c9f4fb..319d7ac4864 100644 --- a/src/panfrost/compiler/kraid/nir.rs +++ b/src/panfrost/compiler/kraid/nir.rs @@ -4,6 +4,7 @@ #![allow(non_upper_case_globals)] use crate::builder::*; +use crate::data_type::*; use crate::ir::*; use crate::ops::*; use crate::ssa_value::SSAValueAllocator; @@ -82,6 +83,53 @@ impl<'a> ShaderFromNir<'a> { self.get_src_ssa(src).into() } + fn get_alu_src(&self, src: &nir_alu_src, comps: u8) -> Src { + let src_vec = self.get_ssa(src.src.as_def()); + + match src.src.as_def().bit_size { + 8 => { + assert!(comps <= 4); + let w = src.swizzle[0] / 4; + let mut bytes = [src.swizzle[0] % 4; 4]; + for i in 1..usize::from(comps) { + assert!(src.swizzle[i] / 4 == w); + bytes[i] = src.swizzle[i] % 4; + } + if comps == 2 { + // For vec2's, make it symmetric + bytes[2] = bytes[0]; + bytes[3] = bytes[1]; + } + let swizzle = Swizzle::from_bytes(bytes); + Src::from(src_vec[usize::from(w)]).swizzle(swizzle) + } + 16 => { + assert!(comps <= 2); + let w = src.swizzle[0] / 2; + let mut halves = [src.swizzle[0] % 2; 2]; + if comps == 2 { + assert!(src.swizzle[1] / 2 == w); + halves[1] = src.swizzle[1] % 2; + } + let swizzle = Swizzle::from_halves(halves); + Src::from(src_vec[usize::from(w)]).swizzle(swizzle) + } + 32 => { + assert!(comps == 1); + src_vec[usize::from(src.swizzle[0])].into() + } + 64 => { + assert!(comps == 1); + [ + src_vec[usize::from(src.swizzle[0]) * 2], + src_vec[usize::from(src.swizzle[0]) * 2 + 1], + ] + .into() + } + bit_size => panic!("Unsupported bit size: {bit_size}"), + } + } + fn parse_const( &mut self, b: &mut impl SSABuilder, @@ -246,7 +294,61 @@ impl<'a> ShaderFromNir<'a> { return; } + let mut srcs_vec = Vec::new(); + for i in 0..alu.info().num_inputs { + let comps = alu.src_components(i); + srcs_vec.push(self.get_alu_src(alu.get_src(i.into()), comps)); + } + let srcs_vec = srcs_vec; + + // Cloning ALU sources should always be cheap but the helper makes + // things more ergonamic. + let srcs = |i: usize| srcs_vec[i].clone(); + let src_type = |i, num_type| { + let comps = alu.src_components(i); + let bits = alu.get_src(i.into()).bit_size(); + DataType::get(comps, num_type, bits) + }; + + let dst = self.alloc_ssa(b, &alu.def); + let dst_type = |num_type| { + DataType::get(alu.def.num_components, num_type, alu.def.bit_size) + }; + match alu.op { + nir_op_fabs => { + // TODO: Do we really want FAdd for this? + b.push_op(OpFAdd { + dst: dst.into(), + dst_type: dst_type(NumericType::Float), + srcs: [srcs(0).fabs(), Src::from(0).fneg()], + }); + } + nir_op_fadd => { + b.push_op(OpFAdd { + dst: dst.into(), + dst_type: dst_type(NumericType::Float), + srcs: [srcs(0), srcs(1)], + }); + } + nir_op_feq16 | nir_op_feq32 | nir_op_fge16 | nir_op_fge32 + | nir_op_flt16 | nir_op_flt32 | nir_op_fneu16 | nir_op_fneu32 => { + b.push_op(OpFCmp { + dst: dst.into(), + src_type: src_type(0, NumericType::Float), + res_type: CmpResultType::M1, + cmp_op: match alu.op { + nir_op_feq16 | nir_op_feq32 => CmpOp::Eq, + nir_op_fge16 | nir_op_fge32 => CmpOp::Ge, + nir_op_flt16 | nir_op_flt32 => CmpOp::Lt, + nir_op_fneu16 | nir_op_fneu32 => CmpOp::Ne, + _ => panic!("Usupported float comparison"), + }, + srcs: [srcs(0), srcs(1)], + accum: 0.into(), + accum_op: CmpAccumOp::None, + }); + } _ => panic!("Unsupported ALU instruction: {}", alu.info().name()), } } diff --git a/src/panfrost/compiler/kraid/ops.rs b/src/panfrost/compiler/kraid/ops.rs index fda0c864117..ebf513fa641 100644 --- a/src/panfrost/compiler/kraid/ops.rs +++ b/src/panfrost/compiler/kraid/ops.rs @@ -66,6 +66,119 @@ impl fmt::Display for OpEnd { } } +#[repr(C)] +#[derive(Clone, Opcode)] +#[variants(dst_type in [F16, V2F16, F32])] +pub struct OpFAdd { + pub dst: Dst, + pub dst_type: DataType, + pub srcs: [Src; 2], +} + +impl fmt::Display for OpFAdd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} = FADD.{} {} {}", + &self.dst, &self.dst_type, &self.srcs[0], &self.srcs[1] + ) + } +} + +#[derive(Clone, Copy, Eq, Hash, PartialEq)] +pub enum CmpAccumOp { + None, + And, + Or, +} + +impl fmt::Display for CmpAccumOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CmpAccumOp::None => Ok(()), + CmpAccumOp::And => write!(f, "_AND"), + CmpAccumOp::Or => write!(f, "_OR"), + } + } +} + +#[derive(Clone, Copy, Eq, Hash, PartialEq)] +pub enum CmpResultType { + I1, + F1, + M1, +} + +impl fmt::Display for CmpResultType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CmpResultType::I1 => write!(f, ".i1"), + CmpResultType::F1 => write!(f, ".f1"), + CmpResultType::M1 => write!(f, ".m1"), + } + } +} + +#[derive(Clone, Copy, Eq, Hash, PartialEq)] +pub enum CmpOp { + Eq, + Gt, + Ge, + Ne, + Lt, + Le, + GtLt, + Total, +} + +impl fmt::Display for CmpOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CmpOp::Eq => write!(f, ".eq"), + CmpOp::Gt => write!(f, ".gt"), + CmpOp::Ge => write!(f, ".ge"), + CmpOp::Ne => write!(f, ".ne"), + CmpOp::Lt => write!(f, ".lt"), + CmpOp::Le => write!(f, ".le"), + CmpOp::GtLt => write!(f, "gtlt"), + CmpOp::Total => write!(f, ".total"), + } + } +} + +#[repr(C)] +#[derive(Clone, Opcode)] +#[variants(src_type in [F16, V2F16, F32])] +pub struct OpFCmp { + pub dst: Dst, + + pub src_type: DataType, + pub res_type: CmpResultType, + pub cmp_op: CmpOp, + + pub srcs: [Src; 2], + + #[src_type(VNIN)] + pub accum: Src, + pub accum_op: CmpAccumOp, +} + +impl fmt::Display for OpFCmp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} = FCMP{}.{}{}{} {} {}", + &self.dst, + self.accum_op, + self.src_type, + self.res_type, + self.cmp_op, + &self.srcs[0], + &self.srcs[1], + ) + } +} + #[repr(C)] #[derive(Clone, Opcode)] pub struct OpMkVecV2I8 { @@ -129,6 +242,8 @@ impl fmt::Display for OpMov { pub enum Op { Branch(OpBranch), End(OpEnd), + FAdd(OpFAdd), + FCmp(OpFCmp), MkVecV2I8(OpMkVecV2I8), MkVecV4I8(OpMkVecV4I8), Mov(OpMov),