diff --git a/src/nouveau/compiler/nak/from_nir.rs b/src/nouveau/compiler/nak/from_nir.rs index 6098ec33d96..11a64b084c0 100644 --- a/src/nouveau/compiler/nak/from_nir.rs +++ b/src/nouveau/compiler/nak/from_nir.rs @@ -820,11 +820,7 @@ impl<'a> ShaderFromNir<'a> { let src_type = FloatType::from_bits(src_bits.into()); let dst_type = FloatType::from_bits(dst_bits.into()); - let mut src = srcs(0); - if src_bits == 16 { - src = restrict_f16v2_src(src); - } - + let src = srcs(0); let dst = b.alloc_ssa_vec(RegFile::GPR, dst_bits.div_ceil(32)); b.push_op(OpF2F { dst: dst.clone().into(), @@ -2319,13 +2315,7 @@ impl<'a> ShaderFromNir<'a> { let src_type = FloatType::from_bits(src_bit_size.into()); - let mut src = self.get_src(&srcs[0]); - if src_bit_size == 16 - && intrin.def.num_components() == 1 - { - src = src.swizzle(SrcSwizzle::Xx); - } - + let src = self.get_src(&srcs[0]); b.push_op(OpF2F { dst: dst.clone().into(), src, diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 4c95178ac2c..3990a035f64 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -890,7 +890,7 @@ impl SrcMod { } } -#[derive(Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum SrcSwizzle { None, Xx, @@ -976,6 +976,7 @@ impl Src { self } + #[allow(dead_code)] pub fn swizzle(mut self, src_swizzle: SrcSwizzle) -> Src { // Since we only have xx, yy, and xy, for any composition of swizzles, // the inner-most non-xy swizzle wins. @@ -1012,7 +1013,10 @@ impl Src { Some(match src_type { SrcType::F16 => { - let low = u & 0xFFFF; + let low = match self.src_swizzle { + SrcSwizzle::None | SrcSwizzle::Xx => u & 0xffff, + SrcSwizzle::Yy => u >> 16, + }; match self.src_mod { SrcMod::None => low, @@ -4644,23 +4648,6 @@ pub struct OpF2F { pub integer_rnd: bool, } -impl OpF2F { - pub fn is_high(&self) -> bool { - if matches!(self.src_type, FloatType::F16) { - // OpF2F with the same source and destination types is only allowed - // pre-Volta and only with F32. - assert!(!matches!(self.dst_type, FloatType::F16)); - - matches!(self.src.src_swizzle, SrcSwizzle::Yy) - } else if matches!(self.dst_type, FloatType::F16) { - self.dst_high - } else { - assert!(!self.dst_high); - false - } - } -} - impl AsSlice for OpF2F { type Attr = SrcType; @@ -4674,7 +4661,7 @@ impl AsSlice for OpF2F { fn attrs(&self) -> SrcTypeList { let src_type = match self.src_type { - FloatType::F16 => SrcType::F16v2, + FloatType::F16 => SrcType::F16, FloatType::F32 => SrcType::F32, FloatType::F64 => SrcType::F64, }; diff --git a/src/nouveau/compiler/nak/opt_copy_prop.rs b/src/nouveau/compiler/nak/opt_copy_prop.rs index 88e906f662d..ec3932baa0f 100644 --- a/src/nouveau/compiler/nak/opt_copy_prop.rs +++ b/src/nouveau/compiler/nak/opt_copy_prop.rs @@ -313,7 +313,10 @@ impl<'a> CopyPropPass<'a> { // Turn the swizzle into a permute. For F16, we use Xx to // indicate that it only takes the bottom 16 bits. let swizzle_prmt: [u8; 4] = match src_type { - SrcType::F16 => [0, 1, 0, 1], + SrcType::F16 => match src.src_swizzle { + SrcSwizzle::None | SrcSwizzle::Xx => [0, 1, 0, 1], + SrcSwizzle::Yy => [2, 3, 2, 3], + }, SrcType::F16v2 => match src.src_swizzle { SrcSwizzle::None => [0, 1, 2, 3], SrcSwizzle::Xx => [0, 1, 0, 1], @@ -353,12 +356,11 @@ impl<'a> CopyPropPass<'a> { // See if that permute is a valid swizzle let new_swizzle = match src_type { - SrcType::F16 => { - if combined != [0, 1, 0, 1] { - return; - } - SrcSwizzle::None - } + SrcType::F16 => match combined { + [0, 1, _, _] => SrcSwizzle::None, + [2, 3, _, _] => SrcSwizzle::Yy, + _ => return, + }, SrcType::F16v2 => match combined { [0, 1, 2, 3] => SrcSwizzle::None, [0, 1, 0, 1] => SrcSwizzle::Xx, diff --git a/src/nouveau/compiler/nak/sm20.rs b/src/nouveau/compiler/nak/sm20.rs index 24c138e0a09..b320f26b432 100644 --- a/src/nouveau/compiler/nak/sm20.rs +++ b/src/nouveau/compiler/nak/sm20.rs @@ -1449,7 +1449,7 @@ impl SM20Op for OpF2F { e.set_field(23..25, (self.src_type.bits() / 8).ilog2()); e.set_rnd_mode(49..51, self.rnd_mode); e.set_bit(55, self.ftz); - e.set_bit(56, self.is_high()); + e.set_bit(56, self.src.src_swizzle == SrcSwizzle::Yy); } } @@ -1468,7 +1468,7 @@ impl SM20Op for OpF2I { e.set_field(23..25, (self.src_type.bits() / 8).ilog2()); e.set_rnd_mode(49..51, self.rnd_mode); e.set_bit(55, self.ftz); - e.set_bit(56, false); // .high + e.set_bit(56, self.src.src_swizzle == SrcSwizzle::Yy); } } diff --git a/src/nouveau/compiler/nak/sm32.rs b/src/nouveau/compiler/nak/sm32.rs index f5fadf1dccb..f689ddcae0f 100644 --- a/src/nouveau/compiler/nak/sm32.rs +++ b/src/nouveau/compiler/nak/sm32.rs @@ -1580,7 +1580,7 @@ impl SM32Op for OpF2F { e.set_field(12..14, (self.src_type.bits() / 8).ilog2()); e.set_rnd_mode(42..44, self.rnd_mode); - e.set_bit(44, self.is_high()); + e.set_bit(44, self.src.src_swizzle == SrcSwizzle::Yy); e.set_bit(45, self.integer_rnd); e.set_bit(47, self.ftz); e.set_bit(48, src.src_mod.has_fneg()); @@ -1623,7 +1623,7 @@ impl SM32Op for OpF2I { e.set_bit(14, self.dst_type.is_signed()); e.set_rnd_mode(42..44, self.rnd_mode); - // 44: .h1 + e.set_bit(44, self.src.src_swizzle == SrcSwizzle::Yy); e.set_bit(47, self.ftz); e.set_bit(48, self.src.src_mod.has_fneg()); e.set_bit(50, false); // dst.CC diff --git a/src/nouveau/compiler/nak/sm50.rs b/src/nouveau/compiler/nak/sm50.rs index 14f51b65ba5..c31c41c2f2a 100644 --- a/src/nouveau/compiler/nak/sm50.rs +++ b/src/nouveau/compiler/nak/sm50.rs @@ -1788,7 +1788,7 @@ impl SM50Op for OpF2F { e.set_field(10..12, (self.src_type.bits() / 8).ilog2()); e.set_rnd_mode(39..41, self.rnd_mode); - e.set_bit(41, self.is_high()); + e.set_bit(41, self.src.src_swizzle == SrcSwizzle::Yy); e.set_bit(42, self.integer_rnd); e.set_bit(44, self.ftz); e.set_bit(50, false); // saturate @@ -1833,6 +1833,7 @@ impl SM50Op for OpF2I { e.set_bit(12, self.dst_type.is_signed()); e.set_rnd_mode(39..41, self.rnd_mode); + e.set_bit(41, self.src.src_swizzle == SrcSwizzle::Yy); e.set_bit(44, self.ftz); e.set_bit(47, false); // .CC } diff --git a/src/nouveau/compiler/nak/sm70_encode.rs b/src/nouveau/compiler/nak/sm70_encode.rs index 967a69c7e0e..b9d7eb89ca7 100644 --- a/src/nouveau/compiler/nak/sm70_encode.rs +++ b/src/nouveau/compiler/nak/sm70_encode.rs @@ -406,16 +406,47 @@ impl OffsetStride { } } +#[derive(Clone, Copy, Debug, PartialEq)] +enum SM70SrcType { + F16, + F16v2, + Other, +} + +impl From for SM70SrcType { + fn from(value: SrcType) -> Self { + match value { + SrcType::F16 => SM70SrcType::F16, + SrcType::F16v2 => SM70SrcType::F16v2, + _ => SM70SrcType::Other, + } + } +} + +impl SM70SrcType { + fn is_fp16(self) -> bool { + self != Self::Other + } +} + impl SM70Encoder<'_> { - fn set_swizzle(&mut self, range: Range, swizzle: SrcSwizzle) { + fn set_swizzle( + &mut self, + range: Range, + swizzle: SrcSwizzle, + src_type: SM70SrcType, + ) { assert!(range.len() == 2); self.set_field( range, - match swizzle { - SrcSwizzle::None => 0x00_u8, - SrcSwizzle::Xx => 0x02_u8, - SrcSwizzle::Yy => 0x03_u8, + match (src_type, swizzle) { + (_, SrcSwizzle::None) => 0x00_u8, + (SM70SrcType::F16, SrcSwizzle::Xx) => 0x00_u8, + (SM70SrcType::F16, SrcSwizzle::Yy) => 0x01_u8, + (SM70SrcType::F16v2, SrcSwizzle::Xx) => 0x02_u8, + (SM70SrcType::F16v2, SrcSwizzle::Yy) => 0x03_u8, + _ => unreachable!("unsupported {swizzle:?} for {src_type:?}"), }, ); } @@ -427,7 +458,7 @@ impl SM70Encoder<'_> { neg_bit: usize, swizzle_range: Range, file: RegFile, - is_fp16_alu: bool, + src_type: SM70SrcType, reg: &ALURegRef, ) { match file { @@ -439,8 +470,8 @@ impl SM70Encoder<'_> { self.set_bit(abs_bit, reg.abs); self.set_bit(neg_bit, reg.neg); - if is_fp16_alu { - self.set_swizzle(swizzle_range, reg.swizzle); + if src_type.is_fp16() { + self.set_swizzle(swizzle_range, reg.swizzle, src_type); } else { assert!(reg.swizzle == SrcSwizzle::None); } @@ -450,21 +481,21 @@ impl SM70Encoder<'_> { &mut self, src: &ALUSrc, file: RegFile, - is_fp16_alu: bool, + src_type: SM70SrcType, ) { let reg = match src { ALUSrc::None => return, ALUSrc::Reg(reg) => reg, _ => panic!("Invalid ALU src"), }; - self.set_alu_reg(24..32, 73, 72, 74..76, file, is_fp16_alu, reg); + self.set_alu_reg(24..32, 73, 72, 74..76, file, src_type, reg); } fn encode_alu_src2( &mut self, src: &ALUSrc, file: RegFile, - is_fp16_alu: bool, + src_type: SM70SrcType, ) { let reg = match src { ALUSrc::None => return, @@ -473,34 +504,26 @@ impl SM70Encoder<'_> { }; self.set_alu_reg( 64..72, - if is_fp16_alu { 83 } else { 74 }, - if is_fp16_alu { 84 } else { 75 }, + if src_type.is_fp16() { 83 } else { 74 }, + if src_type.is_fp16() { 84 } else { 75 }, 81..83, file, - is_fp16_alu, + src_type, reg, ); } - fn encode_alu_reg(&mut self, reg: &ALURegRef, is_fp16_alu: bool) { - self.set_alu_reg( - 32..40, - 62, - 63, - 60..62, - RegFile::GPR, - is_fp16_alu, - reg, - ); + fn encode_alu_reg(&mut self, reg: &ALURegRef, src_type: SM70SrcType) { + self.set_alu_reg(32..40, 62, 63, 60..62, RegFile::GPR, src_type, reg); } - fn encode_alu_ureg(&mut self, reg: &ALURegRef, is_fp16_alu: bool) { + fn encode_alu_ureg(&mut self, reg: &ALURegRef, src_type: SM70SrcType) { self.set_ureg(32..40, reg.reg); self.set_bit(62, reg.abs); self.set_bit(63, reg.neg); - if is_fp16_alu { - self.set_swizzle(60..62, reg.swizzle); + if src_type.is_fp16() { + self.set_swizzle(60..62, reg.swizzle, src_type); } else { assert!(reg.swizzle == SrcSwizzle::None); } @@ -512,13 +535,13 @@ impl SM70Encoder<'_> { self.set_field(32..64, *imm); } - fn encode_alu_cb(&mut self, cb: &ALUCBufRef, is_fp16_alu: bool) { + fn encode_alu_cb(&mut self, cb: &ALUCBufRef, src_type: SM70SrcType) { self.set_src_cb(32..59, 91, &cb.cb); self.set_bit(62, cb.abs); self.set_bit(63, cb.neg); - if is_fp16_alu { - self.set_swizzle(60..62, cb.swizzle); + if src_type.is_fp16() { + self.set_swizzle(60..62, cb.swizzle, src_type); } else { assert!(cb.swizzle == SrcSwizzle::None); } @@ -531,7 +554,7 @@ impl SM70Encoder<'_> { src0: Option<&Src>, src1: Option<&Src>, src2: Option<&Src>, - is_fp16_alu: bool, + src_type: SM70SrcType, ) { if let Some(dst) = dst { self.set_dst(dst); @@ -541,19 +564,19 @@ impl SM70Encoder<'_> { let src1 = ALUSrc::from_src(self, src1, false); let src2 = ALUSrc::from_src(self, src2, false); - self.encode_alu_src0(&src0, RegFile::GPR, is_fp16_alu); + self.encode_alu_src0(&src0, RegFile::GPR, src_type); let form = match &src2 { ALUSrc::None | ALUSrc::Reg(_) => { - self.encode_alu_src2(&src2, RegFile::GPR, is_fp16_alu); + self.encode_alu_src2(&src2, RegFile::GPR, src_type); match &src1 { ALUSrc::None => 1_u8, // form ALUSrc::Reg(reg1) => { - self.encode_alu_reg(reg1, is_fp16_alu); + self.encode_alu_reg(reg1, src_type); 1_u8 // form } ALUSrc::UReg(reg1) => { - self.encode_alu_ureg(reg1, is_fp16_alu); + self.encode_alu_ureg(reg1, src_type); 6_u8 // form } ALUSrc::Imm32(imm1) => { @@ -561,25 +584,25 @@ impl SM70Encoder<'_> { 4_u8 // form } ALUSrc::CBuf(cb1) => { - self.encode_alu_cb(cb1, is_fp16_alu); + self.encode_alu_cb(cb1, src_type); 5_u8 // form } } } ALUSrc::UReg(reg2) => { - self.encode_alu_ureg(reg2, is_fp16_alu); - self.encode_alu_src2(&src1, RegFile::GPR, is_fp16_alu); + self.encode_alu_ureg(reg2, src_type); + self.encode_alu_src2(&src1, RegFile::GPR, src_type); 7_u8 // form } ALUSrc::Imm32(imm2) => { self.encode_alu_imm(imm2); - self.encode_alu_src2(&src1, RegFile::GPR, is_fp16_alu); + self.encode_alu_src2(&src1, RegFile::GPR, src_type); 2_u8 // form } ALUSrc::CBuf(cb2) => { // TODO set_src_cx - self.encode_alu_cb(cb2, is_fp16_alu); - self.encode_alu_src2(&src1, RegFile::GPR, is_fp16_alu); + self.encode_alu_cb(cb2, src_type); + self.encode_alu_src2(&src1, RegFile::GPR, src_type); 3_u8 // form } }; @@ -596,10 +619,12 @@ impl SM70Encoder<'_> { src1: Option<&Src>, src2: Option<&Src>, ) { - self.encode_alu_base(opcode, dst, src0, src1, src2, false); + // The SrcType really only matters for FP16, so make this + // convenient for all the other ops + self.encode_alu_base(opcode, dst, src0, src1, src2, SM70SrcType::Other); } - fn encode_fp16_alu( + fn encode_fp16v2_alu( &mut self, opcode: u16, dst: Option<&Dst>, @@ -607,16 +632,17 @@ impl SM70Encoder<'_> { src1: Option<&Src>, src2: Option<&Src>, ) { - self.encode_alu_base(opcode, dst, src0, src1, src2, true); + self.encode_alu_base(opcode, dst, src0, src1, src2, SM70SrcType::F16v2); } - fn encode_ualu( + fn encode_ualu_base( &mut self, opcode: u16, dst: Option<&Dst>, src0: Option<&Src>, src1: Option<&Src>, src2: Option<&Src>, + src_type: SM70SrcType, ) { if let Some(dst) = dst { self.set_udst(dst); @@ -629,14 +655,14 @@ impl SM70Encoder<'_> { // All uniform ALU requires bit 91 set self.set_bit(91, true); - self.encode_alu_src0(&src0, RegFile::UGPR, false); + self.encode_alu_src0(&src0, RegFile::UGPR, src_type); let form = match &src2 { ALUSrc::None | ALUSrc::Reg(_) => { - self.encode_alu_src2(&src2, RegFile::UGPR, false); + self.encode_alu_src2(&src2, RegFile::UGPR, src_type); match &src1 { ALUSrc::None => 1_u8, // form ALUSrc::Reg(reg1) => { - self.encode_alu_ureg(reg1, false); + self.encode_alu_ureg(reg1, src_type); 1_u8 // form } ALUSrc::UReg(_) => panic!("UALU never has UReg"), @@ -650,7 +676,7 @@ impl SM70Encoder<'_> { ALUSrc::UReg(_) => panic!("UALU never has UReg"), ALUSrc::Imm32(imm2) => { self.encode_alu_imm(imm2); - self.encode_alu_src2(&src1, RegFile::UGPR, false); + self.encode_alu_src2(&src1, RegFile::UGPR, src_type); 2_u8 // form } ALUSrc::CBuf(_) => panic!("UALU does not support cbufs"), @@ -660,6 +686,19 @@ impl SM70Encoder<'_> { self.set_field(9..12, form); } + fn encode_ualu( + &mut self, + opcode: u16, + dst: Option<&Dst>, + src0: Option<&Src>, + src1: Option<&Src>, + src2: Option<&Src>, + ) { + // The SrcType really only matters for FP16, so make this + // convenient for all the other ops + self.encode_ualu_base(opcode, dst, src0, src1, src2, SM70SrcType::Other) + } + fn set_rnd_mode(&mut self, range: Range, rnd_mode: FRndMode) { assert!(range.len() == 2); self.set_field( @@ -1073,7 +1112,7 @@ impl SM70Op for OpHAdd2 { fn encode(&self, e: &mut SM70Encoder<'_>) { if src_is_zero_or_gpr(&self.srcs[1]) { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x030, Some(&self.dst), Some(&self.srcs[0]), @@ -1081,7 +1120,7 @@ impl SM70Op for OpHAdd2 { None, ) } else { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x030, Some(&self.dst), Some(&self.srcs[0]), @@ -1107,7 +1146,7 @@ impl SM70Op for OpHFma2 { } fn encode(&self, e: &mut SM70Encoder<'_>) { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x031, Some(&self.dst), Some(&self.srcs[0]), @@ -1133,7 +1172,7 @@ impl SM70Op for OpHMul2 { } fn encode(&self, e: &mut SM70Encoder<'_>) { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x032, Some(&self.dst), Some(&self.srcs[0]), @@ -1163,7 +1202,7 @@ impl SM70Op for OpHSet2 { fn encode(&self, e: &mut SM70Encoder<'_>) { if src_is_zero_or_gpr(&self.srcs[1]) { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x033, Some(&self.dst), Some(&self.srcs[0]), @@ -1171,7 +1210,7 @@ impl SM70Op for OpHSet2 { None, ) } else { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x033, Some(&self.dst), Some(&self.srcs[0]), @@ -1205,7 +1244,7 @@ impl SM70Op for OpHSetP2 { fn encode(&self, e: &mut SM70Encoder<'_>) { if src_is_zero_or_gpr(&self.srcs[1]) { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x034, None, Some(&self.srcs[0]), @@ -1213,7 +1252,7 @@ impl SM70Op for OpHSetP2 { None, ) } else { - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x034, None, Some(&self.srcs[0]), @@ -1246,7 +1285,7 @@ impl SM70Op for OpHMnMx2 { fn encode(&self, e: &mut SM70Encoder<'_>) { assert!(e.sm >= 80); - e.encode_fp16_alu( + e.encode_fp16v2_alu( 0x040, Some(&self.dst), Some(&self.srcs[0]), @@ -1914,17 +1953,21 @@ impl SM70Op for OpF2F { fn encode(&self, e: &mut SM70Encoder<'_>) { assert!(!self.integer_rnd); - // The swizzle is handled by the .high bit below. - let src = self.src.clone().without_swizzle(); - if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 { - e.encode_alu(0x104, Some(&self.dst), None, Some(&src), None) + let opcode = if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 + { + 0x104 } else { - e.encode_alu(0x110, Some(&self.dst), None, Some(&src), None) + 0x110 }; - if self.is_high() { - e.set_field(60..62, 1_u8); // .H1 - } + e.encode_alu_base( + opcode, + Some(&self.dst), + None, + Some(&self.src), + None, + self.src_types()[0].into(), + ); e.set_field(75..77, (self.dst_type.bits() / 8).ilog2()); e.set_rnd_mode(78..80, self.rnd_mode); @@ -1965,12 +2008,22 @@ impl SM70Op for OpF2I { } fn encode(&self, e: &mut SM70Encoder<'_>) { - if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 { - e.encode_alu(0x105, Some(&self.dst), None, Some(&self.src), None) + let opcode = if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 + { + 0x105 } else { - e.encode_alu(0x111, Some(&self.dst), None, Some(&self.src), None) + 0x111 }; + e.encode_alu_base( + opcode, + Some(&self.dst), + None, + Some(&self.src), + None, + self.src_types()[0].into(), + ); + e.set_bit(72, self.dst_type.is_signed()); e.set_field(75..77, (self.dst_type.bits() / 8).ilog2()); e.set_bit(77, false); // NTZ @@ -2006,12 +2059,22 @@ impl SM70Op for OpFRnd { } fn encode(&self, e: &mut SM70Encoder<'_>) { - if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 { - e.encode_alu(0x107, Some(&self.dst), None, Some(&self.src), None) + let opcode = if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 + { + 0x107 } else { - e.encode_alu(0x113, Some(&self.dst), None, Some(&self.src), None) + 0x113 }; + e.encode_alu_base( + opcode, + Some(&self.dst), + None, + Some(&self.src), + None, + self.src_types()[0].into(), + ); + e.set_field(84..86, (self.src_type.bits() / 8).ilog2()); e.set_bit(80, self.ftz); e.set_rnd_mode(78..80, self.rnd_mode); @@ -2033,7 +2096,7 @@ impl SM70Op for OpMov { let src = ALUSrc::from_src(e, Some(&self.src), true); let form: u8 = match &src { ALUSrc::Reg(reg) => { - e.encode_alu_ureg(reg, false); + e.encode_alu_ureg(reg, SM70SrcType::Other); 0x6 // form } ALUSrc::Imm32(imm) => {