From d031365f7ca8bb26599c5af8f787c803cd82ded9 Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Fri, 13 Mar 2026 10:56:41 +0100 Subject: [PATCH] nak: support MUFU.F16 Reviewed-by: Mel Henning Part-of: --- src/nouveau/compiler/nak/builder.rs | 9 ++-- src/nouveau/compiler/nak/from_nir.rs | 63 ++++++++++++++++++------- src/nouveau/compiler/nak/ir.rs | 28 +++++++++-- src/nouveau/compiler/nak/sm70_encode.rs | 11 ++++- 4 files changed, 87 insertions(+), 24 deletions(-) diff --git a/src/nouveau/compiler/nak/builder.rs b/src/nouveau/compiler/nak/builder.rs index 2c7352764fa..9fb0254eb2a 100644 --- a/src/nouveau/compiler/nak/builder.rs +++ b/src/nouveau/compiler/nak/builder.rs @@ -730,12 +730,13 @@ pub trait SSABuilder: Builder { dst } - fn mufu(&mut self, op: MuFuOp, src: Src) -> SSAValue { + fn mufu(&mut self, op: MuFuOp, src: Src, op_type: FloatType) -> SSAValue { let dst = self.alloc_ssa(RegFile::GPR); self.push_op(OpMuFu { dst: dst.into(), op: op, src: src, + op_type: op_type, }); dst } @@ -748,7 +749,7 @@ pub trait SSABuilder: Builder { op: RroOp::SinCos, src, }); - self.mufu(MuFuOp::Sin, tmp.into()) + self.mufu(MuFuOp::Sin, tmp.into(), FloatType::F32) } fn fcos(&mut self, src: Src) -> SSAValue { @@ -759,7 +760,7 @@ pub trait SSABuilder: Builder { op: RroOp::SinCos, src, }); - self.mufu(MuFuOp::Cos, tmp.into()) + self.mufu(MuFuOp::Cos, tmp.into(), FloatType::F32) } fn fexp2(&mut self, src: Src) -> SSAValue { @@ -774,7 +775,7 @@ pub trait SSABuilder: Builder { }); tmp.into() }; - self.mufu(MuFuOp::Exp2, tmp) + self.mufu(MuFuOp::Exp2, tmp, FloatType::F32) } fn prmt(&mut self, x: Src, y: Src, sel: [u8; 4]) -> SSAValue { diff --git a/src/nouveau/compiler/nak/from_nir.rs b/src/nouveau/compiler/nak/from_nir.rs index c57c27fcccd..7a9c229bdab 100644 --- a/src/nouveau/compiler/nak/from_nir.rs +++ b/src/nouveau/compiler/nak/from_nir.rs @@ -978,7 +978,12 @@ impl<'a> ShaderFromNir<'a> { nir_op_fcos => b.fcos(srcs(0)).into(), nir_op_fcos_normalized_2_pi => { assert!(self.sm.sm() >= 70); - b.mufu(MuFuOp::Cos, srcs(0)).into() + b.mufu( + MuFuOp::Cos, + srcs(0), + FloatType::from_bits(alu.def.bit_size().into()), + ) + .into() } nir_op_feq | nir_op_fge | nir_op_flt | nir_op_fneu => { let src_type = @@ -1044,7 +1049,13 @@ impl<'a> ShaderFromNir<'a> { } dst } - nir_op_fexp2 => b.fexp2(srcs(0)).into(), + nir_op_fexp2 => { + if alu.def.bit_size == 16 { + b.mufu(MuFuOp::Exp2, srcs(0), FloatType::F16).into() + } else { + b.fexp2(srcs(0)).into() + } + } nir_op_ffma => { let ftype = FloatType::from_bits(alu.def.bit_size().into()); let dst; @@ -1108,10 +1119,13 @@ impl<'a> ShaderFromNir<'a> { }); dst.into() } - nir_op_flog2 => { - assert!(alu.def.bit_size() == 32); - b.mufu(MuFuOp::Log2, srcs(0)).into() - } + nir_op_flog2 => b + .mufu( + MuFuOp::Log2, + srcs(0), + FloatType::from_bits(alu.def.bit_size().into()), + ) + .into(), nir_op_fmax | nir_op_fmin => { let dst; if alu.def.bit_size() == 64 { @@ -1243,14 +1257,20 @@ impl<'a> ShaderFromNir<'a> { } .into() } - nir_op_frcp => { - assert!(alu.def.bit_size() == 32); - b.mufu(MuFuOp::Rcp, srcs(0)).into() - } - nir_op_frsq => { - assert!(alu.def.bit_size() == 32); - b.mufu(MuFuOp::Rsq, srcs(0)).into() - } + nir_op_frcp => b + .mufu( + MuFuOp::Rcp, + srcs(0), + FloatType::from_bits(alu.def.bit_size().into()), + ) + .into(), + nir_op_frsq => b + .mufu( + MuFuOp::Rsq, + srcs(0), + FloatType::from_bits(alu.def.bit_size().into()), + ) + .into(), nir_op_fsat => { let ftype = FloatType::from_bits(alu.def.bit_size().into()); @@ -1315,9 +1335,20 @@ impl<'a> ShaderFromNir<'a> { nir_op_fsin => b.fsin(srcs(0)).into(), nir_op_fsin_normalized_2_pi => { assert!(self.sm.sm() >= 70); - b.mufu(MuFuOp::Sin, srcs(0)).into() + b.mufu( + MuFuOp::Sin, + srcs(0), + FloatType::from_bits(alu.def.bit_size().into()), + ) + .into() } - nir_op_fsqrt => b.mufu(MuFuOp::Sqrt, srcs(0)).into(), + nir_op_fsqrt => b + .mufu( + MuFuOp::Sqrt, + srcs(0), + FloatType::from_bits(alu.def.bit_size().into()), + ) + .into(), nir_op_i2f16 | nir_op_i2f32 | nir_op_i2f64 => { let src_bits = alu.get_src(0).src.bit_size(); let dst_bits = alu.def.bit_size(); diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index a49d8cdfef9..2f9feda1aad 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -3138,20 +3138,42 @@ impl fmt::Display for MuFuOp { } #[repr(C)] -#[derive(SrcsAsSlice, DstsAsSlice)] +#[derive(DstsAsSlice)] pub struct OpMuFu { #[dst_type(F32)] pub dst: Dst, pub op: MuFuOp, - #[src_type(F32)] pub src: Src, + + pub op_type: FloatType, +} + +impl AsSlice for OpMuFu { + type Attr = SrcType; + + fn as_slice(&self) -> &[Src] { + std::slice::from_ref(&self.src) + } + + fn as_mut_slice(&mut self) -> &mut [Src] { + std::slice::from_mut(&mut self.src) + } + + fn attrs(&self) -> SrcTypeList { + let src_type = match self.op_type { + FloatType::F16 => SrcType::F16, + FloatType::F32 => SrcType::F32, + FloatType::F64 => unreachable!("MuFu does not support F64"), + }; + SrcTypeList::Uniform(src_type) + } } impl DisplayOp for OpMuFu { fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "mufu.{} {}", self.op, self.src) + write!(f, "mufu.{}{} {}", self.op, self.op_type, self.src) } } impl_display_for_op!(OpMuFu); diff --git a/src/nouveau/compiler/nak/sm70_encode.rs b/src/nouveau/compiler/nak/sm70_encode.rs index b9d7eb89ca7..5236615a470 100644 --- a/src/nouveau/compiler/nak/sm70_encode.rs +++ b/src/nouveau/compiler/nak/sm70_encode.rs @@ -982,7 +982,16 @@ impl SM70Op for OpMuFu { } fn encode(&self, e: &mut SM70Encoder<'_>) { - e.encode_alu(0x108, Some(&self.dst), None, Some(&self.src), None); + e.encode_alu_base( + 0x108, + Some(&self.dst), + None, + Some(&self.src), + None, + self.src_types()[0].into(), + ); + + e.set_bit(73, self.op_type == FloatType::F16); e.set_field( 74..80, match self.op {