nak: rework swizzling on scalar FP16 ops

Instructions that take a F16 value can generally select which component to
read from. This lets us get rid of some PRMTs.

This also cleans up partial support for it for F2F and streamlines
everything into an uniform model as previously it wasn't wired up
generally and copy prop didn't always propagate the swizzle through.

This also makes it uneccessary to apply a Xx swizzle to scalar FP16
sources.

Totals from 907 (0.08% of 1163204) affected shaders:
CodeSize: 40856816 -> 40843408 (-0.03%); split: -0.03%, +0.00%
Static cycle count: 20898101 -> 20895619 (-0.01%); split: -0.01%, +0.00%

Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40392>
This commit is contained in:
Karol Herbst 2026-03-16 15:11:33 +01:00 committed by Marge Bot
parent aa39da8338
commit 67bfbc7535
7 changed files with 160 additions and 117 deletions

View file

@ -820,11 +820,7 @@ impl<'a> ShaderFromNir<'a> {
let src_type = FloatType::from_bits(src_bits.into());
let dst_type = FloatType::from_bits(dst_bits.into());
let mut src = srcs(0);
if src_bits == 16 {
src = restrict_f16v2_src(src);
}
let src = srcs(0);
let dst = b.alloc_ssa_vec(RegFile::GPR, dst_bits.div_ceil(32));
b.push_op(OpF2F {
dst: dst.clone().into(),
@ -2319,13 +2315,7 @@ impl<'a> ShaderFromNir<'a> {
let src_type =
FloatType::from_bits(src_bit_size.into());
let mut src = self.get_src(&srcs[0]);
if src_bit_size == 16
&& intrin.def.num_components() == 1
{
src = src.swizzle(SrcSwizzle::Xx);
}
let src = self.get_src(&srcs[0]);
b.push_op(OpF2F {
dst: dst.clone().into(),
src,

View file

@ -890,7 +890,7 @@ impl SrcMod {
}
}
#[derive(Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SrcSwizzle {
None,
Xx,
@ -976,6 +976,7 @@ impl Src {
self
}
#[allow(dead_code)]
pub fn swizzle(mut self, src_swizzle: SrcSwizzle) -> Src {
// Since we only have xx, yy, and xy, for any composition of swizzles,
// the inner-most non-xy swizzle wins.
@ -1012,7 +1013,10 @@ impl Src {
Some(match src_type {
SrcType::F16 => {
let low = u & 0xFFFF;
let low = match self.src_swizzle {
SrcSwizzle::None | SrcSwizzle::Xx => u & 0xffff,
SrcSwizzle::Yy => u >> 16,
};
match self.src_mod {
SrcMod::None => low,
@ -4644,23 +4648,6 @@ pub struct OpF2F {
pub integer_rnd: bool,
}
impl OpF2F {
pub fn is_high(&self) -> bool {
if matches!(self.src_type, FloatType::F16) {
// OpF2F with the same source and destination types is only allowed
// pre-Volta and only with F32.
assert!(!matches!(self.dst_type, FloatType::F16));
matches!(self.src.src_swizzle, SrcSwizzle::Yy)
} else if matches!(self.dst_type, FloatType::F16) {
self.dst_high
} else {
assert!(!self.dst_high);
false
}
}
}
impl AsSlice<Src> for OpF2F {
type Attr = SrcType;
@ -4674,7 +4661,7 @@ impl AsSlice<Src> for OpF2F {
fn attrs(&self) -> SrcTypeList {
let src_type = match self.src_type {
FloatType::F16 => SrcType::F16v2,
FloatType::F16 => SrcType::F16,
FloatType::F32 => SrcType::F32,
FloatType::F64 => SrcType::F64,
};

View file

@ -313,7 +313,10 @@ impl<'a> CopyPropPass<'a> {
// Turn the swizzle into a permute. For F16, we use Xx to
// indicate that it only takes the bottom 16 bits.
let swizzle_prmt: [u8; 4] = match src_type {
SrcType::F16 => [0, 1, 0, 1],
SrcType::F16 => match src.src_swizzle {
SrcSwizzle::None | SrcSwizzle::Xx => [0, 1, 0, 1],
SrcSwizzle::Yy => [2, 3, 2, 3],
},
SrcType::F16v2 => match src.src_swizzle {
SrcSwizzle::None => [0, 1, 2, 3],
SrcSwizzle::Xx => [0, 1, 0, 1],
@ -353,12 +356,11 @@ impl<'a> CopyPropPass<'a> {
// See if that permute is a valid swizzle
let new_swizzle = match src_type {
SrcType::F16 => {
if combined != [0, 1, 0, 1] {
return;
}
SrcSwizzle::None
}
SrcType::F16 => match combined {
[0, 1, _, _] => SrcSwizzle::None,
[2, 3, _, _] => SrcSwizzle::Yy,
_ => return,
},
SrcType::F16v2 => match combined {
[0, 1, 2, 3] => SrcSwizzle::None,
[0, 1, 0, 1] => SrcSwizzle::Xx,

View file

@ -1449,7 +1449,7 @@ impl SM20Op for OpF2F {
e.set_field(23..25, (self.src_type.bits() / 8).ilog2());
e.set_rnd_mode(49..51, self.rnd_mode);
e.set_bit(55, self.ftz);
e.set_bit(56, self.is_high());
e.set_bit(56, self.src.src_swizzle == SrcSwizzle::Yy);
}
}
@ -1468,7 +1468,7 @@ impl SM20Op for OpF2I {
e.set_field(23..25, (self.src_type.bits() / 8).ilog2());
e.set_rnd_mode(49..51, self.rnd_mode);
e.set_bit(55, self.ftz);
e.set_bit(56, false); // .high
e.set_bit(56, self.src.src_swizzle == SrcSwizzle::Yy);
}
}

View file

@ -1580,7 +1580,7 @@ impl SM32Op for OpF2F {
e.set_field(12..14, (self.src_type.bits() / 8).ilog2());
e.set_rnd_mode(42..44, self.rnd_mode);
e.set_bit(44, self.is_high());
e.set_bit(44, self.src.src_swizzle == SrcSwizzle::Yy);
e.set_bit(45, self.integer_rnd);
e.set_bit(47, self.ftz);
e.set_bit(48, src.src_mod.has_fneg());
@ -1623,7 +1623,7 @@ impl SM32Op for OpF2I {
e.set_bit(14, self.dst_type.is_signed());
e.set_rnd_mode(42..44, self.rnd_mode);
// 44: .h1
e.set_bit(44, self.src.src_swizzle == SrcSwizzle::Yy);
e.set_bit(47, self.ftz);
e.set_bit(48, self.src.src_mod.has_fneg());
e.set_bit(50, false); // dst.CC

View file

@ -1788,7 +1788,7 @@ impl SM50Op for OpF2F {
e.set_field(10..12, (self.src_type.bits() / 8).ilog2());
e.set_rnd_mode(39..41, self.rnd_mode);
e.set_bit(41, self.is_high());
e.set_bit(41, self.src.src_swizzle == SrcSwizzle::Yy);
e.set_bit(42, self.integer_rnd);
e.set_bit(44, self.ftz);
e.set_bit(50, false); // saturate
@ -1833,6 +1833,7 @@ impl SM50Op for OpF2I {
e.set_bit(12, self.dst_type.is_signed());
e.set_rnd_mode(39..41, self.rnd_mode);
e.set_bit(41, self.src.src_swizzle == SrcSwizzle::Yy);
e.set_bit(44, self.ftz);
e.set_bit(47, false); // .CC
}

View file

@ -406,16 +406,47 @@ impl OffsetStride {
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum SM70SrcType {
F16,
F16v2,
Other,
}
impl From<SrcType> for SM70SrcType {
fn from(value: SrcType) -> Self {
match value {
SrcType::F16 => SM70SrcType::F16,
SrcType::F16v2 => SM70SrcType::F16v2,
_ => SM70SrcType::Other,
}
}
}
impl SM70SrcType {
fn is_fp16(self) -> bool {
self != Self::Other
}
}
impl SM70Encoder<'_> {
fn set_swizzle(&mut self, range: Range<usize>, swizzle: SrcSwizzle) {
fn set_swizzle(
&mut self,
range: Range<usize>,
swizzle: SrcSwizzle,
src_type: SM70SrcType,
) {
assert!(range.len() == 2);
self.set_field(
range,
match swizzle {
SrcSwizzle::None => 0x00_u8,
SrcSwizzle::Xx => 0x02_u8,
SrcSwizzle::Yy => 0x03_u8,
match (src_type, swizzle) {
(_, SrcSwizzle::None) => 0x00_u8,
(SM70SrcType::F16, SrcSwizzle::Xx) => 0x00_u8,
(SM70SrcType::F16, SrcSwizzle::Yy) => 0x01_u8,
(SM70SrcType::F16v2, SrcSwizzle::Xx) => 0x02_u8,
(SM70SrcType::F16v2, SrcSwizzle::Yy) => 0x03_u8,
_ => unreachable!("unsupported {swizzle:?} for {src_type:?}"),
},
);
}
@ -427,7 +458,7 @@ impl SM70Encoder<'_> {
neg_bit: usize,
swizzle_range: Range<usize>,
file: RegFile,
is_fp16_alu: bool,
src_type: SM70SrcType,
reg: &ALURegRef,
) {
match file {
@ -439,8 +470,8 @@ impl SM70Encoder<'_> {
self.set_bit(abs_bit, reg.abs);
self.set_bit(neg_bit, reg.neg);
if is_fp16_alu {
self.set_swizzle(swizzle_range, reg.swizzle);
if src_type.is_fp16() {
self.set_swizzle(swizzle_range, reg.swizzle, src_type);
} else {
assert!(reg.swizzle == SrcSwizzle::None);
}
@ -450,21 +481,21 @@ impl SM70Encoder<'_> {
&mut self,
src: &ALUSrc,
file: RegFile,
is_fp16_alu: bool,
src_type: SM70SrcType,
) {
let reg = match src {
ALUSrc::None => return,
ALUSrc::Reg(reg) => reg,
_ => panic!("Invalid ALU src"),
};
self.set_alu_reg(24..32, 73, 72, 74..76, file, is_fp16_alu, reg);
self.set_alu_reg(24..32, 73, 72, 74..76, file, src_type, reg);
}
fn encode_alu_src2(
&mut self,
src: &ALUSrc,
file: RegFile,
is_fp16_alu: bool,
src_type: SM70SrcType,
) {
let reg = match src {
ALUSrc::None => return,
@ -473,34 +504,26 @@ impl SM70Encoder<'_> {
};
self.set_alu_reg(
64..72,
if is_fp16_alu { 83 } else { 74 },
if is_fp16_alu { 84 } else { 75 },
if src_type.is_fp16() { 83 } else { 74 },
if src_type.is_fp16() { 84 } else { 75 },
81..83,
file,
is_fp16_alu,
src_type,
reg,
);
}
fn encode_alu_reg(&mut self, reg: &ALURegRef, is_fp16_alu: bool) {
self.set_alu_reg(
32..40,
62,
63,
60..62,
RegFile::GPR,
is_fp16_alu,
reg,
);
fn encode_alu_reg(&mut self, reg: &ALURegRef, src_type: SM70SrcType) {
self.set_alu_reg(32..40, 62, 63, 60..62, RegFile::GPR, src_type, reg);
}
fn encode_alu_ureg(&mut self, reg: &ALURegRef, is_fp16_alu: bool) {
fn encode_alu_ureg(&mut self, reg: &ALURegRef, src_type: SM70SrcType) {
self.set_ureg(32..40, reg.reg);
self.set_bit(62, reg.abs);
self.set_bit(63, reg.neg);
if is_fp16_alu {
self.set_swizzle(60..62, reg.swizzle);
if src_type.is_fp16() {
self.set_swizzle(60..62, reg.swizzle, src_type);
} else {
assert!(reg.swizzle == SrcSwizzle::None);
}
@ -512,13 +535,13 @@ impl SM70Encoder<'_> {
self.set_field(32..64, *imm);
}
fn encode_alu_cb(&mut self, cb: &ALUCBufRef, is_fp16_alu: bool) {
fn encode_alu_cb(&mut self, cb: &ALUCBufRef, src_type: SM70SrcType) {
self.set_src_cb(32..59, 91, &cb.cb);
self.set_bit(62, cb.abs);
self.set_bit(63, cb.neg);
if is_fp16_alu {
self.set_swizzle(60..62, cb.swizzle);
if src_type.is_fp16() {
self.set_swizzle(60..62, cb.swizzle, src_type);
} else {
assert!(cb.swizzle == SrcSwizzle::None);
}
@ -531,7 +554,7 @@ impl SM70Encoder<'_> {
src0: Option<&Src>,
src1: Option<&Src>,
src2: Option<&Src>,
is_fp16_alu: bool,
src_type: SM70SrcType,
) {
if let Some(dst) = dst {
self.set_dst(dst);
@ -541,19 +564,19 @@ impl SM70Encoder<'_> {
let src1 = ALUSrc::from_src(self, src1, false);
let src2 = ALUSrc::from_src(self, src2, false);
self.encode_alu_src0(&src0, RegFile::GPR, is_fp16_alu);
self.encode_alu_src0(&src0, RegFile::GPR, src_type);
let form = match &src2 {
ALUSrc::None | ALUSrc::Reg(_) => {
self.encode_alu_src2(&src2, RegFile::GPR, is_fp16_alu);
self.encode_alu_src2(&src2, RegFile::GPR, src_type);
match &src1 {
ALUSrc::None => 1_u8, // form
ALUSrc::Reg(reg1) => {
self.encode_alu_reg(reg1, is_fp16_alu);
self.encode_alu_reg(reg1, src_type);
1_u8 // form
}
ALUSrc::UReg(reg1) => {
self.encode_alu_ureg(reg1, is_fp16_alu);
self.encode_alu_ureg(reg1, src_type);
6_u8 // form
}
ALUSrc::Imm32(imm1) => {
@ -561,25 +584,25 @@ impl SM70Encoder<'_> {
4_u8 // form
}
ALUSrc::CBuf(cb1) => {
self.encode_alu_cb(cb1, is_fp16_alu);
self.encode_alu_cb(cb1, src_type);
5_u8 // form
}
}
}
ALUSrc::UReg(reg2) => {
self.encode_alu_ureg(reg2, is_fp16_alu);
self.encode_alu_src2(&src1, RegFile::GPR, is_fp16_alu);
self.encode_alu_ureg(reg2, src_type);
self.encode_alu_src2(&src1, RegFile::GPR, src_type);
7_u8 // form
}
ALUSrc::Imm32(imm2) => {
self.encode_alu_imm(imm2);
self.encode_alu_src2(&src1, RegFile::GPR, is_fp16_alu);
self.encode_alu_src2(&src1, RegFile::GPR, src_type);
2_u8 // form
}
ALUSrc::CBuf(cb2) => {
// TODO set_src_cx
self.encode_alu_cb(cb2, is_fp16_alu);
self.encode_alu_src2(&src1, RegFile::GPR, is_fp16_alu);
self.encode_alu_cb(cb2, src_type);
self.encode_alu_src2(&src1, RegFile::GPR, src_type);
3_u8 // form
}
};
@ -596,10 +619,12 @@ impl SM70Encoder<'_> {
src1: Option<&Src>,
src2: Option<&Src>,
) {
self.encode_alu_base(opcode, dst, src0, src1, src2, false);
// The SrcType really only matters for FP16, so make this
// convenient for all the other ops
self.encode_alu_base(opcode, dst, src0, src1, src2, SM70SrcType::Other);
}
fn encode_fp16_alu(
fn encode_fp16v2_alu(
&mut self,
opcode: u16,
dst: Option<&Dst>,
@ -607,16 +632,17 @@ impl SM70Encoder<'_> {
src1: Option<&Src>,
src2: Option<&Src>,
) {
self.encode_alu_base(opcode, dst, src0, src1, src2, true);
self.encode_alu_base(opcode, dst, src0, src1, src2, SM70SrcType::F16v2);
}
fn encode_ualu(
fn encode_ualu_base(
&mut self,
opcode: u16,
dst: Option<&Dst>,
src0: Option<&Src>,
src1: Option<&Src>,
src2: Option<&Src>,
src_type: SM70SrcType,
) {
if let Some(dst) = dst {
self.set_udst(dst);
@ -629,14 +655,14 @@ impl SM70Encoder<'_> {
// All uniform ALU requires bit 91 set
self.set_bit(91, true);
self.encode_alu_src0(&src0, RegFile::UGPR, false);
self.encode_alu_src0(&src0, RegFile::UGPR, src_type);
let form = match &src2 {
ALUSrc::None | ALUSrc::Reg(_) => {
self.encode_alu_src2(&src2, RegFile::UGPR, false);
self.encode_alu_src2(&src2, RegFile::UGPR, src_type);
match &src1 {
ALUSrc::None => 1_u8, // form
ALUSrc::Reg(reg1) => {
self.encode_alu_ureg(reg1, false);
self.encode_alu_ureg(reg1, src_type);
1_u8 // form
}
ALUSrc::UReg(_) => panic!("UALU never has UReg"),
@ -650,7 +676,7 @@ impl SM70Encoder<'_> {
ALUSrc::UReg(_) => panic!("UALU never has UReg"),
ALUSrc::Imm32(imm2) => {
self.encode_alu_imm(imm2);
self.encode_alu_src2(&src1, RegFile::UGPR, false);
self.encode_alu_src2(&src1, RegFile::UGPR, src_type);
2_u8 // form
}
ALUSrc::CBuf(_) => panic!("UALU does not support cbufs"),
@ -660,6 +686,19 @@ impl SM70Encoder<'_> {
self.set_field(9..12, form);
}
fn encode_ualu(
&mut self,
opcode: u16,
dst: Option<&Dst>,
src0: Option<&Src>,
src1: Option<&Src>,
src2: Option<&Src>,
) {
// The SrcType really only matters for FP16, so make this
// convenient for all the other ops
self.encode_ualu_base(opcode, dst, src0, src1, src2, SM70SrcType::Other)
}
fn set_rnd_mode(&mut self, range: Range<usize>, rnd_mode: FRndMode) {
assert!(range.len() == 2);
self.set_field(
@ -1073,7 +1112,7 @@ impl SM70Op for OpHAdd2 {
fn encode(&self, e: &mut SM70Encoder<'_>) {
if src_is_zero_or_gpr(&self.srcs[1]) {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x030,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1081,7 +1120,7 @@ impl SM70Op for OpHAdd2 {
None,
)
} else {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x030,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1107,7 +1146,7 @@ impl SM70Op for OpHFma2 {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x031,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1133,7 +1172,7 @@ impl SM70Op for OpHMul2 {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x032,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1163,7 +1202,7 @@ impl SM70Op for OpHSet2 {
fn encode(&self, e: &mut SM70Encoder<'_>) {
if src_is_zero_or_gpr(&self.srcs[1]) {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x033,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1171,7 +1210,7 @@ impl SM70Op for OpHSet2 {
None,
)
} else {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x033,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1205,7 +1244,7 @@ impl SM70Op for OpHSetP2 {
fn encode(&self, e: &mut SM70Encoder<'_>) {
if src_is_zero_or_gpr(&self.srcs[1]) {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x034,
None,
Some(&self.srcs[0]),
@ -1213,7 +1252,7 @@ impl SM70Op for OpHSetP2 {
None,
)
} else {
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x034,
None,
Some(&self.srcs[0]),
@ -1246,7 +1285,7 @@ impl SM70Op for OpHMnMx2 {
fn encode(&self, e: &mut SM70Encoder<'_>) {
assert!(e.sm >= 80);
e.encode_fp16_alu(
e.encode_fp16v2_alu(
0x040,
Some(&self.dst),
Some(&self.srcs[0]),
@ -1914,17 +1953,21 @@ impl SM70Op for OpF2F {
fn encode(&self, e: &mut SM70Encoder<'_>) {
assert!(!self.integer_rnd);
// The swizzle is handled by the .high bit below.
let src = self.src.clone().without_swizzle();
if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 {
e.encode_alu(0x104, Some(&self.dst), None, Some(&src), None)
let opcode = if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32
{
0x104
} else {
e.encode_alu(0x110, Some(&self.dst), None, Some(&src), None)
0x110
};
if self.is_high() {
e.set_field(60..62, 1_u8); // .H1
}
e.encode_alu_base(
opcode,
Some(&self.dst),
None,
Some(&self.src),
None,
self.src_types()[0].into(),
);
e.set_field(75..77, (self.dst_type.bits() / 8).ilog2());
e.set_rnd_mode(78..80, self.rnd_mode);
@ -1965,12 +2008,22 @@ impl SM70Op for OpF2I {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 {
e.encode_alu(0x105, Some(&self.dst), None, Some(&self.src), None)
let opcode = if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32
{
0x105
} else {
e.encode_alu(0x111, Some(&self.dst), None, Some(&self.src), None)
0x111
};
e.encode_alu_base(
opcode,
Some(&self.dst),
None,
Some(&self.src),
None,
self.src_types()[0].into(),
);
e.set_bit(72, self.dst_type.is_signed());
e.set_field(75..77, (self.dst_type.bits() / 8).ilog2());
e.set_bit(77, false); // NTZ
@ -2006,12 +2059,22 @@ impl SM70Op for OpFRnd {
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32 {
e.encode_alu(0x107, Some(&self.dst), None, Some(&self.src), None)
let opcode = if self.src_type.bits() <= 32 && self.dst_type.bits() <= 32
{
0x107
} else {
e.encode_alu(0x113, Some(&self.dst), None, Some(&self.src), None)
0x113
};
e.encode_alu_base(
opcode,
Some(&self.dst),
None,
Some(&self.src),
None,
self.src_types()[0].into(),
);
e.set_field(84..86, (self.src_type.bits() / 8).ilog2());
e.set_bit(80, self.ftz);
e.set_rnd_mode(78..80, self.rnd_mode);
@ -2033,7 +2096,7 @@ impl SM70Op for OpMov {
let src = ALUSrc::from_src(e, Some(&self.src), true);
let form: u8 = match &src {
ALUSrc::Reg(reg) => {
e.encode_alu_ureg(reg, false);
e.encode_alu_ureg(reg, SM70SrcType::Other);
0x6 // form
}
ALUSrc::Imm32(imm) => {