From 001de6d71bd788f784b0a96177dab010221d31d9 Mon Sep 17 00:00:00 2001 From: Mel Henning Date: Mon, 6 Apr 2026 15:49:53 -0400 Subject: [PATCH] nak: Fix mufu's f16 bit on sm90+ Fixes multiple cts tests on blackwell, including eg. dEQP-VK.spirv_assembly.instruction.graphics.float16.arithmetic_2.opfdiv_tessc Fixes: d031365f7ca ("nak: support MUFU.F16") Reviewed-by: Karol Herbst Part-of: --- src/nouveau/compiler/nak/nvdisasm_tests.rs | 38 ++++++++++++++++++++++ src/nouveau/compiler/nak/sm70_encode.rs | 14 +++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/nouveau/compiler/nak/nvdisasm_tests.rs b/src/nouveau/compiler/nak/nvdisasm_tests.rs index f703bebec83..e2cc03cde91 100644 --- a/src/nouveau/compiler/nak/nvdisasm_tests.rs +++ b/src/nouveau/compiler/nak/nvdisasm_tests.rs @@ -1009,3 +1009,41 @@ pub fn test_isbewr() { c.check(sm); } } + +#[test] +pub fn test_mufu() { + let r2 = RegRef::new(RegFile::GPR, 2, 1); + let r3 = RegRef::new(RegFile::GPR, 3, 1); + + use MuFuOp::*; + let ops = [Cos, Sin, Exp2, Log2, Rcp, Rsq, Rcp64H, Rsq64H, Sqrt, Tanh]; + let op_types = [(FloatType::F32, ""), (FloatType::F16, ".f16")]; + + for &sm in sm_list() { + let mut c = DisasmCheck::new(); + + for op in ops { + for (op_type, op_type_str) in op_types { + match (op, op_type) { + (Rcp64H | Rsq64H, FloatType::F16) => continue, + _ => (), + } + let instr = OpMuFu { + dst: Dst::Reg(r2), + src: SrcRef::Reg(r3).into(), + op, + op_type, + }; + let op_str = match op { + Exp2 => ".ex2".into(), + Log2 => ".lg2".into(), + _ => format!(".{op}"), + }; + let disasm = format!("mufu{op_str}{op_type_str} r2, r3;"); + c.push(instr, disasm); + } + } + + c.check(sm); + } +} diff --git a/src/nouveau/compiler/nak/sm70_encode.rs b/src/nouveau/compiler/nak/sm70_encode.rs index 5236615a470..63de8e2a5c2 100644 --- a/src/nouveau/compiler/nak/sm70_encode.rs +++ b/src/nouveau/compiler/nak/sm70_encode.rs @@ -991,7 +991,19 @@ impl SM70Op for OpMuFu { self.src_types()[0].into(), ); - e.set_bit(73, self.op_type == FloatType::F16); + if e.sm >= 90 { + e.set_field( + 72..73, + match self.op_type { + FloatType::F32 => 0u8, + FloatType::F16 => 1u8, + /* .bf16 => 2 */ + FloatType::F64 => unreachable!(), + }, + ); + } else { + e.set_bit(73, self.op_type == FloatType::F16); + } e.set_field( 74..80, match self.op {