diff --git a/src/nouveau/compiler/nak/from_nir.rs b/src/nouveau/compiler/nak/from_nir.rs index 5316fd5da48..4251d04bfd6 100644 --- a/src/nouveau/compiler/nak/from_nir.rs +++ b/src/nouveau/compiler/nak/from_nir.rs @@ -1466,8 +1466,6 @@ impl<'a> ShaderFromNir<'a> { nir_op_ixor => b.lop2(LogicOp2::Xor, srcs[0], srcs[1]), nir_op_pack_half_2x16_split | nir_op_pack_half_2x16_rtz_split => { assert!(alu.get_src(0).bit_size() == 32); - let low = b.alloc_ssa(RegFile::GPR, 1); - let high = b.alloc_ssa(RegFile::GPR, 1); let rnd_mode = match alu.op { nir_op_pack_half_2x16_split => FRndMode::NearestEven, @@ -1475,32 +1473,46 @@ impl<'a> ShaderFromNir<'a> { _ => panic!("Unhandled fp16 pack op"), }; - b.push_op(OpF2F { - dst: low.into(), - src: srcs[0], - src_type: FloatType::F32, - dst_type: FloatType::F16, - rnd_mode: rnd_mode, - ftz: false, - high: false, - integer_rnd: false, - }); + if self.sm.sm() >= 86 { + let result: SSARef = b.alloc_ssa(RegFile::GPR, 1); + b.push_op(OpF2FP { + dst: result.into(), + srcs: [srcs[1], srcs[0]], + rnd_mode: rnd_mode, + }); - let src_bits = usize::from(alu.get_src(1).bit_size()); - let src_type = FloatType::from_bits(src_bits); - assert!(matches!(src_type, FloatType::F32)); - b.push_op(OpF2F { - dst: high.into(), - src: srcs[1], - src_type: FloatType::F32, - dst_type: FloatType::F16, - rnd_mode: rnd_mode, - ftz: false, - high: false, - integer_rnd: false, - }); + result + } else { + let low = b.alloc_ssa(RegFile::GPR, 1); + let high = b.alloc_ssa(RegFile::GPR, 1); - b.prmt(low.into(), high.into(), [0, 1, 4, 5]) + b.push_op(OpF2F { + dst: low.into(), + src: srcs[0], + src_type: FloatType::F32, + dst_type: FloatType::F16, + rnd_mode: rnd_mode, + ftz: false, + high: false, + integer_rnd: false, + }); + + let src_bits = usize::from(alu.get_src(1).bit_size()); + let src_type = FloatType::from_bits(src_bits); + assert!(matches!(src_type, FloatType::F32)); + b.push_op(OpF2F { + dst: high.into(), + src: srcs[1], + src_type: FloatType::F32, + dst_type: FloatType::F16, + rnd_mode: rnd_mode, + ftz: false, + high: false, + integer_rnd: false, + }); + + b.prmt(low.into(), high.into(), [0, 1, 4, 5]) + } } nir_op_prmt_nv => { let dst = b.alloc_ssa(RegFile::GPR, 1); diff --git a/src/nouveau/compiler/nak/hw_tests.rs b/src/nouveau/compiler/nak/hw_tests.rs index 86822464154..c6de62f95b7 100644 --- a/src/nouveau/compiler/nak/hw_tests.rs +++ b/src/nouveau/compiler/nak/hw_tests.rs @@ -1135,3 +1135,60 @@ fn test_shr64() { } } } + +#[test] +fn test_f2fp_pack_ab() { + let run = RunSingleton::get(); + let mut b = TestShaderBuilder::new(run.sm.as_ref()); + + let srcs = SSARef::from([ + b.ld_test_data(0, MemType::B32)[0], + b.ld_test_data(4, MemType::B32)[0], + ]); + + let dst = b.alloc_ssa(RegFile::GPR, 1); + b.push_op(OpF2FP { + dst: dst.into(), + srcs: [srcs[0].into(), srcs[1].into()], + rnd_mode: FRndMode::NearestEven, + }); + b.st_test_data(8, MemType::B32, dst[0].into()); + + let dst = b.alloc_ssa(RegFile::GPR, 1); + b.push_op(OpF2FP { + dst: dst.into(), + srcs: [srcs[0].into(), 2.0.into()], + rnd_mode: FRndMode::Zero, + }); + b.st_test_data(12, MemType::B32, dst[0].into()); + + let bin = b.compile(); + + fn f32_to_u32(val: f32) -> u32 { + u32::from_le_bytes(val.to_le_bytes()) + } + + let zero = f32_to_u32(0.0); + let one = f32_to_u32(1.0); + let two = f32_to_u32(2.0); + let complex = f32_to_u32(1.4556); + + let mut data = Vec::new(); + data.push([one, two, 0, 0]); + data.push([one, zero, 0, 0]); + data.push([complex, zero, 0, 0]); + run.run.run(&bin, &mut data).unwrap(); + + // { 1.0fp16, 2.0fp16 } + assert_eq!(data[0][2], 0x3c004000); + // { 1.0fp16, 2.0fp16 } + assert_eq!(data[0][3], 0x3c004000); + // { 1.0fp16, 0.0fp16 } + assert_eq!(data[1][2], 0x3c000000); + // { 1.0fp16, 0.0fp16 } + assert_eq!(data[1][3], 0x3c004000); + // { 1.456fp16, 0.0fp16 } + assert_eq!(data[2][2], 0x3dd30000); + // { 1.455fp16, 0.0fp16 } + assert_eq!(data[2][3], 0x3dd24000); +} diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index e79ff3aeaa0..c268bcd87cf 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -4058,6 +4058,29 @@ impl DisplayOp for OpF2F { } impl_display_for_op!(OpF2F); +#[repr(C)] +#[derive(DstsAsSlice, SrcsAsSlice)] +pub struct OpF2FP { + #[dst_type(GPR)] + pub dst: Dst, + + #[src_type(ALU)] + pub srcs: [Src; 2], + + pub rnd_mode: FRndMode, +} + +impl DisplayOp for OpF2FP { + fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "f2fp.pack_ab")?; + if self.rnd_mode != FRndMode::NearestEven { + write!(f, "{}", self.rnd_mode)?; + } + write!(f, " {}, {}", self.srcs[0], self.srcs[1],) + } +} +impl_display_for_op!(OpF2FP); + #[repr(C)] #[derive(DstsAsSlice)] pub struct OpF2I { @@ -6159,6 +6182,7 @@ pub enum Op { Shl(OpShl), Shr(OpShr), F2F(OpF2F), + F2FP(OpF2FP), F2I(OpF2I), I2F(OpI2F), I2I(OpI2I), @@ -6606,7 +6630,8 @@ impl Instr { pub fn has_fixed_latency(&self, sm: u8) -> bool { match &self.op { // Float ALU - Op::FAdd(_) + Op::F2FP(_) + | Op::FAdd(_) | Op::FFma(_) | Op::FMnMx(_) | Op::FMul(_) diff --git a/src/nouveau/compiler/nak/sm70.rs b/src/nouveau/compiler/nak/sm70.rs index fb13e405886..b4c8ea40df5 100644 --- a/src/nouveau/compiler/nak/sm70.rs +++ b/src/nouveau/compiler/nak/sm70.rs @@ -1921,6 +1921,44 @@ impl SM70Op for OpF2F { } } +impl SM70Op for OpF2FP { + fn legalize(&mut self, b: &mut LegalizeBuilder) { + let gpr = op_gpr(self); + let [src0, src1] = &mut self.srcs; + swap_srcs_if_not_reg(src0, src1, gpr); + + b.copy_alu_src_if_not_reg(src0, gpr, SrcType::ALU); + } + + fn encode(&self, e: &mut SM70Encoder<'_>) { + if src_is_zero_or_gpr(&self.srcs[1]) { + e.encode_alu( + 0x03e, + Some(&self.dst), + Some(&self.srcs[0]), + Some(&self.srcs[1]), + Some(&Src::new_zero()), + ) + } else { + e.encode_alu( + 0x03e, + Some(&self.dst), + None, + Some(&self.srcs[1]), + Some(&self.srcs[0]), + ) + }; + + // .MERGE_C behavior + // Use src1 and src2, src0 is unused + // src1 get converted and packed in the lower 16 bits of dest. + // src2 lower or high 16 bits (decided by .H1 flag) get packed in the upper of dest. + e.set_bit(78, false); // TODO: .MERGE_C + e.set_bit(72, false); // .H1 (MERGE_C only) + e.set_rnd_mode(79..81, self.rnd_mode); + } +} + impl SM70Op for OpF2I { fn legalize(&mut self, _b: &mut LegalizeBuilder) { // Nothing to do @@ -3397,6 +3435,7 @@ macro_rules! as_sm70_op_match { Op::PopC(op) => op, Op::Shf(op) => op, Op::F2F(op) => op, + Op::F2FP(op) => op, Op::F2I(op) => op, Op::I2F(op) => op, Op::FRnd(op) => op,