nak: Add 16-bits float operations

Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27635>
This commit is contained in:
Mary Guillemard 2024-02-15 12:09:35 +01:00 committed by Marge Bot
parent 6b2ce802b7
commit 567cae69c3
7 changed files with 650 additions and 34 deletions

View file

@ -217,6 +217,31 @@ pub trait SSABuilder: Builder {
dst dst
} }
fn hadd2(&mut self, x: Src, y: Src) -> SSARef {
let dst = self.alloc_ssa(RegFile::GPR, 1);
self.push_op(OpHAdd2 {
dst: dst.into(),
srcs: [x, y],
saturate: false,
ftz: false,
f32: false,
});
dst
}
fn hset2(&mut self, cmp_op: FloatCmpOp, x: Src, y: Src) -> SSARef {
let dst = self.alloc_ssa(RegFile::GPR, 1);
self.push_op(OpHSet2 {
dst: dst.into(),
set_op: PredSetOp::And,
cmp_op: cmp_op,
srcs: [x, y],
ftz: false,
accum: SrcRef::True.into(),
});
dst
}
fn dsetp(&mut self, cmp_op: FloatCmpOp, x: Src, y: Src) -> SSARef { fn dsetp(&mut self, cmp_op: FloatCmpOp, x: Src, y: Src) -> SSARef {
let dst = self.alloc_ssa(RegFile::Pred, 1); let dst = self.alloc_ssa(RegFile::Pred, 1);
self.push_op(OpDSetP { self.push_op(OpDSetP {

View file

@ -535,6 +535,17 @@ impl SM70Instr {
self.encode_alu_base(opcode, dst, src0, src1, src2, false); self.encode_alu_base(opcode, dst, src0, src1, src2, false);
} }
fn encode_fp16_alu(
&mut self,
opcode: u16,
dst: Option<Dst>,
src0: ALUSrc,
src1: ALUSrc,
src2: ALUSrc,
) {
self.encode_alu_base(opcode, dst, src0, src1, src2, true);
}
fn set_instr_deps(&mut self, deps: &InstrDeps) { fn set_instr_deps(&mut self, deps: &InstrDeps) {
self.set_field(105..109, deps.delay); self.set_field(105..109, deps.delay);
self.set_bit(109, deps.yld); self.set_bit(109, deps.yld);
@ -804,6 +815,159 @@ impl SM70Instr {
self.set_pred_src(87..90, 90, op.accum); self.set_pred_src(87..90, 90, op.accum);
} }
fn encode_hadd2(&mut self, op: &OpHAdd2) {
match op.srcs[1].src_ref {
SrcRef::Reg(_) | SrcRef::Zero => {
self.encode_fp16_alu(
0x030,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::None,
);
}
_ => {
self.encode_fp16_alu(
0x030,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::None,
ALUSrc::from_src(&op.srcs[1]),
);
}
}
self.set_bit(77, op.saturate);
self.set_bit(78, op.f32);
self.set_bit(80, op.ftz);
self.set_bit(85, false); // .BF16_V2 (SM90+)
}
fn encode_hfma2(&mut self, op: &OpHFma2) {
// HFMA2 doesn't have fneg and fabs on SRC2.
assert!(op.srcs[2].src_mod.is_none());
self.encode_fp16_alu(
0x031,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::from_src(&op.srcs[2]),
);
self.set_bit(76, op.dnz);
self.set_bit(77, op.saturate);
self.set_bit(78, op.f32);
self.set_bit(79, false); // .RELU (SM86+)
self.set_bit(80, op.ftz);
self.set_bit(85, false); // .BF16_V2 (SM86+)
}
fn encode_hmul2(&mut self, op: &OpHMul2) {
self.encode_fp16_alu(
0x032,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::None,
);
self.set_bit(76, op.dnz);
self.set_bit(77, op.saturate);
self.set_bit(78, false); // .F32 (SM70-SM75)
self.set_bit(79, false); // .RELU (SM86+)
self.set_bit(80, op.ftz);
self.set_bit(85, false); // .BF16_V2 (SM90+)
}
fn encode_hset2(&mut self, op: &OpHSet2) {
match op.srcs[1].src_ref {
SrcRef::Reg(_) | SrcRef::Zero => {
self.encode_fp16_alu(
0x033,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::None,
);
}
_ => {
self.encode_fp16_alu(
0x033,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::None,
ALUSrc::from_src(&op.srcs[1]),
);
}
}
self.set_bit(65, false); // .BF16_V2 (SM90+)
self.set_pred_set_op(69..71, op.set_op);
// This differentiate between integer and fp16 output
self.set_bit(71, true); // .BF
self.set_float_cmp_op(76..80, op.cmp_op);
self.set_bit(80, op.ftz);
self.set_pred_src(87..90, 90, op.accum);
}
fn encode_hsetp2(&mut self, op: &OpHSetP2) {
match op.srcs[1].src_ref {
SrcRef::Reg(_) | SrcRef::Zero => {
self.encode_fp16_alu(
0x034,
None,
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::None,
);
}
_ => {
self.encode_fp16_alu(
0x034,
None,
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::None,
ALUSrc::from_src(&op.srcs[1]),
);
}
}
self.set_bit(65, false); // .BF16_V2 (SM90+)
self.set_pred_set_op(69..71, op.set_op);
self.set_bit(71, op.horizontal); // .H_AND
self.set_float_cmp_op(76..80, op.cmp_op);
self.set_bit(80, op.ftz);
self.set_pred_dst(81..84, op.dsts[0]);
self.set_pred_dst(84..87, op.dsts[1]);
self.set_pred_src(87..90, 90, op.accum);
}
fn encode_hmnmx2(&mut self, op: &OpHMnMx2) {
assert!(self.sm >= 80);
self.encode_fp16_alu(
0x040,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::None,
);
// This differentiate between integer and fp16 output
self.set_bit(78, false); // .F32 (SM86)
self.set_bit(80, op.ftz);
self.set_bit(81, false); // .NAN
self.set_bit(82, false); // .XORSIGN
self.set_bit(85, false); // .BF16_V2
self.set_pred_src(87..90, 90, op.min);
}
fn encode_bmsk(&mut self, op: &OpBMsk) { fn encode_bmsk(&mut self, op: &OpBMsk) {
self.encode_alu( self.encode_alu(
0x01b, 0x01b,
@ -2175,6 +2339,12 @@ impl SM70Instr {
Op::DFma(op) => si.encode_dfma(op), Op::DFma(op) => si.encode_dfma(op),
Op::DMul(op) => si.encode_dmul(op), Op::DMul(op) => si.encode_dmul(op),
Op::DSetP(op) => si.encode_dsetp(op), Op::DSetP(op) => si.encode_dsetp(op),
Op::HAdd2(op) => si.encode_hadd2(op),
Op::HFma2(op) => si.encode_hfma2(op),
Op::HMul2(op) => si.encode_hmul2(op),
Op::HSet2(op) => si.encode_hset2(op),
Op::HSetP2(op) => si.encode_hsetp2(op),
Op::HMnMx2(op) => si.encode_hmnmx2(op),
Op::MuFu(op) => si.encode_mufu(op), Op::MuFu(op) => si.encode_mufu(op),
Op::BMsk(op) => si.encode_bmsk(op), Op::BMsk(op) => si.encode_bmsk(op),
Op::BRev(op) => si.encode_brev(op), Op::BRev(op) => si.encode_brev(op),

View file

@ -511,6 +511,16 @@ impl<'a> ShaderFromNir<'a> {
} }
} }
// Restricts an F16v2 source to just x if the ALU op is single-component. This
// must only be called for per-component sources (see nir_op_info::output_sizes
// for more details).
let restrict_f16v2_src = |mut src: Src| {
if alu.def.num_components == 1 {
src.src_swizzle = SrcSwizzle::Xx;
}
src
};
let dst: SSARef = match alu.op { let dst: SSARef = match alu.op {
nir_op_b2b1 => { nir_op_b2b1 => {
assert!(alu.get_src(0).bit_size() == 32); assert!(alu.get_src(0).bit_size() == 32);
@ -524,6 +534,7 @@ impl<'a> ShaderFromNir<'a> {
let hi = b.copy(0.into()); let hi = b.copy(0.into());
[lo[0], hi[0]].into() [lo[0], hi[0]].into()
} }
nir_op_b2f16 => b.sel(srcs[0].bnot(), 0.into(), 0x3c00.into()),
nir_op_b2f32 => { nir_op_b2f32 => {
b.sel(srcs[0].bnot(), 0.0_f32.into(), 1.0_f32.into()) b.sel(srcs[0].bnot(), 0.0_f32.into(), 1.0_f32.into())
} }
@ -688,7 +699,7 @@ impl<'a> ShaderFromNir<'a> {
} }
nir_op_fabs | nir_op_fadd | nir_op_fneg => { nir_op_fabs | nir_op_fadd | nir_op_fneg => {
let (x, y) = match alu.op { let (x, y) = match alu.op {
nir_op_fabs => (srcs[0].fabs(), 0.0_f32.into()), nir_op_fabs => (srcs[0].fabs(), 0.into()),
nir_op_fadd => (srcs[0], srcs[1]), nir_op_fadd => (srcs[0], srcs[1]),
nir_op_fneg => (Src::new_zero().fneg(), srcs[0].fneg()), nir_op_fneg => (Src::new_zero().fneg(), srcs[0].fneg()),
_ => panic!("Unhandled case"), _ => panic!("Unhandled case"),
@ -711,6 +722,19 @@ impl<'a> ShaderFromNir<'a> {
rnd_mode: self.float_ctl[ftype].rnd_mode, rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz, ftz: self.float_ctl[ftype].ftz,
}); });
} else if alu.def.bit_size() == 16 {
assert!(
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
);
dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpHAdd2 {
dst: dst.into(),
srcs: [restrict_f16v2_src(x), restrict_f16v2_src(y)],
saturate: self.try_saturate_alu_dst(&alu.def),
ftz: self.float_ctl[ftype].ftz,
f32: false,
});
} else { } else {
panic!("Unsupported float type: f{}", alu.def.bit_size()); panic!("Unsupported float type: f{}", alu.def.bit_size());
} }
@ -718,7 +742,6 @@ impl<'a> ShaderFromNir<'a> {
} }
nir_op_fceil | nir_op_ffloor | nir_op_fround_even nir_op_fceil | nir_op_ffloor | nir_op_fround_even
| nir_op_ftrunc => { | nir_op_ftrunc => {
assert!(alu.def.bit_size() == 32);
let dst = b.alloc_ssa(RegFile::GPR, 1); let dst = b.alloc_ssa(RegFile::GPR, 1);
let ty = FloatType::from_bits(alu.def.bit_size().into()); let ty = FloatType::from_bits(alu.def.bit_size().into());
let rnd_mode = match alu.op { let rnd_mode = match alu.op {
@ -730,6 +753,9 @@ impl<'a> ShaderFromNir<'a> {
}; };
let ftz = self.float_ctl[ty].ftz; let ftz = self.float_ctl[ty].ftz;
if b.sm() >= 70 { if b.sm() >= 70 {
assert!(
alu.def.bit_size() == 32 || alu.def.bit_size() == 16
);
b.push_op(OpFRnd { b.push_op(OpFRnd {
dst: dst.into(), dst: dst.into(),
src: srcs[0], src: srcs[0],
@ -739,6 +765,7 @@ impl<'a> ShaderFromNir<'a> {
ftz, ftz,
}); });
} else { } else {
assert!(alu.def.bit_size() == 32);
b.push_op(OpF2F { b.push_op(OpF2F {
dst: dst.into(), dst: dst.into(),
src: srcs[0], src: srcs[0],
@ -764,8 +791,9 @@ impl<'a> ShaderFromNir<'a> {
_ => panic!("Usupported float comparison"), _ => panic!("Usupported float comparison"),
}; };
let dst = b.alloc_ssa(RegFile::Pred, 1); let dst = b.alloc_ssa(RegFile::Pred, alu.def.num_components);
if alu.get_src(0).bit_size() == 64 { if alu.get_src(0).bit_size() == 64 {
assert!(alu.def.num_components == 1);
b.push_op(OpDSetP { b.push_op(OpDSetP {
dst: dst.into(), dst: dst.into(),
set_op: PredSetOp::And, set_op: PredSetOp::And,
@ -774,6 +802,7 @@ impl<'a> ShaderFromNir<'a> {
accum: SrcRef::True.into(), accum: SrcRef::True.into(),
}); });
} else if alu.get_src(0).bit_size() == 32 { } else if alu.get_src(0).bit_size() == 32 {
assert!(alu.def.num_components == 1);
b.push_op(OpFSetP { b.push_op(OpFSetP {
dst: dst.into(), dst: dst.into(),
set_op: PredSetOp::And, set_op: PredSetOp::And,
@ -782,6 +811,30 @@ impl<'a> ShaderFromNir<'a> {
accum: SrcRef::True.into(), accum: SrcRef::True.into(),
ftz: self.float_ctl[src_type].ftz, ftz: self.float_ctl[src_type].ftz,
}); });
} else if alu.get_src(0).bit_size() == 16 {
assert!(
alu.def.num_components == 1
|| alu.def.num_components == 2
);
let dsts = if alu.def.num_components == 2 {
[dst[0].into(), dst[1].into()]
} else {
[dst[0].into(), Dst::None]
};
b.push_op(OpHSetP2 {
dsts,
set_op: PredSetOp::And,
cmp_op: cmp_op,
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
],
accum: SrcRef::True.into(),
ftz: self.float_ctl[src_type].ftz,
horizontal: false,
});
} else { } else {
panic!( panic!(
"Unsupported float type: f{}", "Unsupported float type: f{}",
@ -814,6 +867,24 @@ impl<'a> ShaderFromNir<'a> {
ftz: self.float_ctl[ftype].ftz, ftz: self.float_ctl[ftype].ftz,
dnz: false, dnz: false,
}); });
} else if alu.def.bit_size() == 16 {
assert!(
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
);
dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpHFma2 {
dst: dst.into(),
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
restrict_f16v2_src(srcs[2]),
],
saturate: self.try_saturate_alu_dst(&alu.def),
ftz: self.float_ctl[ftype].ftz,
dnz: false,
f32: false,
});
} else { } else {
panic!("Unsupported float type: f{}", alu.def.bit_size()); panic!("Unsupported float type: f{}", alu.def.bit_size());
} }
@ -857,6 +928,17 @@ impl<'a> ShaderFromNir<'a> {
min: (alu.op == nir_op_fmin).into(), min: (alu.op == nir_op_fmin).into(),
ftz: self.float_ctl.fp32.ftz, ftz: self.float_ctl.fp32.ftz,
}); });
} else if alu.def.bit_size() == 16 {
dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpHMnMx2 {
dst: dst.into(),
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
],
min: (alu.op == nir_op_fmin).into(),
ftz: self.float_ctl.fp16.ftz,
});
} else { } else {
panic!("Unsupported float type: f{}", alu.def.bit_size()); panic!("Unsupported float type: f{}", alu.def.bit_size());
} }
@ -883,6 +965,22 @@ impl<'a> ShaderFromNir<'a> {
ftz: self.float_ctl[ftype].ftz, ftz: self.float_ctl[ftype].ftz,
dnz: false, dnz: false,
}); });
} else if alu.def.bit_size() == 16 {
assert!(
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
);
dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpHMul2 {
dst: dst.into(),
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
],
saturate: self.try_saturate_alu_dst(&alu.def),
ftz: self.float_ctl[ftype].ftz,
dnz: false,
});
} else { } else {
panic!("Unsupported float type: f{}", alu.def.bit_size()); panic!("Unsupported float type: f{}", alu.def.bit_size());
} }
@ -940,11 +1038,11 @@ impl<'a> ShaderFromNir<'a> {
b.mufu(MuFuOp::Rsq, srcs[0]) b.mufu(MuFuOp::Rsq, srcs[0])
} }
nir_op_fsat => { nir_op_fsat => {
assert!(alu.def.bit_size() == 32); let ftype = FloatType::from_bits(alu.def.bit_size().into());
if self.alu_src_is_saturated(&alu.srcs_as_slice()[0]) { if self.alu_src_is_saturated(&alu.srcs_as_slice()[0]) {
b.copy(srcs[0]) b.copy(srcs[0])
} else { } else if alu.def.bit_size() == 32 {
let ftype = FloatType::from_bits(alu.def.bit_size().into());
let dst = b.alloc_ssa(RegFile::GPR, 1); let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpFAdd { b.push_op(OpFAdd {
dst: dst.into(), dst: dst.into(),
@ -954,6 +1052,22 @@ impl<'a> ShaderFromNir<'a> {
ftz: self.float_ctl[ftype].ftz, ftz: self.float_ctl[ftype].ftz,
}); });
dst dst
} else if alu.def.bit_size() == 16 {
assert!(
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
);
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpHAdd2 {
dst: dst.into(),
srcs: [restrict_f16v2_src(srcs[0]), 0.into()],
saturate: true,
ftz: self.float_ctl[ftype].ftz,
f32: false,
});
dst
} else {
panic!("Unsupported float type: f{}", alu.def.bit_size());
} }
} }
nir_op_fsign => { nir_op_fsign => {
@ -968,6 +1082,17 @@ impl<'a> ShaderFromNir<'a> {
let lz = b.fset(FloatCmpOp::OrdLt, srcs[0], 0.into()); let lz = b.fset(FloatCmpOp::OrdLt, srcs[0], 0.into());
let gz = b.fset(FloatCmpOp::OrdGt, srcs[0], 0.into()); let gz = b.fset(FloatCmpOp::OrdGt, srcs[0], 0.into());
b.fadd(gz.into(), Src::from(lz).fneg()) b.fadd(gz.into(), Src::from(lz).fneg())
} else if alu.def.bit_size() == 16 {
let x = restrict_f16v2_src(srcs[0]);
let lz = restrict_f16v2_src(
b.hset2(FloatCmpOp::OrdLt, x, 0.into()).into(),
);
let gz = restrict_f16v2_src(
b.hset2(FloatCmpOp::OrdGt, x, 0.into()).into(),
);
b.hadd2(gz, lz.fneg())
} else { } else {
panic!("Unsupported float type: f{}", alu.def.bit_size()); panic!("Unsupported float type: f{}", alu.def.bit_size());
} }

View file

@ -2696,6 +2696,186 @@ impl DisplayOp for OpDSetP {
} }
impl_display_for_op!(OpDSetP); impl_display_for_op!(OpDSetP);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHAdd2 {
pub dst: Dst,
#[src_type(F16v2)]
pub srcs: [Src; 2],
pub saturate: bool,
pub ftz: bool,
pub f32: bool,
}
impl DisplayOp for OpHAdd2 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sat = if self.saturate { ".sat" } else { "" };
let f32 = if self.f32 { ".f32" } else { "" };
write!(f, "hadd2{sat}{f32}")?;
if self.ftz {
write!(f, ".ftz")?;
}
write!(f, " {} {}", self.srcs[0], self.srcs[1])
}
}
impl_display_for_op!(OpHAdd2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHSet2 {
pub dst: Dst,
pub set_op: PredSetOp,
pub cmp_op: FloatCmpOp,
#[src_type(F16v2)]
pub srcs: [Src; 2],
#[src_type(Pred)]
pub accum: Src,
pub ftz: bool,
}
impl DisplayOp for OpHSet2 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ftz = if self.ftz { ".ftz" } else { "" };
write!(f, "hset2{}{ftz}", self.cmp_op)?;
if !self.set_op.is_trivial(&self.accum) {
write!(f, "{}", self.set_op)?;
}
write!(f, " {} {}", self.srcs[0], self.srcs[1])?;
if !self.set_op.is_trivial(&self.accum) {
write!(f, " {}", self.accum)?;
}
Ok(())
}
}
impl_display_for_op!(OpHSet2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHSetP2 {
pub dsts: [Dst; 2],
pub set_op: PredSetOp,
pub cmp_op: FloatCmpOp,
#[src_type(F16v2)]
pub srcs: [Src; 2],
#[src_type(Pred)]
pub accum: Src,
pub ftz: bool,
// When not set, each dsts get the result of each lanes.
// When set, the first dst gets the result of both lanes (res0 && res1)
// and the second dst gets the negation !(res0 && res1)
// before applying the accumulator.
pub horizontal: bool,
}
impl DisplayOp for OpHSetP2 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ftz = if self.ftz { ".ftz" } else { "" };
write!(f, "hsetp2{}{ftz}", self.cmp_op)?;
if !self.set_op.is_trivial(&self.accum) {
write!(f, "{}", self.set_op)?;
}
write!(f, " {} {}", self.srcs[0], self.srcs[1])?;
if !self.set_op.is_trivial(&self.accum) {
write!(f, " {}", self.accum)?;
}
Ok(())
}
}
impl_display_for_op!(OpHSetP2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHMul2 {
pub dst: Dst,
#[src_type(F16v2)]
pub srcs: [Src; 2],
pub saturate: bool,
pub ftz: bool,
pub dnz: bool,
}
impl DisplayOp for OpHMul2 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sat = if self.saturate { ".sat" } else { "" };
write!(f, "hmul2{sat}")?;
if self.dnz {
write!(f, ".dnz")?;
} else if self.ftz {
write!(f, ".ftz")?;
}
write!(f, " {} {}", self.srcs[0], self.srcs[1])
}
}
impl_display_for_op!(OpHMul2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHFma2 {
pub dst: Dst,
#[src_type(F16v2)]
pub srcs: [Src; 3],
pub saturate: bool,
pub ftz: bool,
pub dnz: bool,
pub f32: bool,
}
impl DisplayOp for OpHFma2 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sat = if self.saturate { ".sat" } else { "" };
let f32 = if self.f32 { ".f32" } else { "" };
write!(f, "hfma2{sat}{f32}")?;
if self.dnz {
write!(f, ".dnz")?;
} else if self.ftz {
write!(f, ".ftz")?;
}
write!(f, " {} {} {}", self.srcs[0], self.srcs[1], self.srcs[2])
}
}
impl_display_for_op!(OpHFma2);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpHMnMx2 {
pub dst: Dst,
#[src_type(F16v2)]
pub srcs: [Src; 2],
#[src_type(Pred)]
pub min: Src,
pub ftz: bool,
}
impl DisplayOp for OpHMnMx2 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ftz = if self.ftz { ".ftz" } else { "" };
write!(
f,
"hmnmx2{ftz} {} {} {}",
self.srcs[0], self.srcs[1], self.min
)
}
}
impl_display_for_op!(OpHMnMx2);
#[repr(C)] #[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)] #[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBMsk { pub struct OpBMsk {
@ -4999,6 +5179,12 @@ pub enum Op {
DMnMx(OpDMnMx), DMnMx(OpDMnMx),
DMul(OpDMul), DMul(OpDMul),
DSetP(OpDSetP), DSetP(OpDSetP),
HAdd2(OpHAdd2),
HFma2(OpHFma2),
HMul2(OpHMul2),
HSet2(OpHSet2),
HSetP2(OpHSetP2),
HMnMx2(OpHMnMx2),
BMsk(OpBMsk), BMsk(OpBMsk),
BRev(OpBRev), BRev(OpBRev),
Bfe(OpBfe), Bfe(OpBfe),
@ -5434,6 +5620,12 @@ impl Instr {
| Op::FMul(_) | Op::FMul(_)
| Op::FSet(_) | Op::FSet(_)
| Op::FSetP(_) | Op::FSetP(_)
| Op::HAdd2(_)
| Op::HFma2(_)
| Op::HMul2(_)
| Op::HSet2(_)
| Op::HSetP2(_)
| Op::HMnMx2(_)
| Op::FSwzAdd(_) => true, | Op::FSwzAdd(_) => true,
// Multi-function unit is variable latency // Multi-function unit is variable latency

View file

@ -163,13 +163,23 @@ fn copy_alu_src_if_f20_overflow(
} }
} }
fn copy_alu_src_if_fabs( fn copy_alu_src_and_lower_fmod(
b: &mut impl SSABuilder, b: &mut impl SSABuilder,
src: &mut Src, src: &mut Src,
src_type: SrcType, src_type: SrcType,
) { ) {
if src.src_mod.has_fabs() {
match src_type { match src_type {
SrcType::F16 | SrcType::F16v2 => {
let val = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpHAdd2 {
dst: val.into(),
srcs: [Src::new_zero().fneg(), *src],
saturate: false,
ftz: false,
f32: false,
});
*src = val.into();
}
SrcType::F32 => { SrcType::F32 => {
let val = b.alloc_ssa(RegFile::GPR, 1); let val = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpFAdd { b.push_op(OpFAdd {
@ -192,6 +202,15 @@ fn copy_alu_src_if_fabs(
} }
_ => panic!("Invalid ffabs srouce type"), _ => panic!("Invalid ffabs srouce type"),
} }
}
fn copy_alu_src_if_fabs(
b: &mut impl SSABuilder,
src: &mut Src,
src_type: SrcType,
) {
if src.src_mod.has_fabs() {
copy_alu_src_and_lower_fmod(b, src, src_type);
} }
} }
@ -501,6 +520,49 @@ fn legalize_sm70_instr(
} }
copy_alu_src_if_not_reg(b, src0, SrcType::F32); copy_alu_src_if_not_reg(b, src0, SrcType::F32);
} }
Op::HAdd2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
}
Op::HFma2(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
copy_alu_src_if_not_reg(b, src1, SrcType::F16v2);
copy_alu_src_if_both_not_reg(b, src1, src2, SrcType::F16v2);
// HFMA2 doesn't have fabs or fneg on SRC2.
if !src2.src_mod.is_none() {
copy_alu_src_and_lower_fmod(b, src2, SrcType::F16v2);
}
}
Op::HMul2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
}
Op::HSet2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
}
Op::HSetP2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
}
Op::HMnMx2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
}
Op::MuFu(_) => (), // Nothing to do Op::MuFu(_) => (), // Nothing to do
Op::DAdd(op) => { Op::DAdd(op) => {
let [ref mut src0, ref mut src1] = op.srcs; let [ref mut src0, ref mut src1] = op.srcs;

View file

@ -395,6 +395,19 @@ impl CopyPropPass {
fn try_add_instr(&mut self, instr: &Instr) { fn try_add_instr(&mut self, instr: &Instr) {
match &instr.op { match &instr.op {
Op::HAdd2(add) => {
let dst = add.dst.as_ssa().unwrap();
assert!(dst.comps() == 1);
let dst = dst[0];
if !add.saturate {
if add.srcs[0].is_fneg_zero(SrcType::F16v2) {
self.add_copy(dst, SrcType::F16v2, add.srcs[1]);
} else if add.srcs[1].is_fneg_zero(SrcType::F16v2) {
self.add_copy(dst, SrcType::F16v2, add.srcs[0]);
}
}
}
Op::FAdd(add) => { Op::FAdd(add) => {
let dst = add.dst.as_ssa().unwrap(); let dst = add.dst.as_ssa().unwrap();
assert!(dst.comps() == 1); assert!(dst.comps() == 1);

View file

@ -146,14 +146,20 @@ nak_optimize_nir(nir_shader *nir, const struct nak_compiler *nak)
} }
static unsigned static unsigned
lower_bit_size_cb(const nir_instr *instr, void *_data) lower_bit_size_cb(const nir_instr *instr, void *data)
{ {
const struct nak_compiler *nak = data;
switch (instr->type) { switch (instr->type) {
case nir_instr_type_alu: { case nir_instr_type_alu: {
nir_alu_instr *alu = nir_instr_as_alu(instr); nir_alu_instr *alu = nir_instr_as_alu(instr);
if (nir_op_infos[alu->op].is_conversion) if (nir_op_infos[alu->op].is_conversion)
return 0; return 0;
const unsigned bit_size = nir_alu_instr_is_comparison(alu)
? alu->src[0].src.ssa->bit_size
: alu->def.bit_size;
switch (alu->op) { switch (alu->op) {
case nir_op_bit_count: case nir_op_bit_count:
case nir_op_ufind_msb: case nir_op_ufind_msb:
@ -164,17 +170,40 @@ lower_bit_size_cb(const nir_instr *instr, void *_data)
* source. * source.
*/ */
return alu->src[0].src.ssa->bit_size == 32 ? 0 : 32; return alu->src[0].src.ssa->bit_size == 32 ? 0 : 32;
case nir_op_fabs:
case nir_op_fadd:
case nir_op_fneg:
case nir_op_feq:
case nir_op_fge:
case nir_op_flt:
case nir_op_fneu:
case nir_op_fmul:
case nir_op_ffma:
case nir_op_ffmaz:
case nir_op_fsign:
case nir_op_fsat:
case nir_op_fceil:
case nir_op_ffloor:
case nir_op_fround_even:
case nir_op_ftrunc:
if (bit_size == 16 && nak->sm >= 70)
return 0;
break;
case nir_op_fmax:
case nir_op_fmin:
if (bit_size == 16 && nak->sm >= 80)
return 0;
break;
default: default:
break; break;
} }
const unsigned bit_size = nir_alu_instr_is_comparison(alu)
? alu->src[0].src.ssa->bit_size
: alu->def.bit_size;
if (bit_size >= 32) if (bit_size >= 32)
return 0; return 0;
/* TODO: Some hardware has native 16-bit support */
if (bit_size & (8 | 16)) if (bit_size & (8 | 16))
return 32; return 32;