mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-21 13:40:16 +01:00
nak: Add 16-bits float operations
Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com> Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27635>
This commit is contained in:
parent
6b2ce802b7
commit
567cae69c3
7 changed files with 650 additions and 34 deletions
|
|
@ -217,6 +217,31 @@ pub trait SSABuilder: Builder {
|
|||
dst
|
||||
}
|
||||
|
||||
fn hadd2(&mut self, x: Src, y: Src) -> SSARef {
|
||||
let dst = self.alloc_ssa(RegFile::GPR, 1);
|
||||
self.push_op(OpHAdd2 {
|
||||
dst: dst.into(),
|
||||
srcs: [x, y],
|
||||
saturate: false,
|
||||
ftz: false,
|
||||
f32: false,
|
||||
});
|
||||
dst
|
||||
}
|
||||
|
||||
fn hset2(&mut self, cmp_op: FloatCmpOp, x: Src, y: Src) -> SSARef {
|
||||
let dst = self.alloc_ssa(RegFile::GPR, 1);
|
||||
self.push_op(OpHSet2 {
|
||||
dst: dst.into(),
|
||||
set_op: PredSetOp::And,
|
||||
cmp_op: cmp_op,
|
||||
srcs: [x, y],
|
||||
ftz: false,
|
||||
accum: SrcRef::True.into(),
|
||||
});
|
||||
dst
|
||||
}
|
||||
|
||||
fn dsetp(&mut self, cmp_op: FloatCmpOp, x: Src, y: Src) -> SSARef {
|
||||
let dst = self.alloc_ssa(RegFile::Pred, 1);
|
||||
self.push_op(OpDSetP {
|
||||
|
|
|
|||
|
|
@ -535,6 +535,17 @@ impl SM70Instr {
|
|||
self.encode_alu_base(opcode, dst, src0, src1, src2, false);
|
||||
}
|
||||
|
||||
fn encode_fp16_alu(
|
||||
&mut self,
|
||||
opcode: u16,
|
||||
dst: Option<Dst>,
|
||||
src0: ALUSrc,
|
||||
src1: ALUSrc,
|
||||
src2: ALUSrc,
|
||||
) {
|
||||
self.encode_alu_base(opcode, dst, src0, src1, src2, true);
|
||||
}
|
||||
|
||||
fn set_instr_deps(&mut self, deps: &InstrDeps) {
|
||||
self.set_field(105..109, deps.delay);
|
||||
self.set_bit(109, deps.yld);
|
||||
|
|
@ -804,6 +815,159 @@ impl SM70Instr {
|
|||
self.set_pred_src(87..90, 90, op.accum);
|
||||
}
|
||||
|
||||
fn encode_hadd2(&mut self, op: &OpHAdd2) {
|
||||
match op.srcs[1].src_ref {
|
||||
SrcRef::Reg(_) | SrcRef::Zero => {
|
||||
self.encode_fp16_alu(
|
||||
0x030,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
ALUSrc::None,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
self.encode_fp16_alu(
|
||||
0x030,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::None,
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.set_bit(77, op.saturate);
|
||||
self.set_bit(78, op.f32);
|
||||
self.set_bit(80, op.ftz);
|
||||
self.set_bit(85, false); // .BF16_V2 (SM90+)
|
||||
}
|
||||
|
||||
fn encode_hfma2(&mut self, op: &OpHFma2) {
|
||||
// HFMA2 doesn't have fneg and fabs on SRC2.
|
||||
assert!(op.srcs[2].src_mod.is_none());
|
||||
|
||||
self.encode_fp16_alu(
|
||||
0x031,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
ALUSrc::from_src(&op.srcs[2]),
|
||||
);
|
||||
|
||||
self.set_bit(76, op.dnz);
|
||||
self.set_bit(77, op.saturate);
|
||||
self.set_bit(78, op.f32);
|
||||
self.set_bit(79, false); // .RELU (SM86+)
|
||||
self.set_bit(80, op.ftz);
|
||||
self.set_bit(85, false); // .BF16_V2 (SM86+)
|
||||
}
|
||||
|
||||
fn encode_hmul2(&mut self, op: &OpHMul2) {
|
||||
self.encode_fp16_alu(
|
||||
0x032,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
ALUSrc::None,
|
||||
);
|
||||
|
||||
self.set_bit(76, op.dnz);
|
||||
self.set_bit(77, op.saturate);
|
||||
self.set_bit(78, false); // .F32 (SM70-SM75)
|
||||
self.set_bit(79, false); // .RELU (SM86+)
|
||||
self.set_bit(80, op.ftz);
|
||||
self.set_bit(85, false); // .BF16_V2 (SM90+)
|
||||
}
|
||||
|
||||
fn encode_hset2(&mut self, op: &OpHSet2) {
|
||||
match op.srcs[1].src_ref {
|
||||
SrcRef::Reg(_) | SrcRef::Zero => {
|
||||
self.encode_fp16_alu(
|
||||
0x033,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
ALUSrc::None,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
self.encode_fp16_alu(
|
||||
0x033,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::None,
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.set_bit(65, false); // .BF16_V2 (SM90+)
|
||||
self.set_pred_set_op(69..71, op.set_op);
|
||||
|
||||
// This differentiate between integer and fp16 output
|
||||
self.set_bit(71, true); // .BF
|
||||
self.set_float_cmp_op(76..80, op.cmp_op);
|
||||
self.set_bit(80, op.ftz);
|
||||
|
||||
self.set_pred_src(87..90, 90, op.accum);
|
||||
}
|
||||
|
||||
fn encode_hsetp2(&mut self, op: &OpHSetP2) {
|
||||
match op.srcs[1].src_ref {
|
||||
SrcRef::Reg(_) | SrcRef::Zero => {
|
||||
self.encode_fp16_alu(
|
||||
0x034,
|
||||
None,
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
ALUSrc::None,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
self.encode_fp16_alu(
|
||||
0x034,
|
||||
None,
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::None,
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.set_bit(65, false); // .BF16_V2 (SM90+)
|
||||
self.set_pred_set_op(69..71, op.set_op);
|
||||
self.set_bit(71, op.horizontal); // .H_AND
|
||||
self.set_float_cmp_op(76..80, op.cmp_op);
|
||||
self.set_bit(80, op.ftz);
|
||||
|
||||
self.set_pred_dst(81..84, op.dsts[0]);
|
||||
self.set_pred_dst(84..87, op.dsts[1]);
|
||||
|
||||
self.set_pred_src(87..90, 90, op.accum);
|
||||
}
|
||||
|
||||
fn encode_hmnmx2(&mut self, op: &OpHMnMx2) {
|
||||
assert!(self.sm >= 80);
|
||||
|
||||
self.encode_fp16_alu(
|
||||
0x040,
|
||||
Some(op.dst),
|
||||
ALUSrc::from_src(&op.srcs[0]),
|
||||
ALUSrc::from_src(&op.srcs[1]),
|
||||
ALUSrc::None,
|
||||
);
|
||||
|
||||
// This differentiate between integer and fp16 output
|
||||
self.set_bit(78, false); // .F32 (SM86)
|
||||
self.set_bit(80, op.ftz);
|
||||
self.set_bit(81, false); // .NAN
|
||||
self.set_bit(82, false); // .XORSIGN
|
||||
self.set_bit(85, false); // .BF16_V2
|
||||
|
||||
self.set_pred_src(87..90, 90, op.min);
|
||||
}
|
||||
|
||||
fn encode_bmsk(&mut self, op: &OpBMsk) {
|
||||
self.encode_alu(
|
||||
0x01b,
|
||||
|
|
@ -2175,6 +2339,12 @@ impl SM70Instr {
|
|||
Op::DFma(op) => si.encode_dfma(op),
|
||||
Op::DMul(op) => si.encode_dmul(op),
|
||||
Op::DSetP(op) => si.encode_dsetp(op),
|
||||
Op::HAdd2(op) => si.encode_hadd2(op),
|
||||
Op::HFma2(op) => si.encode_hfma2(op),
|
||||
Op::HMul2(op) => si.encode_hmul2(op),
|
||||
Op::HSet2(op) => si.encode_hset2(op),
|
||||
Op::HSetP2(op) => si.encode_hsetp2(op),
|
||||
Op::HMnMx2(op) => si.encode_hmnmx2(op),
|
||||
Op::MuFu(op) => si.encode_mufu(op),
|
||||
Op::BMsk(op) => si.encode_bmsk(op),
|
||||
Op::BRev(op) => si.encode_brev(op),
|
||||
|
|
|
|||
|
|
@ -511,6 +511,16 @@ impl<'a> ShaderFromNir<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
// Restricts an F16v2 source to just x if the ALU op is single-component. This
|
||||
// must only be called for per-component sources (see nir_op_info::output_sizes
|
||||
// for more details).
|
||||
let restrict_f16v2_src = |mut src: Src| {
|
||||
if alu.def.num_components == 1 {
|
||||
src.src_swizzle = SrcSwizzle::Xx;
|
||||
}
|
||||
src
|
||||
};
|
||||
|
||||
let dst: SSARef = match alu.op {
|
||||
nir_op_b2b1 => {
|
||||
assert!(alu.get_src(0).bit_size() == 32);
|
||||
|
|
@ -524,6 +534,7 @@ impl<'a> ShaderFromNir<'a> {
|
|||
let hi = b.copy(0.into());
|
||||
[lo[0], hi[0]].into()
|
||||
}
|
||||
nir_op_b2f16 => b.sel(srcs[0].bnot(), 0.into(), 0x3c00.into()),
|
||||
nir_op_b2f32 => {
|
||||
b.sel(srcs[0].bnot(), 0.0_f32.into(), 1.0_f32.into())
|
||||
}
|
||||
|
|
@ -688,7 +699,7 @@ impl<'a> ShaderFromNir<'a> {
|
|||
}
|
||||
nir_op_fabs | nir_op_fadd | nir_op_fneg => {
|
||||
let (x, y) = match alu.op {
|
||||
nir_op_fabs => (srcs[0].fabs(), 0.0_f32.into()),
|
||||
nir_op_fabs => (srcs[0].fabs(), 0.into()),
|
||||
nir_op_fadd => (srcs[0], srcs[1]),
|
||||
nir_op_fneg => (Src::new_zero().fneg(), srcs[0].fneg()),
|
||||
_ => panic!("Unhandled case"),
|
||||
|
|
@ -711,6 +722,19 @@ impl<'a> ShaderFromNir<'a> {
|
|||
rnd_mode: self.float_ctl[ftype].rnd_mode,
|
||||
ftz: self.float_ctl[ftype].ftz,
|
||||
});
|
||||
} else if alu.def.bit_size() == 16 {
|
||||
assert!(
|
||||
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
|
||||
);
|
||||
|
||||
dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpHAdd2 {
|
||||
dst: dst.into(),
|
||||
srcs: [restrict_f16v2_src(x), restrict_f16v2_src(y)],
|
||||
saturate: self.try_saturate_alu_dst(&alu.def),
|
||||
ftz: self.float_ctl[ftype].ftz,
|
||||
f32: false,
|
||||
});
|
||||
} else {
|
||||
panic!("Unsupported float type: f{}", alu.def.bit_size());
|
||||
}
|
||||
|
|
@ -718,7 +742,6 @@ impl<'a> ShaderFromNir<'a> {
|
|||
}
|
||||
nir_op_fceil | nir_op_ffloor | nir_op_fround_even
|
||||
| nir_op_ftrunc => {
|
||||
assert!(alu.def.bit_size() == 32);
|
||||
let dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
let ty = FloatType::from_bits(alu.def.bit_size().into());
|
||||
let rnd_mode = match alu.op {
|
||||
|
|
@ -730,6 +753,9 @@ impl<'a> ShaderFromNir<'a> {
|
|||
};
|
||||
let ftz = self.float_ctl[ty].ftz;
|
||||
if b.sm() >= 70 {
|
||||
assert!(
|
||||
alu.def.bit_size() == 32 || alu.def.bit_size() == 16
|
||||
);
|
||||
b.push_op(OpFRnd {
|
||||
dst: dst.into(),
|
||||
src: srcs[0],
|
||||
|
|
@ -739,6 +765,7 @@ impl<'a> ShaderFromNir<'a> {
|
|||
ftz,
|
||||
});
|
||||
} else {
|
||||
assert!(alu.def.bit_size() == 32);
|
||||
b.push_op(OpF2F {
|
||||
dst: dst.into(),
|
||||
src: srcs[0],
|
||||
|
|
@ -764,8 +791,9 @@ impl<'a> ShaderFromNir<'a> {
|
|||
_ => panic!("Usupported float comparison"),
|
||||
};
|
||||
|
||||
let dst = b.alloc_ssa(RegFile::Pred, 1);
|
||||
let dst = b.alloc_ssa(RegFile::Pred, alu.def.num_components);
|
||||
if alu.get_src(0).bit_size() == 64 {
|
||||
assert!(alu.def.num_components == 1);
|
||||
b.push_op(OpDSetP {
|
||||
dst: dst.into(),
|
||||
set_op: PredSetOp::And,
|
||||
|
|
@ -774,6 +802,7 @@ impl<'a> ShaderFromNir<'a> {
|
|||
accum: SrcRef::True.into(),
|
||||
});
|
||||
} else if alu.get_src(0).bit_size() == 32 {
|
||||
assert!(alu.def.num_components == 1);
|
||||
b.push_op(OpFSetP {
|
||||
dst: dst.into(),
|
||||
set_op: PredSetOp::And,
|
||||
|
|
@ -782,6 +811,30 @@ impl<'a> ShaderFromNir<'a> {
|
|||
accum: SrcRef::True.into(),
|
||||
ftz: self.float_ctl[src_type].ftz,
|
||||
});
|
||||
} else if alu.get_src(0).bit_size() == 16 {
|
||||
assert!(
|
||||
alu.def.num_components == 1
|
||||
|| alu.def.num_components == 2
|
||||
);
|
||||
|
||||
let dsts = if alu.def.num_components == 2 {
|
||||
[dst[0].into(), dst[1].into()]
|
||||
} else {
|
||||
[dst[0].into(), Dst::None]
|
||||
};
|
||||
|
||||
b.push_op(OpHSetP2 {
|
||||
dsts,
|
||||
set_op: PredSetOp::And,
|
||||
cmp_op: cmp_op,
|
||||
srcs: [
|
||||
restrict_f16v2_src(srcs[0]),
|
||||
restrict_f16v2_src(srcs[1]),
|
||||
],
|
||||
accum: SrcRef::True.into(),
|
||||
ftz: self.float_ctl[src_type].ftz,
|
||||
horizontal: false,
|
||||
});
|
||||
} else {
|
||||
panic!(
|
||||
"Unsupported float type: f{}",
|
||||
|
|
@ -814,6 +867,24 @@ impl<'a> ShaderFromNir<'a> {
|
|||
ftz: self.float_ctl[ftype].ftz,
|
||||
dnz: false,
|
||||
});
|
||||
} else if alu.def.bit_size() == 16 {
|
||||
assert!(
|
||||
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
|
||||
);
|
||||
|
||||
dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpHFma2 {
|
||||
dst: dst.into(),
|
||||
srcs: [
|
||||
restrict_f16v2_src(srcs[0]),
|
||||
restrict_f16v2_src(srcs[1]),
|
||||
restrict_f16v2_src(srcs[2]),
|
||||
],
|
||||
saturate: self.try_saturate_alu_dst(&alu.def),
|
||||
ftz: self.float_ctl[ftype].ftz,
|
||||
dnz: false,
|
||||
f32: false,
|
||||
});
|
||||
} else {
|
||||
panic!("Unsupported float type: f{}", alu.def.bit_size());
|
||||
}
|
||||
|
|
@ -857,6 +928,17 @@ impl<'a> ShaderFromNir<'a> {
|
|||
min: (alu.op == nir_op_fmin).into(),
|
||||
ftz: self.float_ctl.fp32.ftz,
|
||||
});
|
||||
} else if alu.def.bit_size() == 16 {
|
||||
dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpHMnMx2 {
|
||||
dst: dst.into(),
|
||||
srcs: [
|
||||
restrict_f16v2_src(srcs[0]),
|
||||
restrict_f16v2_src(srcs[1]),
|
||||
],
|
||||
min: (alu.op == nir_op_fmin).into(),
|
||||
ftz: self.float_ctl.fp16.ftz,
|
||||
});
|
||||
} else {
|
||||
panic!("Unsupported float type: f{}", alu.def.bit_size());
|
||||
}
|
||||
|
|
@ -883,6 +965,22 @@ impl<'a> ShaderFromNir<'a> {
|
|||
ftz: self.float_ctl[ftype].ftz,
|
||||
dnz: false,
|
||||
});
|
||||
} else if alu.def.bit_size() == 16 {
|
||||
assert!(
|
||||
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
|
||||
);
|
||||
|
||||
dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpHMul2 {
|
||||
dst: dst.into(),
|
||||
srcs: [
|
||||
restrict_f16v2_src(srcs[0]),
|
||||
restrict_f16v2_src(srcs[1]),
|
||||
],
|
||||
saturate: self.try_saturate_alu_dst(&alu.def),
|
||||
ftz: self.float_ctl[ftype].ftz,
|
||||
dnz: false,
|
||||
});
|
||||
} else {
|
||||
panic!("Unsupported float type: f{}", alu.def.bit_size());
|
||||
}
|
||||
|
|
@ -940,11 +1038,11 @@ impl<'a> ShaderFromNir<'a> {
|
|||
b.mufu(MuFuOp::Rsq, srcs[0])
|
||||
}
|
||||
nir_op_fsat => {
|
||||
assert!(alu.def.bit_size() == 32);
|
||||
let ftype = FloatType::from_bits(alu.def.bit_size().into());
|
||||
|
||||
if self.alu_src_is_saturated(&alu.srcs_as_slice()[0]) {
|
||||
b.copy(srcs[0])
|
||||
} else {
|
||||
let ftype = FloatType::from_bits(alu.def.bit_size().into());
|
||||
} else if alu.def.bit_size() == 32 {
|
||||
let dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpFAdd {
|
||||
dst: dst.into(),
|
||||
|
|
@ -954,6 +1052,22 @@ impl<'a> ShaderFromNir<'a> {
|
|||
ftz: self.float_ctl[ftype].ftz,
|
||||
});
|
||||
dst
|
||||
} else if alu.def.bit_size() == 16 {
|
||||
assert!(
|
||||
self.float_ctl[ftype].rnd_mode == FRndMode::NearestEven
|
||||
);
|
||||
|
||||
let dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpHAdd2 {
|
||||
dst: dst.into(),
|
||||
srcs: [restrict_f16v2_src(srcs[0]), 0.into()],
|
||||
saturate: true,
|
||||
ftz: self.float_ctl[ftype].ftz,
|
||||
f32: false,
|
||||
});
|
||||
dst
|
||||
} else {
|
||||
panic!("Unsupported float type: f{}", alu.def.bit_size());
|
||||
}
|
||||
}
|
||||
nir_op_fsign => {
|
||||
|
|
@ -968,6 +1082,17 @@ impl<'a> ShaderFromNir<'a> {
|
|||
let lz = b.fset(FloatCmpOp::OrdLt, srcs[0], 0.into());
|
||||
let gz = b.fset(FloatCmpOp::OrdGt, srcs[0], 0.into());
|
||||
b.fadd(gz.into(), Src::from(lz).fneg())
|
||||
} else if alu.def.bit_size() == 16 {
|
||||
let x = restrict_f16v2_src(srcs[0]);
|
||||
|
||||
let lz = restrict_f16v2_src(
|
||||
b.hset2(FloatCmpOp::OrdLt, x, 0.into()).into(),
|
||||
);
|
||||
let gz = restrict_f16v2_src(
|
||||
b.hset2(FloatCmpOp::OrdGt, x, 0.into()).into(),
|
||||
);
|
||||
|
||||
b.hadd2(gz, lz.fneg())
|
||||
} else {
|
||||
panic!("Unsupported float type: f{}", alu.def.bit_size());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2696,6 +2696,186 @@ impl DisplayOp for OpDSetP {
|
|||
}
|
||||
impl_display_for_op!(OpDSetP);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHAdd2 {
|
||||
pub dst: Dst,
|
||||
|
||||
#[src_type(F16v2)]
|
||||
pub srcs: [Src; 2],
|
||||
|
||||
pub saturate: bool,
|
||||
pub ftz: bool,
|
||||
pub f32: bool,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHAdd2 {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let sat = if self.saturate { ".sat" } else { "" };
|
||||
let f32 = if self.f32 { ".f32" } else { "" };
|
||||
write!(f, "hadd2{sat}{f32}")?;
|
||||
if self.ftz {
|
||||
write!(f, ".ftz")?;
|
||||
}
|
||||
write!(f, " {} {}", self.srcs[0], self.srcs[1])
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpHAdd2);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHSet2 {
|
||||
pub dst: Dst,
|
||||
|
||||
pub set_op: PredSetOp,
|
||||
pub cmp_op: FloatCmpOp,
|
||||
|
||||
#[src_type(F16v2)]
|
||||
pub srcs: [Src; 2],
|
||||
|
||||
#[src_type(Pred)]
|
||||
pub accum: Src,
|
||||
|
||||
pub ftz: bool,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHSet2 {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let ftz = if self.ftz { ".ftz" } else { "" };
|
||||
write!(f, "hset2{}{ftz}", self.cmp_op)?;
|
||||
if !self.set_op.is_trivial(&self.accum) {
|
||||
write!(f, "{}", self.set_op)?;
|
||||
}
|
||||
write!(f, " {} {}", self.srcs[0], self.srcs[1])?;
|
||||
if !self.set_op.is_trivial(&self.accum) {
|
||||
write!(f, " {}", self.accum)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpHSet2);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHSetP2 {
|
||||
pub dsts: [Dst; 2],
|
||||
|
||||
pub set_op: PredSetOp,
|
||||
pub cmp_op: FloatCmpOp,
|
||||
|
||||
#[src_type(F16v2)]
|
||||
pub srcs: [Src; 2],
|
||||
|
||||
#[src_type(Pred)]
|
||||
pub accum: Src,
|
||||
|
||||
pub ftz: bool,
|
||||
|
||||
// When not set, each dsts get the result of each lanes.
|
||||
// When set, the first dst gets the result of both lanes (res0 && res1)
|
||||
// and the second dst gets the negation !(res0 && res1)
|
||||
// before applying the accumulator.
|
||||
pub horizontal: bool,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHSetP2 {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let ftz = if self.ftz { ".ftz" } else { "" };
|
||||
write!(f, "hsetp2{}{ftz}", self.cmp_op)?;
|
||||
if !self.set_op.is_trivial(&self.accum) {
|
||||
write!(f, "{}", self.set_op)?;
|
||||
}
|
||||
write!(f, " {} {}", self.srcs[0], self.srcs[1])?;
|
||||
if !self.set_op.is_trivial(&self.accum) {
|
||||
write!(f, " {}", self.accum)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpHSetP2);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHMul2 {
|
||||
pub dst: Dst,
|
||||
|
||||
#[src_type(F16v2)]
|
||||
pub srcs: [Src; 2],
|
||||
|
||||
pub saturate: bool,
|
||||
pub ftz: bool,
|
||||
pub dnz: bool,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHMul2 {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let sat = if self.saturate { ".sat" } else { "" };
|
||||
write!(f, "hmul2{sat}")?;
|
||||
if self.dnz {
|
||||
write!(f, ".dnz")?;
|
||||
} else if self.ftz {
|
||||
write!(f, ".ftz")?;
|
||||
}
|
||||
write!(f, " {} {}", self.srcs[0], self.srcs[1])
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpHMul2);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHFma2 {
|
||||
pub dst: Dst,
|
||||
|
||||
#[src_type(F16v2)]
|
||||
pub srcs: [Src; 3],
|
||||
|
||||
pub saturate: bool,
|
||||
pub ftz: bool,
|
||||
pub dnz: bool,
|
||||
pub f32: bool,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHFma2 {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let sat = if self.saturate { ".sat" } else { "" };
|
||||
let f32 = if self.f32 { ".f32" } else { "" };
|
||||
write!(f, "hfma2{sat}{f32}")?;
|
||||
if self.dnz {
|
||||
write!(f, ".dnz")?;
|
||||
} else if self.ftz {
|
||||
write!(f, ".ftz")?;
|
||||
}
|
||||
write!(f, " {} {} {}", self.srcs[0], self.srcs[1], self.srcs[2])
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpHFma2);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpHMnMx2 {
|
||||
pub dst: Dst,
|
||||
|
||||
#[src_type(F16v2)]
|
||||
pub srcs: [Src; 2],
|
||||
|
||||
#[src_type(Pred)]
|
||||
pub min: Src,
|
||||
|
||||
pub ftz: bool,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpHMnMx2 {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let ftz = if self.ftz { ".ftz" } else { "" };
|
||||
write!(
|
||||
f,
|
||||
"hmnmx2{ftz} {} {} {}",
|
||||
self.srcs[0], self.srcs[1], self.min
|
||||
)
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpHMnMx2);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(SrcsAsSlice, DstsAsSlice)]
|
||||
pub struct OpBMsk {
|
||||
|
|
@ -4999,6 +5179,12 @@ pub enum Op {
|
|||
DMnMx(OpDMnMx),
|
||||
DMul(OpDMul),
|
||||
DSetP(OpDSetP),
|
||||
HAdd2(OpHAdd2),
|
||||
HFma2(OpHFma2),
|
||||
HMul2(OpHMul2),
|
||||
HSet2(OpHSet2),
|
||||
HSetP2(OpHSetP2),
|
||||
HMnMx2(OpHMnMx2),
|
||||
BMsk(OpBMsk),
|
||||
BRev(OpBRev),
|
||||
Bfe(OpBfe),
|
||||
|
|
@ -5434,6 +5620,12 @@ impl Instr {
|
|||
| Op::FMul(_)
|
||||
| Op::FSet(_)
|
||||
| Op::FSetP(_)
|
||||
| Op::HAdd2(_)
|
||||
| Op::HFma2(_)
|
||||
| Op::HMul2(_)
|
||||
| Op::HSet2(_)
|
||||
| Op::HSetP2(_)
|
||||
| Op::HMnMx2(_)
|
||||
| Op::FSwzAdd(_) => true,
|
||||
|
||||
// Multi-function unit is variable latency
|
||||
|
|
|
|||
|
|
@ -163,35 +163,54 @@ fn copy_alu_src_if_f20_overflow(
|
|||
}
|
||||
}
|
||||
|
||||
fn copy_alu_src_and_lower_fmod(
|
||||
b: &mut impl SSABuilder,
|
||||
src: &mut Src,
|
||||
src_type: SrcType,
|
||||
) {
|
||||
match src_type {
|
||||
SrcType::F16 | SrcType::F16v2 => {
|
||||
let val = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpHAdd2 {
|
||||
dst: val.into(),
|
||||
srcs: [Src::new_zero().fneg(), *src],
|
||||
saturate: false,
|
||||
ftz: false,
|
||||
f32: false,
|
||||
});
|
||||
*src = val.into();
|
||||
}
|
||||
SrcType::F32 => {
|
||||
let val = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpFAdd {
|
||||
dst: val.into(),
|
||||
srcs: [Src::new_zero().fneg(), *src],
|
||||
saturate: false,
|
||||
rnd_mode: FRndMode::NearestEven,
|
||||
ftz: false,
|
||||
});
|
||||
*src = val.into();
|
||||
}
|
||||
SrcType::F64 => {
|
||||
let val = b.alloc_ssa(RegFile::GPR, 2);
|
||||
b.push_op(OpDAdd {
|
||||
dst: val.into(),
|
||||
srcs: [Src::new_zero().fneg(), *src],
|
||||
rnd_mode: FRndMode::NearestEven,
|
||||
});
|
||||
*src = val.into();
|
||||
}
|
||||
_ => panic!("Invalid ffabs srouce type"),
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_alu_src_if_fabs(
|
||||
b: &mut impl SSABuilder,
|
||||
src: &mut Src,
|
||||
src_type: SrcType,
|
||||
) {
|
||||
if src.src_mod.has_fabs() {
|
||||
match src_type {
|
||||
SrcType::F32 => {
|
||||
let val = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpFAdd {
|
||||
dst: val.into(),
|
||||
srcs: [Src::new_zero().fneg(), *src],
|
||||
saturate: false,
|
||||
rnd_mode: FRndMode::NearestEven,
|
||||
ftz: false,
|
||||
});
|
||||
*src = val.into();
|
||||
}
|
||||
SrcType::F64 => {
|
||||
let val = b.alloc_ssa(RegFile::GPR, 2);
|
||||
b.push_op(OpDAdd {
|
||||
dst: val.into(),
|
||||
srcs: [Src::new_zero().fneg(), *src],
|
||||
rnd_mode: FRndMode::NearestEven,
|
||||
});
|
||||
*src = val.into();
|
||||
}
|
||||
_ => panic!("Invalid ffabs srouce type"),
|
||||
}
|
||||
copy_alu_src_and_lower_fmod(b, src, src_type);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -501,6 +520,49 @@ fn legalize_sm70_instr(
|
|||
}
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F32);
|
||||
}
|
||||
Op::HAdd2(op) => {
|
||||
let [ref mut src0, ref mut src1] = op.srcs;
|
||||
swap_srcs_if_not_reg(src0, src1);
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
|
||||
}
|
||||
Op::HFma2(op) => {
|
||||
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
|
||||
swap_srcs_if_not_reg(src0, src1);
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
|
||||
copy_alu_src_if_not_reg(b, src1, SrcType::F16v2);
|
||||
copy_alu_src_if_both_not_reg(b, src1, src2, SrcType::F16v2);
|
||||
|
||||
// HFMA2 doesn't have fabs or fneg on SRC2.
|
||||
if !src2.src_mod.is_none() {
|
||||
copy_alu_src_and_lower_fmod(b, src2, SrcType::F16v2);
|
||||
}
|
||||
}
|
||||
Op::HMul2(op) => {
|
||||
let [ref mut src0, ref mut src1] = op.srcs;
|
||||
swap_srcs_if_not_reg(src0, src1);
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
|
||||
}
|
||||
Op::HSet2(op) => {
|
||||
let [ref mut src0, ref mut src1] = op.srcs;
|
||||
if !src_is_reg(src0) && src_is_reg(src1) {
|
||||
std::mem::swap(src0, src1);
|
||||
op.cmp_op = op.cmp_op.flip();
|
||||
}
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
|
||||
}
|
||||
Op::HSetP2(op) => {
|
||||
let [ref mut src0, ref mut src1] = op.srcs;
|
||||
if !src_is_reg(src0) && src_is_reg(src1) {
|
||||
std::mem::swap(src0, src1);
|
||||
op.cmp_op = op.cmp_op.flip();
|
||||
}
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
|
||||
}
|
||||
Op::HMnMx2(op) => {
|
||||
let [ref mut src0, ref mut src1] = op.srcs;
|
||||
swap_srcs_if_not_reg(src0, src1);
|
||||
copy_alu_src_if_not_reg(b, src0, SrcType::F16v2);
|
||||
}
|
||||
Op::MuFu(_) => (), // Nothing to do
|
||||
Op::DAdd(op) => {
|
||||
let [ref mut src0, ref mut src1] = op.srcs;
|
||||
|
|
|
|||
|
|
@ -395,6 +395,19 @@ impl CopyPropPass {
|
|||
|
||||
fn try_add_instr(&mut self, instr: &Instr) {
|
||||
match &instr.op {
|
||||
Op::HAdd2(add) => {
|
||||
let dst = add.dst.as_ssa().unwrap();
|
||||
assert!(dst.comps() == 1);
|
||||
let dst = dst[0];
|
||||
|
||||
if !add.saturate {
|
||||
if add.srcs[0].is_fneg_zero(SrcType::F16v2) {
|
||||
self.add_copy(dst, SrcType::F16v2, add.srcs[1]);
|
||||
} else if add.srcs[1].is_fneg_zero(SrcType::F16v2) {
|
||||
self.add_copy(dst, SrcType::F16v2, add.srcs[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::FAdd(add) => {
|
||||
let dst = add.dst.as_ssa().unwrap();
|
||||
assert!(dst.comps() == 1);
|
||||
|
|
|
|||
|
|
@ -146,14 +146,20 @@ nak_optimize_nir(nir_shader *nir, const struct nak_compiler *nak)
|
|||
}
|
||||
|
||||
static unsigned
|
||||
lower_bit_size_cb(const nir_instr *instr, void *_data)
|
||||
lower_bit_size_cb(const nir_instr *instr, void *data)
|
||||
{
|
||||
const struct nak_compiler *nak = data;
|
||||
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_alu: {
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
if (nir_op_infos[alu->op].is_conversion)
|
||||
return 0;
|
||||
|
||||
const unsigned bit_size = nir_alu_instr_is_comparison(alu)
|
||||
? alu->src[0].src.ssa->bit_size
|
||||
: alu->def.bit_size;
|
||||
|
||||
switch (alu->op) {
|
||||
case nir_op_bit_count:
|
||||
case nir_op_ufind_msb:
|
||||
|
|
@ -164,17 +170,40 @@ lower_bit_size_cb(const nir_instr *instr, void *_data)
|
|||
* source.
|
||||
*/
|
||||
return alu->src[0].src.ssa->bit_size == 32 ? 0 : 32;
|
||||
|
||||
case nir_op_fabs:
|
||||
case nir_op_fadd:
|
||||
case nir_op_fneg:
|
||||
case nir_op_feq:
|
||||
case nir_op_fge:
|
||||
case nir_op_flt:
|
||||
case nir_op_fneu:
|
||||
case nir_op_fmul:
|
||||
case nir_op_ffma:
|
||||
case nir_op_ffmaz:
|
||||
case nir_op_fsign:
|
||||
case nir_op_fsat:
|
||||
case nir_op_fceil:
|
||||
case nir_op_ffloor:
|
||||
case nir_op_fround_even:
|
||||
case nir_op_ftrunc:
|
||||
if (bit_size == 16 && nak->sm >= 70)
|
||||
return 0;
|
||||
break;
|
||||
|
||||
case nir_op_fmax:
|
||||
case nir_op_fmin:
|
||||
if (bit_size == 16 && nak->sm >= 80)
|
||||
return 0;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
const unsigned bit_size = nir_alu_instr_is_comparison(alu)
|
||||
? alu->src[0].src.ssa->bit_size
|
||||
: alu->def.bit_size;
|
||||
if (bit_size >= 32)
|
||||
return 0;
|
||||
|
||||
/* TODO: Some hardware has native 16-bit support */
|
||||
if (bit_size & (8 | 16))
|
||||
return 32;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue