nak/from_nir: Turn srcs into a closure

Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34794>
This commit is contained in:
Mel Henning 2025-04-30 18:36:33 -04:00 committed by Marge Bot
parent 854b2d5882
commit f21557154b
2 changed files with 154 additions and 128 deletions

View file

@ -633,7 +633,7 @@ impl<'a> ShaderFromNir<'a> {
}
let nir_srcs = alu.srcs_as_slice();
let mut srcs: Vec<Src> = Vec::new();
let mut srcs_vec: Vec<Option<Src>> = Vec::new();
for (i, alu_src) in nir_srcs.iter().enumerate() {
let bit_size = alu_src.src.bit_size();
let comps = alu.src_components(i.try_into().unwrap());
@ -643,7 +643,7 @@ impl<'a> ShaderFromNir<'a> {
1 => {
assert!(comps == 1);
let s = usize::from(alu_src.swizzle[0]);
srcs.push(ssa[s].into());
srcs_vec.push(Some(ssa[s].into()));
}
8 | 16 => {
let num_bytes = usize::from(comps * (bit_size / 8));
@ -676,21 +676,23 @@ impl<'a> ShaderFromNir<'a> {
}
}
srcs.push(b.prmt4(prmt_srcs, prmt).into());
srcs_vec.push(Some(b.prmt4(prmt_srcs, prmt).into()));
}
32 => {
assert!(comps == 1);
let s = usize::from(alu_src.swizzle[0]);
srcs.push(ssa[s].into());
srcs_vec.push(Some(ssa[s].into()));
}
64 => {
assert!(comps == 1);
let s = usize::from(alu_src.swizzle[0]);
srcs.push([ssa[s * 2], ssa[s * 2 + 1]].into());
srcs_vec.push(Some([ssa[s * 2], ssa[s * 2 + 1]].into()));
}
_ => panic!("Invalid bit size: {bit_size}"),
}
}
let mut srcs =
|i: usize| -> Src { std::mem::take(&mut srcs_vec[i]).unwrap() };
// Restricts an F16v2 source to just x if the ALU op is single-component. This
// must only be called for per-component sources (see nir_op_info::output_sizes
@ -705,35 +707,35 @@ impl<'a> ShaderFromNir<'a> {
let dst: SSARef = match alu.op {
nir_op_b2b1 => {
assert!(alu.get_src(0).bit_size() == 32);
b.isetp(IntCmpType::I32, IntCmpOp::Ne, srcs[0], 0.into())
b.isetp(IntCmpType::I32, IntCmpOp::Ne, srcs(0), 0.into())
.into()
}
nir_op_b2b32 | nir_op_b2i8 | nir_op_b2i16 | nir_op_b2i32 => {
b.sel(srcs[0].bnot(), 0.into(), 1.into()).into()
b.sel(srcs(0).bnot(), 0.into(), 1.into()).into()
}
nir_op_b2i64 => {
let lo = b.sel(srcs[0].bnot(), 0.into(), 1.into());
let lo = b.sel(srcs(0).bnot(), 0.into(), 1.into());
let hi = b.copy(0.into());
[lo, hi].into()
}
nir_op_b2f16 => {
b.sel(srcs[0].bnot(), 0.into(), 0x3c00.into()).into()
b.sel(srcs(0).bnot(), 0.into(), 0x3c00.into()).into()
}
nir_op_b2f32 => {
b.sel(srcs[0].bnot(), 0.0_f32.into(), 1.0_f32.into()).into()
b.sel(srcs(0).bnot(), 0.0_f32.into(), 1.0_f32.into()).into()
}
nir_op_b2f64 => {
let lo = b.copy(0.into());
let hi = b.sel(srcs[0].bnot(), 0.into(), 0x3ff00000.into());
let hi = b.sel(srcs(0).bnot(), 0.into(), 0x3ff00000.into());
[lo, hi].into()
}
nir_op_bcsel => b.sel(srcs[0], srcs[1], srcs[2]).into(),
nir_op_bcsel => b.sel(srcs(0), srcs(1), srcs(2)).into(),
nir_op_bfm => {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpBMsk {
dst: dst.into(),
pos: srcs[1],
width: srcs[0],
pos: srcs(1),
width: srcs(0),
wrap: true,
});
dst.into()
@ -742,16 +744,16 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpPopC {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
});
dst.into()
}
nir_op_bitfield_reverse => b.brev(srcs[0]).into(),
nir_op_bitfield_reverse => b.brev(srcs(0)).into(),
nir_op_ibitfield_extract | nir_op_ubitfield_extract => {
let range = b.alloc_ssa(RegFile::GPR);
b.push_op(OpPrmt {
dst: range.into(),
srcs: [srcs[1], srcs[2]],
srcs: [srcs(1), srcs(2)],
sel: 0x0040.into(),
mode: PrmtMode::Index,
});
@ -759,7 +761,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpBfe {
dst: dst.into(),
base: srcs[0],
base: srcs(0),
signed: !matches!(alu.op, nir_op_ubitfield_extract),
range: range.into(),
reverse: false,
@ -777,25 +779,25 @@ impl<'a> ShaderFromNir<'a> {
assert!(elem < 4);
let byte = elem;
let zero = 4;
b.prmt(srcs[0], 0.into(), [byte, zero, zero, zero])
b.prmt(srcs(0), 0.into(), [byte, zero, zero, zero])
}
nir_op_extract_i8 => {
assert!(elem < 4);
let byte = elem;
let sign = byte | 0x8;
b.prmt(srcs[0], 0.into(), [byte, sign, sign, sign])
b.prmt(srcs(0), 0.into(), [byte, sign, sign, sign])
}
nir_op_extract_u16 => {
assert!(elem < 2);
let byte = elem * 2;
let zero = 4;
b.prmt(srcs[0], 0.into(), [byte, byte + 1, zero, zero])
b.prmt(srcs(0), 0.into(), [byte, byte + 1, zero, zero])
}
nir_op_extract_i16 => {
assert!(elem < 2);
let byte = elem * 2;
let sign = (byte + 1) | 0x8;
b.prmt(srcs[0], 0.into(), [byte, byte + 1, sign, sign])
b.prmt(srcs(0), 0.into(), [byte, byte + 1, sign, sign])
}
_ => panic!("Unknown extract op: {}", alu.op),
}
@ -811,7 +813,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa_vec(RegFile::GPR, dst_bits.div_ceil(32));
b.push_op(OpF2F {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
src_type: FloatType::from_bits(src_bits.into()),
dst_type: dst_type,
rnd_mode: match alu.op {
@ -830,7 +832,7 @@ impl<'a> ShaderFromNir<'a> {
dst
}
nir_op_find_lsb => {
let rev = b.brev(srcs[0]);
let rev = b.brev(srcs(0));
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpFlo {
dst: dst.into(),
@ -855,7 +857,7 @@ impl<'a> ShaderFromNir<'a> {
let tmp_type = IntType::from_bits(32, dst_is_signed);
b.push_op(OpF2I {
dst: tmp.into(),
src: srcs[0],
src: srcs(0),
src_type,
dst_type: tmp_type,
rnd_mode: FRndMode::Zero,
@ -873,7 +875,7 @@ impl<'a> ShaderFromNir<'a> {
} else {
b.push_op(OpF2I {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
src_type,
dst_type,
rnd_mode: FRndMode::Zero,
@ -884,9 +886,9 @@ impl<'a> ShaderFromNir<'a> {
}
nir_op_fabs | nir_op_fadd | nir_op_fneg => {
let (x, y) = match alu.op {
nir_op_fabs => (Src::new_zero().fneg(), srcs[0].fabs()),
nir_op_fadd => (srcs[0], srcs[1]),
nir_op_fneg => (Src::new_zero().fneg(), srcs[0].fneg()),
nir_op_fabs => (Src::new_zero().fneg(), srcs(0).fabs()),
nir_op_fadd => (srcs(0), srcs(1)),
nir_op_fneg => (Src::new_zero().fneg(), srcs(0).fneg()),
_ => panic!("Unhandled case"),
};
let ftype = FloatType::from_bits(alu.def.bit_size().into());
@ -943,7 +945,7 @@ impl<'a> ShaderFromNir<'a> {
);
b.push_op(OpFRnd {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
src_type: ty,
dst_type: ty,
rnd_mode,
@ -953,7 +955,7 @@ impl<'a> ShaderFromNir<'a> {
assert!(alu.def.bit_size() == 32);
b.push_op(OpF2F {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
src_type: ty,
dst_type: ty,
rnd_mode,
@ -964,7 +966,7 @@ impl<'a> ShaderFromNir<'a> {
}
dst.into()
}
nir_op_fcos => b.fcos(srcs[0]).into(),
nir_op_fcos => b.fcos(srcs(0)).into(),
nir_op_feq | nir_op_fge | nir_op_flt | nir_op_fneu => {
let src_type =
FloatType::from_bits(alu.get_src(0).bit_size().into());
@ -984,7 +986,7 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.into(),
set_op: PredSetOp::And,
cmp_op: cmp_op,
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
accum: SrcRef::True.into(),
});
} else if alu.get_src(0).bit_size() == 32 {
@ -993,7 +995,7 @@ impl<'a> ShaderFromNir<'a> {
dst: dst.into(),
set_op: PredSetOp::And,
cmp_op: cmp_op,
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
accum: SrcRef::True.into(),
ftz: self.float_ctl[src_type].ftz,
});
@ -1014,8 +1016,8 @@ impl<'a> ShaderFromNir<'a> {
set_op: PredSetOp::And,
cmp_op: cmp_op,
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
restrict_f16v2_src(srcs(0)),
restrict_f16v2_src(srcs(1)),
],
accum: SrcRef::True.into(),
ftz: self.float_ctl[src_type].ftz,
@ -1029,7 +1031,7 @@ impl<'a> ShaderFromNir<'a> {
}
dst
}
nir_op_fexp2 => b.fexp2(srcs[0]).into(),
nir_op_fexp2 => b.fexp2(srcs(0)).into(),
nir_op_ffma => {
let ftype = FloatType::from_bits(alu.def.bit_size().into());
let dst;
@ -1038,14 +1040,14 @@ impl<'a> ShaderFromNir<'a> {
dst = b.alloc_ssa_vec(RegFile::GPR, 2);
b.push_op(OpDFma {
dst: dst.into(),
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
rnd_mode: self.float_ctl[ftype].rnd_mode,
});
} else if alu.def.bit_size() == 32 {
dst = b.alloc_ssa_vec(RegFile::GPR, 1);
b.push_op(OpFFma {
dst: dst.into(),
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
saturate: self.try_saturate_alu_dst(&alu.def),
rnd_mode: self.float_ctl[ftype].rnd_mode,
// The hardware doesn't like FTZ+DNZ and DNZ implies FTZ
@ -1062,9 +1064,9 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpHFma2 {
dst: dst.into(),
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
restrict_f16v2_src(srcs[2]),
restrict_f16v2_src(srcs(0)),
restrict_f16v2_src(srcs(1)),
restrict_f16v2_src(srcs(2)),
],
saturate: self.try_saturate_alu_dst(&alu.def),
ftz: self.float_ctl[ftype].ftz,
@ -1083,7 +1085,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpFFma {
dst: dst.into(),
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
saturate: self.try_saturate_alu_dst(&alu.def),
rnd_mode: self.float_ctl.fp32.rnd_mode,
// The hardware doesn't like FTZ+DNZ and DNZ implies FTZ
@ -1095,7 +1097,7 @@ impl<'a> ShaderFromNir<'a> {
}
nir_op_flog2 => {
assert!(alu.def.bit_size() == 32);
b.mufu(MuFuOp::Log2, srcs[0]).into()
b.mufu(MuFuOp::Log2, srcs(0)).into()
}
nir_op_fmax | nir_op_fmin => {
let dst;
@ -1103,14 +1105,14 @@ impl<'a> ShaderFromNir<'a> {
dst = b.alloc_ssa_vec(RegFile::GPR, 2);
b.push_op(OpDMnMx {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
min: (alu.op == nir_op_fmin).into(),
});
} else if alu.def.bit_size() == 32 {
dst = b.alloc_ssa_vec(RegFile::GPR, 1);
b.push_op(OpFMnMx {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
min: (alu.op == nir_op_fmin).into(),
ftz: self.float_ctl.fp32.ftz,
});
@ -1119,8 +1121,8 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpHMnMx2 {
dst: dst.into(),
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
restrict_f16v2_src(srcs(0)),
restrict_f16v2_src(srcs(1)),
],
min: (alu.op == nir_op_fmin).into(),
ftz: self.float_ctl.fp16.ftz,
@ -1138,14 +1140,14 @@ impl<'a> ShaderFromNir<'a> {
dst = b.alloc_ssa_vec(RegFile::GPR, 2);
b.push_op(OpDMul {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
rnd_mode: self.float_ctl[ftype].rnd_mode,
});
} else if alu.def.bit_size() == 32 {
dst = b.alloc_ssa_vec(RegFile::GPR, 1);
b.push_op(OpFMul {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
saturate: self.try_saturate_alu_dst(&alu.def),
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
@ -1160,8 +1162,8 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpHMul2 {
dst: dst.into(),
srcs: [
restrict_f16v2_src(srcs[0]),
restrict_f16v2_src(srcs[1]),
restrict_f16v2_src(srcs(0)),
restrict_f16v2_src(srcs(1)),
],
saturate: self.try_saturate_alu_dst(&alu.def),
ftz: self.float_ctl[ftype].ftz,
@ -1179,7 +1181,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpFMul {
dst: dst.into(),
srcs: [srcs[0], srcs[1]],
srcs: [srcs(0), srcs(1)],
saturate: self.try_saturate_alu_dst(&alu.def),
rnd_mode: self.float_ctl.fp32.rnd_mode,
// The hardware doesn't like FTZ+DNZ and DNZ implies FTZ
@ -1193,7 +1195,7 @@ impl<'a> ShaderFromNir<'a> {
let tmp = b.alloc_ssa(RegFile::GPR);
b.push_op(OpF2F {
dst: tmp.into(),
src: srcs[0],
src: srcs(0),
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: FRndMode::NearestEven,
@ -1218,12 +1220,12 @@ impl<'a> ShaderFromNir<'a> {
// that manually
let denorm = b.fsetp(
FloatCmpOp::OrdLt,
srcs[0].fabs(),
srcs(0).fabs(),
0x38800000.into(),
);
// Get the correctly signed zero
let zero =
b.lop2(LogicOp2::And, srcs[0], 0x80000000.into());
b.lop2(LogicOp2::And, srcs(0), 0x80000000.into());
b.sel(denorm.into(), zero.into(), dst.into())
} else {
dst
@ -1232,22 +1234,22 @@ impl<'a> ShaderFromNir<'a> {
}
nir_op_frcp => {
assert!(alu.def.bit_size() == 32);
b.mufu(MuFuOp::Rcp, srcs[0]).into()
b.mufu(MuFuOp::Rcp, srcs(0)).into()
}
nir_op_frsq => {
assert!(alu.def.bit_size() == 32);
b.mufu(MuFuOp::Rsq, srcs[0]).into()
b.mufu(MuFuOp::Rsq, srcs(0)).into()
}
nir_op_fsat => {
let ftype = FloatType::from_bits(alu.def.bit_size().into());
if self.alu_src_is_saturated(&alu.srcs_as_slice()[0]) {
b.copy(srcs[0]).into()
b.copy(srcs(0)).into()
} else if alu.def.bit_size() == 32 {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpFAdd {
dst: dst.into(),
srcs: [srcs[0], 0.into()],
srcs: [srcs(0), 0.into()],
saturate: true,
rnd_mode: self.float_ctl[ftype].rnd_mode,
ftz: self.float_ctl[ftype].ftz,
@ -1261,7 +1263,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpHAdd2 {
dst: dst.into(),
srcs: [restrict_f16v2_src(srcs[0]), 0.into()],
srcs: [restrict_f16v2_src(srcs(0)), 0.into()],
saturate: true,
ftz: self.float_ctl[ftype].ftz,
f32: false,
@ -1272,19 +1274,20 @@ impl<'a> ShaderFromNir<'a> {
}
}
nir_op_fsign => {
let src0 = srcs(0);
if alu.def.bit_size() == 64 {
let lz = b.dsetp(FloatCmpOp::OrdLt, srcs[0], 0.into());
let gz = b.dsetp(FloatCmpOp::OrdGt, srcs[0], 0.into());
let lz = b.dsetp(FloatCmpOp::OrdLt, src0.clone(), 0.into());
let gz = b.dsetp(FloatCmpOp::OrdGt, src0, 0.into());
let hi = b.sel(lz.into(), 0xbff00000.into(), 0.into());
let hi = b.sel(gz.into(), 0x3ff00000.into(), hi.into());
let lo = b.copy(0.into());
[lo, hi].into()
} else if alu.def.bit_size() == 32 {
let lz = b.fset(FloatCmpOp::OrdLt, srcs[0], 0.into());
let gz = b.fset(FloatCmpOp::OrdGt, srcs[0], 0.into());
let lz = b.fset(FloatCmpOp::OrdLt, src0.clone(), 0.into());
let gz = b.fset(FloatCmpOp::OrdGt, src0, 0.into());
b.fadd(gz.into(), Src::from(lz).fneg()).into()
} else if alu.def.bit_size() == 16 {
let x = restrict_f16v2_src(srcs[0]);
let x = restrict_f16v2_src(src0);
let lz = restrict_f16v2_src(
b.hset2(FloatCmpOp::OrdLt, x, 0.into()).into(),
@ -1298,8 +1301,8 @@ impl<'a> ShaderFromNir<'a> {
panic!("Unsupported float type: f{}", alu.def.bit_size());
}
}
nir_op_fsin => b.fsin(srcs[0]).into(),
nir_op_fsqrt => b.mufu(MuFuOp::Sqrt, srcs[0]).into(),
nir_op_fsin => b.fsin(srcs(0)).into(),
nir_op_fsqrt => b.mufu(MuFuOp::Sqrt, srcs(0)).into(),
nir_op_i2f16 | nir_op_i2f32 | nir_op_i2f64 => {
let src_bits = alu.get_src(0).src.bit_size();
let dst_bits = alu.def.bit_size();
@ -1307,7 +1310,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa_vec(RegFile::GPR, dst_bits.div_ceil(32));
b.push_op(OpI2F {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
dst_type: dst_type,
src_type: IntType::from_bits(src_bits.into(), true),
rnd_mode: self.float_ctl[dst_type].rnd_mode,
@ -1347,10 +1350,10 @@ impl<'a> ShaderFromNir<'a> {
let prmt_lo: [u8; 4] = prmt[0..4].try_into().unwrap();
let prmt_hi: [u8; 4] = prmt[4..8].try_into().unwrap();
let src = srcs[0].as_ssa().unwrap();
let src = srcs(0).to_ssa();
if src_bits == 64 {
if dst_bits == 64 {
*src
src.into()
} else {
b.prmt(src[0].into(), src[1].into(), prmt_lo).into()
}
@ -1364,26 +1367,26 @@ impl<'a> ShaderFromNir<'a> {
}
}
}
nir_op_iabs => b.iabs(srcs[0]).into(),
nir_op_iabs => b.iabs(srcs(0)).into(),
nir_op_iadd => match alu.def.bit_size {
32 => b.iadd(srcs[0], srcs[1], 0.into()).into(),
64 => b.iadd64(srcs[0], srcs[1], 0.into()),
32 => b.iadd(srcs(0), srcs(1), 0.into()).into(),
64 => b.iadd64(srcs(0), srcs(1), 0.into()),
x => panic!("unsupported bit size for nir_op_iadd: {x}"),
},
nir_op_iadd3 => match alu.def.bit_size {
32 => b.iadd(srcs[0], srcs[1], srcs[2]).into(),
64 => b.iadd64(srcs[0], srcs[1], srcs[2]),
32 => b.iadd(srcs(0), srcs(1), srcs(2)).into(),
64 => b.iadd64(srcs(0), srcs(1), srcs(2)),
x => panic!("unsupported bit size for nir_op_iadd3: {x}"),
},
nir_op_iand => b.lop2(LogicOp2::And, srcs[0], srcs[1]).into(),
nir_op_iand => b.lop2(LogicOp2::And, srcs(0), srcs(1)).into(),
nir_op_ieq => {
if alu.get_src(0).bit_size() == 1 {
b.lop2(LogicOp2::Xor, srcs[0], srcs[1].bnot()).into()
b.lop2(LogicOp2::Xor, srcs(0), srcs(1).bnot()).into()
} else if alu.get_src(0).bit_size() == 64 {
b.isetp64(IntCmpType::I32, IntCmpOp::Eq, srcs[0], srcs[1])
b.isetp64(IntCmpType::I32, IntCmpOp::Eq, srcs(0), srcs(1))
} else {
assert!(alu.get_src(0).bit_size() == 32);
b.isetp(IntCmpType::I32, IntCmpOp::Eq, srcs[0], srcs[1])
b.isetp(IntCmpType::I32, IntCmpOp::Eq, srcs(0), srcs(1))
.into()
}
}
@ -1392,7 +1395,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpFlo {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
signed: match alu.op {
nir_op_ifind_msb | nir_op_ifind_msb_rev => true,
nir_op_ufind_msb | nir_op_ufind_msb_rev => false,
@ -1407,8 +1410,8 @@ impl<'a> ShaderFromNir<'a> {
dst.into()
}
nir_op_ige | nir_op_ilt | nir_op_uge | nir_op_ult => {
let x = *srcs[0].as_ssa().unwrap();
let y = *srcs[1].as_ssa().unwrap();
let x = *srcs(0).as_ssa().unwrap();
let y = *srcs(1).as_ssa().unwrap();
let (cmp_type, cmp_op) = match alu.op {
nir_op_ige => (IntCmpType::I32, IntCmpOp::Ge),
nir_op_ilt => (IntCmpType::I32, IntCmpOp::Lt),
@ -1428,7 +1431,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpIMad {
dst: dst.into(),
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
signed: false,
});
dst.into()
@ -1442,60 +1445,60 @@ impl<'a> ShaderFromNir<'a> {
_ => panic!("Not an integer min/max"),
};
assert!(alu.def.bit_size() == 32);
b.imnmx(tp, srcs[0], srcs[1], min.into()).into()
b.imnmx(tp, srcs(0), srcs(1), min.into()).into()
}
nir_op_imul => {
assert!(alu.def.bit_size() == 32);
b.imul(srcs[0], srcs[1]).into()
b.imul(srcs(0), srcs(1)).into()
}
nir_op_imul_2x32_64 | nir_op_umul_2x32_64 => {
let signed = alu.op == nir_op_imul_2x32_64;
b.imul_2x32_64(srcs[0], srcs[1], signed)
b.imul_2x32_64(srcs(0), srcs(1), signed)
}
nir_op_imul_high | nir_op_umul_high => {
let signed = alu.op == nir_op_imul_high;
let dst64 = b.imul_2x32_64(srcs[0], srcs[1], signed);
let dst64 = b.imul_2x32_64(srcs(0), srcs(1), signed);
dst64[1].into()
}
nir_op_ine => {
if alu.get_src(0).bit_size() == 1 {
b.lop2(LogicOp2::Xor, srcs[0], srcs[1]).into()
b.lop2(LogicOp2::Xor, srcs(0), srcs(1)).into()
} else if alu.get_src(0).bit_size() == 64 {
b.isetp64(IntCmpType::I32, IntCmpOp::Ne, srcs[0], srcs[1])
b.isetp64(IntCmpType::I32, IntCmpOp::Ne, srcs(0), srcs(1))
} else {
assert!(alu.get_src(0).bit_size() == 32);
b.isetp(IntCmpType::I32, IntCmpOp::Ne, srcs[0], srcs[1])
b.isetp(IntCmpType::I32, IntCmpOp::Ne, srcs(0), srcs(1))
.into()
}
}
nir_op_ineg => {
if alu.def.bit_size == 64 {
b.ineg64(srcs[0])
b.ineg64(srcs(0))
} else {
assert!(alu.def.bit_size() == 32);
b.ineg(srcs[0]).into()
b.ineg(srcs(0)).into()
}
}
nir_op_inot => {
if alu.def.bit_size() == 1 {
b.lop2(LogicOp2::PassB, true.into(), srcs[0].bnot()).into()
b.lop2(LogicOp2::PassB, true.into(), srcs(0).bnot()).into()
} else {
assert!(alu.def.bit_size() == 32);
b.lop2(LogicOp2::PassB, 0.into(), srcs[0].bnot()).into()
b.lop2(LogicOp2::PassB, 0.into(), srcs(0).bnot()).into()
}
}
nir_op_ior => b.lop2(LogicOp2::Or, srcs[0], srcs[1]).into(),
nir_op_ior => b.lop2(LogicOp2::Or, srcs(0), srcs(1)).into(),
nir_op_ishl => {
if alu.def.bit_size() == 64 {
let shift = if let Some(s) = nir_srcs[1].comp_as_uint(0) {
(s as u32).into()
} else {
srcs[1]
srcs(1)
};
b.shl64(srcs[0], shift)
b.shl64(srcs(0), shift)
} else {
assert!(alu.def.bit_size() == 32);
b.shl(srcs[0], srcs[1]).into()
b.shl(srcs(0), srcs(1)).into()
}
}
nir_op_ishr => {
@ -1503,17 +1506,17 @@ impl<'a> ShaderFromNir<'a> {
let shift = if let Some(s) = nir_srcs[1].comp_as_uint(0) {
(s as u32).into()
} else {
srcs[1]
srcs(1)
};
b.shr64(srcs[0], shift, true)
b.shr64(srcs(0), shift, true)
} else {
assert!(alu.def.bit_size() == 32);
b.shr(srcs[0], srcs[1], true).into()
b.shr(srcs(0), srcs(1), true).into()
}
}
nir_op_lea_nv => {
let src_a = srcs[1];
let src_b = srcs[0];
let src_a = srcs(1);
let src_b = srcs(0);
let shift = nir_srcs[2].comp_as_uint(0).unwrap() as u8;
match alu.def.bit_size {
32 => b.lea(src_a, src_b, shift).into(),
@ -1522,11 +1525,11 @@ impl<'a> ShaderFromNir<'a> {
}
}
nir_op_isub => match alu.def.bit_size {
32 => b.iadd(srcs[0], srcs[1].ineg(), 0.into()).into(),
64 => b.iadd64(srcs[0], srcs[1].ineg(), 0.into()),
32 => b.iadd(srcs(0), srcs(1).ineg(), 0.into()).into(),
64 => b.iadd64(srcs(0), srcs(1).ineg(), 0.into()),
x => panic!("unsupported bit size for nir_op_iadd: {x}"),
},
nir_op_ixor => b.lop2(LogicOp2::Xor, srcs[0], srcs[1]).into(),
nir_op_ixor => b.lop2(LogicOp2::Xor, srcs(0), srcs(1)).into(),
nir_op_pack_half_2x16_split | nir_op_pack_half_2x16_rtz_split => {
assert!(alu.get_src(0).bit_size() == 32);
@ -1540,7 +1543,7 @@ impl<'a> ShaderFromNir<'a> {
let result = b.alloc_ssa(RegFile::GPR);
b.push_op(OpF2FP {
dst: result.into(),
srcs: [srcs[1], srcs[0]],
srcs: [srcs(1), srcs(0)],
rnd_mode: rnd_mode,
});
@ -1551,7 +1554,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpF2F {
dst: low.into(),
src: srcs[0],
src: srcs(0),
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: rnd_mode,
@ -1565,7 +1568,7 @@ impl<'a> ShaderFromNir<'a> {
assert!(matches!(src_type, FloatType::F32));
b.push_op(OpF2F {
dst: high.into(),
src: srcs[1],
src: srcs(1),
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: rnd_mode,
@ -1581,8 +1584,8 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa(RegFile::GPR);
b.push_op(OpPrmt {
dst: dst.into(),
srcs: [srcs[1], srcs[2]],
sel: srcs[0],
srcs: [srcs(1), srcs(2)],
sel: srcs(0),
mode: PrmtMode::Index,
});
dst.into()
@ -1592,7 +1595,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpIDp4 {
dst: dst.into(),
src_types: [IntType::I8, IntType::I8],
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
});
dst.into()
}
@ -1601,7 +1604,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpIDp4 {
dst: dst.into(),
src_types: [IntType::I8, IntType::U8],
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
});
dst.into()
}
@ -1610,7 +1613,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpIDp4 {
dst: dst.into(),
src_types: [IntType::U8, IntType::U8],
srcs: [srcs[0], srcs[1], srcs[2]],
srcs: [srcs(0), srcs(1), srcs(2)],
});
dst.into()
}
@ -1621,7 +1624,7 @@ impl<'a> ShaderFromNir<'a> {
let dst = b.alloc_ssa_vec(RegFile::GPR, dst_bits.div_ceil(32));
b.push_op(OpI2F {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
dst_type: dst_type,
src_type: IntType::from_bits(src_bits.into(), false),
rnd_mode: self.float_ctl[dst_type].rnd_mode,
@ -1629,8 +1632,8 @@ impl<'a> ShaderFromNir<'a> {
dst
}
nir_op_uadd_sat => {
let x = srcs[0].as_ssa().unwrap();
let y = srcs[1].as_ssa().unwrap();
let x = srcs(0).to_ssa();
let y = srcs(1).to_ssa();
let sum_lo = b.alloc_ssa(RegFile::GPR);
let ovf_lo = b.alloc_ssa(RegFile::Pred);
b.push_op(OpIAdd3 {
@ -1658,8 +1661,8 @@ impl<'a> ShaderFromNir<'a> {
}
}
nir_op_usub_sat => {
let x = srcs[0].as_ssa().unwrap();
let y = srcs[1].as_ssa().unwrap();
let x = srcs(0).to_ssa();
let y = srcs(1).to_ssa();
let sum_lo = b.alloc_ssa(RegFile::GPR);
let ovf_lo = b.alloc_ssa(RegFile::Pred);
// The result of OpIAdd3X is the 33-bit value
@ -1691,17 +1694,17 @@ impl<'a> ShaderFromNir<'a> {
}
}
nir_op_unpack_32_2x16_split_x => {
b.prmt(srcs[0], 0.into(), [0, 1, 4, 4]).into()
b.prmt(srcs(0), 0.into(), [0, 1, 4, 4]).into()
}
nir_op_unpack_32_2x16_split_y => {
b.prmt(srcs[0], 0.into(), [2, 3, 4, 4]).into()
b.prmt(srcs(0), 0.into(), [2, 3, 4, 4]).into()
}
nir_op_unpack_64_2x32_split_x => {
let src0_x = srcs[0].as_ssa().unwrap()[0];
let src0_x = srcs(0).as_ssa().unwrap()[0];
b.copy(src0_x.into()).into()
}
nir_op_unpack_64_2x32_split_y => {
let src0_y = srcs[0].as_ssa().unwrap()[1];
let src0_y = srcs(0).as_ssa().unwrap()[1];
b.copy(src0_y.into()).into()
}
nir_op_unpack_half_2x16_split_x
@ -1711,7 +1714,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpF2F {
dst: dst.into(),
src: srcs[0],
src: srcs(0),
src_type: FloatType::F16,
dst_type: FloatType::F32,
rnd_mode: FRndMode::NearestEven,
@ -1727,12 +1730,12 @@ impl<'a> ShaderFromNir<'a> {
let shift = if let Some(s) = nir_srcs[1].comp_as_uint(0) {
(s as u32).into()
} else {
srcs[1]
srcs(1)
};
b.shr64(srcs[0], shift, false)
b.shr64(srcs(0), shift, false)
} else {
assert!(alu.def.bit_size() == 32);
b.shr(srcs[0], srcs[1], false).into()
b.shr(srcs(0), srcs(1), false).into()
}
}
_ => panic!("Unsupported ALU instruction: {}", alu.info().name()),

View file

@ -686,6 +686,14 @@ impl Dst {
}
}
#[allow(dead_code)]
pub fn to_ssa(self) -> SSARef {
match self {
Dst::SSA(r) => r,
_ => panic!("Expected ssa"),
}
}
pub fn iter_ssa(&self) -> slice::Iter<'_, SSAValue> {
match self {
Dst::None | Dst::Reg(_) => &[],
@ -849,6 +857,13 @@ impl SrcRef {
}
}
pub fn to_ssa(self) -> SSARef {
match self {
SrcRef::SSA(r) => r,
_ => panic!(),
}
}
pub fn as_u32(&self) -> Option<u32> {
match self {
SrcRef::Zero => Some(0),
@ -1219,6 +1234,14 @@ impl Src {
}
}
pub fn to_ssa(&self) -> SSARef {
if self.src_mod.is_none() {
self.src_ref.to_ssa()
} else {
panic!("Did not expect src_mod");
}
}
pub fn as_bool(&self) -> Option<bool> {
match self.src_ref {
SrcRef::True => Some(!self.src_mod.is_bnot()),