nak: Plumb through float controls

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26557>
This commit is contained in:
Faith Ekstrand 2023-12-06 17:41:22 -06:00 committed by Marge Bot
parent 29bfdcd7c1
commit 1c84c8183c
2 changed files with 126 additions and 22 deletions

View file

@ -64,6 +64,7 @@ nak_bindings_rs = rust.bindgen(
'--raw-line', '#![allow(non_upper_case_globals)]',
'--allowlist-type', 'exec_list',
'--allowlist-type', 'exec_node',
'--allowlist-type', 'float_controls',
'--allowlist-type', 'gl_access_qualifier',
'--allowlist-type', 'gl_frag_result',
'--allowlist-type', 'gl_interp_mode',

View file

@ -14,6 +14,7 @@ use nak_bindings::*;
use std::cmp::max;
use std::collections::{HashMap, HashSet};
use std::ops::Index;
fn init_info_from_nir(nir: &nir_shader, sm: u8) -> ShaderInfo {
ShaderInfo {
@ -144,9 +145,94 @@ impl<'a> PhiAllocMap<'a> {
}
}
struct PerSizeFloatControls {
pub ftz: bool,
pub rnd_mode: FRndMode,
}
struct ShaderFloatControls {
pub fp16: PerSizeFloatControls,
pub fp32: PerSizeFloatControls,
pub fp64: PerSizeFloatControls,
}
impl Default for ShaderFloatControls {
fn default() -> Self {
Self {
fp16: PerSizeFloatControls {
ftz: false,
rnd_mode: FRndMode::NearestEven,
},
fp32: PerSizeFloatControls {
ftz: false,
rnd_mode: FRndMode::NearestEven,
},
fp64: PerSizeFloatControls {
ftz: false,
rnd_mode: FRndMode::NearestEven,
},
}
}
}
impl ShaderFloatControls {
fn from_nir(nir: &nir_shader) -> ShaderFloatControls {
let nir_fc = nir.info.float_controls_execution_mode;
let mut fc: ShaderFloatControls = Default::default();
if (nir_fc & FLOAT_CONTROLS_DENORM_PRESERVE_FP16) != 0 {
fc.fp16.ftz = false;
} else if (nir_fc & FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP16) != 0 {
fc.fp16.ftz = true;
}
if (nir_fc & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16) != 0 {
fc.fp16.rnd_mode = FRndMode::NearestEven;
} else if (nir_fc & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16) != 0 {
fc.fp16.rnd_mode = FRndMode::Zero;
}
if (nir_fc & FLOAT_CONTROLS_DENORM_PRESERVE_FP32) != 0 {
fc.fp32.ftz = false;
} else if (nir_fc & FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32) != 0 {
fc.fp32.ftz = true;
}
if (nir_fc & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP32) != 0 {
fc.fp32.rnd_mode = FRndMode::NearestEven;
} else if (nir_fc & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP32) != 0 {
fc.fp32.rnd_mode = FRndMode::Zero;
}
if (nir_fc & FLOAT_CONTROLS_DENORM_PRESERVE_FP64) != 0 {
fc.fp64.ftz = false;
} else if (nir_fc & FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP64) != 0 {
fc.fp64.ftz = true;
}
if (nir_fc & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP64) != 0 {
fc.fp64.rnd_mode = FRndMode::NearestEven;
} else if (nir_fc & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP64) != 0 {
fc.fp64.rnd_mode = FRndMode::Zero;
}
fc
}
}
impl Index<FloatType> for ShaderFloatControls {
type Output = PerSizeFloatControls;
fn index(&self, idx: FloatType) -> &PerSizeFloatControls {
match idx {
FloatType::F16 => &self.fp16,
FloatType::F32 => &self.fp32,
FloatType::F64 => &self.fp64,
}
}
}
struct ShaderFromNir<'a> {
nir: &'a nir_shader,
info: ShaderInfo,
float_ctl: ShaderFloatControls,
cfg: CFGBuilder<u32, BasicBlock>,
label_alloc: LabelAllocator,
block_label: HashMap<u32, Label>,
@ -162,6 +248,7 @@ impl<'a> ShaderFromNir<'a> {
Self {
nir: nir,
info: init_info_from_nir(nir, sm),
float_ctl: ShaderFloatControls::from_nir(nir),
cfg: CFGBuilder::new(),
label_alloc: LabelAllocator::new(),
block_label: HashMap::new(),
@ -489,18 +576,25 @@ impl<'a> ShaderFromNir<'a> {
| nir_op_f2f32 | nir_op_f2f64 => {
let src_bits = alu.get_src(0).src.bit_size();
let dst_bits = alu.def.bit_size();
let src_type = FloatType::from_bits(src_bits.into());
let dst_type = FloatType::from_bits(dst_bits.into());
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpF2F {
dst: dst.into(),
src: srcs[0],
src_type: FloatType::from_bits(src_bits.into()),
dst_type: FloatType::from_bits(dst_bits.into()),
dst_type: dst_type,
rnd_mode: match alu.op {
nir_op_f2f16_rtne => FRndMode::NearestEven,
nir_op_f2f16_rtz => FRndMode::Zero,
_ => FRndMode::NearestEven,
_ => self.float_ctl[dst_type].rnd_mode,
},
ftz: if src_bits < dst_bits {
self.float_ctl[src_type].ftz
} else {
self.float_ctl[dst_type].ftz
},
ftz: true,
high: false,
});
dst
@ -524,18 +618,19 @@ impl<'a> ShaderFromNir<'a> {
| nir_op_f2u8 | nir_op_f2u16 | nir_op_f2u32 | nir_op_f2u64 => {
let src_bits = usize::from(alu.get_src(0).bit_size());
let dst_bits = alu.def.bit_size();
let src_type = FloatType::from_bits(src_bits);
let dst = b.alloc_ssa(RegFile::GPR, dst_bits.div_ceil(32));
let dst_is_signed = alu.info().output_type & 2 != 0;
b.push_op(OpF2I {
dst: dst.into(),
src: srcs[0],
src_type: FloatType::from_bits(src_bits),
src_type: src_type,
dst_type: IntType::from_bits(
dst_bits.into(),
dst_is_signed,
),
rnd_mode: FRndMode::Zero,
ftz: false,
ftz: self.float_ctl[src_type].ftz,
});
dst
}
@ -546,6 +641,7 @@ impl<'a> ShaderFromNir<'a> {
nir_op_fneg => (Src::new_zero().fneg(), srcs[0].fneg()),
_ => panic!("Unhandled case"),
};
let ftype = FloatType::from_bits(alu.def.bit_size().into());
assert!(alu.def.bit_size() == 32);
let dst = b.alloc_ssa(RegFile::GPR, 1);
let saturate = self.try_saturate_alu_dst(&alu.def);
@ -553,8 +649,8 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.into(),
srcs: [x, y],
saturate: saturate,
rnd_mode: FRndMode::NearestEven,
ftz: false,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
});
dst
}
@ -586,14 +682,15 @@ impl<'a> ShaderFromNir<'a> {
nir_op_feq => b.fsetp(FloatCmpOp::OrdEq, srcs[0], srcs[1]),
nir_op_fexp2 => b.mufu(MuFuOp::Exp2, srcs[0]),
nir_op_ffma => {
let ftype = FloatType::from_bits(alu.def.bit_size().into());
assert!(alu.def.bit_size() == 32);
let dst = b.alloc_ssa(RegFile::GPR, 1);
let ffma = OpFFma {
dst: dst.into(),
srcs: [srcs[0], srcs[1], srcs[2]],
saturate: self.try_saturate_alu_dst(&alu.def),
rnd_mode: FRndMode::NearestEven,
ftz: false,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
};
b.push_op(ffma);
dst
@ -617,19 +714,20 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
min: (alu.op == nir_op_fmin).into(),
ftz: false,
ftz: self.float_ctl.fp32.ftz,
});
dst
}
nir_op_fmul => {
let ftype = FloatType::from_bits(alu.def.bit_size().into());
assert!(alu.def.bit_size() == 32);
let dst = b.alloc_ssa(RegFile::GPR, 1);
let fmul = OpFMul {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
saturate: self.try_saturate_alu_dst(&alu.def),
rnd_mode: FRndMode::NearestEven,
ftz: false,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
};
b.push_op(fmul);
dst
@ -672,13 +770,14 @@ impl<'a> ShaderFromNir<'a> {
if self.alu_src_is_saturated(&alu.srcs_as_slice()[0]) {
b.copy(srcs[0])
} else {
let ftype = FloatType::from_bits(alu.def.bit_size().into());
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpFAdd {
dst: dst.into(),
srcs: [srcs[0], 0.into()],
saturate: true,
rnd_mode: FRndMode::NearestEven,
ftz: false,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
});
dst
}
@ -698,13 +797,14 @@ impl<'a> ShaderFromNir<'a> {
nir_op_i2f16 | nir_op_i2f32 | nir_op_i2f64 => {
let src_bits = alu.get_src(0).src.bit_size();
let dst_bits = alu.def.bit_size();
let dst_type = FloatType::from_bits(dst_bits.into());
let dst = b.alloc_ssa(RegFile::GPR, dst_bits.div_ceil(32));
b.push_op(OpI2F {
dst: dst.into(),
src: srcs[0],
dst_type: FloatType::from_bits(dst_bits.into()),
dst_type: dst_type,
src_type: IntType::from_bits(src_bits.into(), true),
rnd_mode: FRndMode::NearestEven,
rnd_mode: self.float_ctl[dst_type].rnd_mode,
});
dst
}
@ -996,13 +1096,14 @@ impl<'a> ShaderFromNir<'a> {
nir_op_u2f16 | nir_op_u2f32 | nir_op_u2f64 => {
let src_bits = alu.get_src(0).src.bit_size();
let dst_bits = alu.def.bit_size();
let dst_type = FloatType::from_bits(dst_bits.into());
let dst = b.alloc_ssa(RegFile::GPR, dst_bits.div_ceil(32));
b.push_op(OpI2F {
dst: dst.into(),
src: srcs[0],
dst_type: FloatType::from_bits(dst_bits.into()),
dst_type: dst_type,
src_type: IntType::from_bits(src_bits.into(), false),
rnd_mode: FRndMode::NearestEven,
rnd_mode: self.float_ctl[dst_type].rnd_mode,
});
dst
}
@ -1137,6 +1238,7 @@ impl<'a> ShaderFromNir<'a> {
// TODO: Real coarse derivatives
assert!(alu.def.bit_size() == 32);
let ftype = FloatType::F32;
let scratch = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpShfl {
@ -1159,8 +1261,8 @@ impl<'a> ShaderFromNir<'a> {
FSwzAddOp::SubLeft,
FSwzAddOp::SubRight,
],
rnd_mode: FRndMode::NearestEven,
ftz: false,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
});
dst
@ -1169,6 +1271,7 @@ impl<'a> ShaderFromNir<'a> {
// TODO: Real coarse derivatives
assert!(alu.def.bit_size() == 32);
let ftype = FloatType::F32;
let scratch = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpShfl {
@ -1191,8 +1294,8 @@ impl<'a> ShaderFromNir<'a> {
FSwzAddOp::SubRight,
FSwzAddOp::SubRight,
],
rnd_mode: FRndMode::NearestEven,
ftz: false,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
});
dst