nak: Add a Foldable trait

This is for ops that we know how to constant-fold.  We'll use this to
generate unit tests.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30275>
This commit is contained in:
Faith Ekstrand 2024-07-17 21:44:56 -05:00 committed by Marge Bot
parent 841737925f
commit e09dd8e201

View file

@ -1421,6 +1421,12 @@ pub trait SrcsAsSlice {
fn srcs_as_slice(&self) -> &[Src];
fn srcs_as_mut_slice(&mut self) -> &mut [Src];
fn src_types(&self) -> SrcTypeList;
fn src_idx(&self, src: &Src) -> usize {
let r = self.srcs_as_slice().as_ptr_range();
assert!(r.contains(&(src as *const Src)));
unsafe { (src as *const Src).offset_from(r.start) as usize }
}
}
fn all_dsts_uniform(dsts: &[Dst]) -> bool {
@ -1461,6 +1467,12 @@ pub trait DstsAsSlice {
fn dsts_as_mut_slice(&mut self) -> &mut [Dst];
fn dst_types(&self) -> DstTypeList;
fn dst_idx(&self, dst: &Dst) -> usize {
let r = self.dsts_as_slice().as_ptr_range();
assert!(r.contains(&(dst as *const Dst)));
unsafe { (dst as *const Dst).offset_from(r.start) as usize }
}
fn is_uniform(&self) -> bool {
all_dsts_uniform(self.dsts_as_slice())
}
@ -1491,6 +1503,102 @@ fn fmt_dst_slice(f: &mut fmt::Formatter<'_>, dsts: &[Dst]) -> fmt::Result {
Ok(())
}
#[allow(dead_code)]
#[derive(Clone, Copy)]
pub enum FoldData {
Pred(bool),
U32(u32),
Vec2([u32; 2]),
}
pub struct OpFoldData<'a> {
pub dsts: &'a mut [FoldData],
pub srcs: &'a [FoldData],
}
impl OpFoldData<'_> {
#[allow(dead_code)]
pub fn get_pred_src(&self, op: &impl SrcsAsSlice, src: &Src) -> bool {
let i = op.src_idx(src);
match src.src_ref {
SrcRef::Zero | SrcRef::Imm32(_) => panic!("Expected a predicate"),
SrcRef::True => true,
SrcRef::False => false,
_ => {
if let FoldData::Pred(b) = self.srcs[i] {
b
} else {
panic!("FoldData is not a predicate");
}
}
}
}
pub fn get_u32_src(&self, op: &impl SrcsAsSlice, src: &Src) -> u32 {
let i = op.src_idx(src);
match src.src_ref {
SrcRef::Zero => 0,
SrcRef::Imm32(imm) => imm,
SrcRef::True | SrcRef::False => panic!("Unexpected predicate"),
_ => {
if let FoldData::U32(u) = self.srcs[i] {
u
} else {
panic!("FoldData is not a U32");
}
}
}
}
#[allow(dead_code)]
pub fn get_f32_src(&self, op: &impl SrcsAsSlice, src: &Src) -> f32 {
f32::from_bits(self.get_u32_src(op, src))
}
#[allow(dead_code)]
pub fn get_f64_src(&self, op: &impl SrcsAsSlice, src: &Src) -> f64 {
let i = op.src_idx(src);
match src.src_ref {
SrcRef::Zero => 0.0,
SrcRef::Imm32(imm) => f64::from_bits(u64::from(imm) << 32),
SrcRef::True | SrcRef::False => panic!("Unexpected predicate"),
_ => {
if let FoldData::Vec2(v) = self.srcs[i] {
let u = u64::from(v[0]) | (u64::from(v[1]) << 32);
f64::from_bits(u)
} else {
panic!("FoldData is not a U32");
}
}
}
}
#[allow(dead_code)]
pub fn set_pred_dst(&mut self, op: &impl DstsAsSlice, dst: &Dst, b: bool) {
self.dsts[op.dst_idx(dst)] = FoldData::Pred(b);
}
pub fn set_u32_dst(&mut self, op: &impl DstsAsSlice, dst: &Dst, u: u32) {
self.dsts[op.dst_idx(dst)] = FoldData::U32(u);
}
#[allow(dead_code)]
pub fn set_f32_dst(&mut self, op: &impl DstsAsSlice, dst: &Dst, f: f32) {
self.set_u32_dst(op, dst, f.to_bits());
}
#[allow(dead_code)]
pub fn set_f64_dst(&mut self, op: &impl DstsAsSlice, dst: &Dst, f: f64) {
let u = f.to_bits();
let v = [u as u32, (u >> 32) as u32];
self.dsts[op.dst_idx(dst)] = FoldData::Vec2(v);
}
}
pub trait Foldable: SrcsAsSlice + DstsAsSlice {
fn fold(&self, sm: &dyn ShaderModel, f: &mut OpFoldData<'_>);
}
pub trait DisplayOp: DstsAsSlice {
fn fmt_dsts(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt_dst_slice(f, self.dsts_as_slice())
@ -3095,7 +3203,7 @@ impl DisplayOp for OpFlo {
impl_display_for_op!(OpFlo);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
#[derive(Clone, SrcsAsSlice, DstsAsSlice)]
pub struct OpIAbs {
#[dst_type(GPR)]
pub dst: Dst,
@ -3104,6 +3212,14 @@ pub struct OpIAbs {
pub src: Src,
}
impl Foldable for OpIAbs {
fn fold(&self, _sm: &dyn ShaderModel, f: &mut OpFoldData<'_>) {
let src = f.get_u32_src(self, &self.src);
let dst = (src as i32).abs() as u32;
f.set_u32_dst(self, &self.dst, dst);
}
}
impl DisplayOp for OpIAbs {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "iabs {}", self.src)