nak/legalize: Handle uniform sources in warp instructions

UGPRs in warp instructions are treated more like cbufs than GPRs.
You're only allowed to have one and it has to share space with the
possible cbuf or immediate.  This means we need to treat them as a "not
a register" case for warp instructions but as a register for uniform
instructions.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29591>
This commit is contained in:
Faith Ekstrand 2024-06-06 17:29:39 -05:00 committed by Marge Bot
parent 6ad49ca7d0
commit caf033b142

View file

@ -7,9 +7,62 @@ use crate::liveness::{BlockLiveness, Liveness, SimpleLiveness};
use std::collections::{HashMap, HashSet};
fn src_is_reg(src: &Src) -> bool {
fn copy_ssa(b: &mut impl SSABuilder, ssa: &mut SSAValue, reg_file: RegFile) {
let tmp = b.alloc_ssa(reg_file, 1)[0];
b.copy_to(tmp.into(), (*ssa).into());
*ssa = tmp;
}
fn src_is_upred_reg(src: &Src) -> bool {
match &src.src_ref {
SrcRef::True | SrcRef::False => false,
SrcRef::SSA(ssa) => {
assert!(ssa.comps() == 1);
match ssa[0].file() {
RegFile::Pred => false,
RegFile::UPred => true,
_ => panic!("Not a predicate source"),
}
}
SrcRef::Reg(_) => panic!("Not in SSA form"),
_ => panic!("Not a predicate source"),
}
}
fn copy_pred_ssa_if_uniform(b: &mut impl SSABuilder, ssa: &mut SSAValue) {
match ssa.file() {
RegFile::Pred => (),
RegFile::UPred => copy_ssa(b, ssa, RegFile::Pred),
_ => panic!("Not a predicate value"),
}
}
fn copy_pred_if_upred(b: &mut impl SSABuilder, pred: &mut Pred) {
match &mut pred.pred_ref {
PredRef::None => (),
PredRef::SSA(ssa) => {
copy_pred_ssa_if_uniform(b, ssa);
}
PredRef::Reg(_) => panic!("Not in SSA form"),
}
}
fn copy_src_if_upred(b: &mut impl SSABuilder, src: &mut Src) {
match &mut src.src_ref {
SrcRef::True | SrcRef::False => (),
SrcRef::SSA(ssa) => {
assert!(ssa.comps() == 1);
copy_pred_ssa_if_uniform(b, &mut ssa[0]);
}
SrcRef::Reg(_) => panic!("Not in SSA form"),
_ => panic!("Not a predicate source"),
}
}
fn src_is_reg(src: &Src, reg_file: RegFile) -> bool {
match src.src_ref {
SrcRef::Zero | SrcRef::True | SrcRef::False | SrcRef::SSA(_) => true,
SrcRef::Zero | SrcRef::True | SrcRef::False => true,
SrcRef::SSA(ssa) => ssa.file() == Some(reg_file),
SrcRef::Imm32(_) | SrcRef::CBuf(_) => false,
SrcRef::Reg(_) => panic!("Not in SSA form"),
}
@ -100,7 +153,7 @@ fn copy_alu_src_if_not_reg(
reg_file: RegFile,
src_type: SrcType,
) {
if !src_is_reg(src) {
if !src_is_reg(src, reg_file) {
copy_alu_src(b, src, reg_file, src_type);
}
}
@ -111,7 +164,7 @@ fn copy_alu_src_if_not_reg_or_imm(
reg_file: RegFile,
src_type: SrcType,
) {
if !src_is_reg(src) && !matches!(&src.src_ref, SrcRef::Imm32(_)) {
if !src_is_reg(src, reg_file) && !matches!(&src.src_ref, SrcRef::Imm32(_)) {
copy_alu_src(b, src, reg_file, src_type);
}
}
@ -138,13 +191,13 @@ fn copy_alu_src_if_both_not_reg(
reg_file: RegFile,
src_type: SrcType,
) {
if !src_is_reg(src1) && !src_is_reg(src2) {
if !src_is_reg(src1, reg_file) && !src_is_reg(src2, reg_file) {
copy_alu_src(b, src2, reg_file, src_type);
}
}
fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) -> bool {
if !src_is_reg(x) && src_is_reg(y) {
fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src, reg_file: RegFile) -> bool {
if !src_is_reg(x, reg_file) && src_is_reg(y, reg_file) {
std::mem::swap(x, y);
true
} else {
@ -225,6 +278,16 @@ fn copy_alu_src_if_fabs(
}
}
fn copy_ssa_ref_if_uniform(b: &mut impl SSABuilder, ssa_ref: &mut SSARef) {
for ssa in &mut ssa_ref[..] {
if ssa.is_uniform() {
let warp = b.alloc_ssa(ssa.file().to_warp(), 1)[0];
b.copy_to(warp.into(), (*ssa).into());
*ssa = warp;
}
}
}
fn legalize_sm50_instr(
b: &mut impl SSABuilder,
_bl: &impl BlockLiveness,
@ -252,17 +315,17 @@ fn legalize_sm50_instr(
}
Op::FAdd(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F32);
}
Op::FMul(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, &mut op.srcs[0], GPR, SrcType::F32);
}
Op::FSet(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, GPR) {
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F32);
@ -270,7 +333,7 @@ fn legalize_sm50_instr(
}
Op::FSetP(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, GPR) {
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F32);
@ -282,7 +345,7 @@ fn legalize_sm50_instr(
}
Op::ISetP(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, GPR) {
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::ALU);
@ -290,7 +353,7 @@ fn legalize_sm50_instr(
}
Op::Lop2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, &mut op.srcs[0], GPR, SrcType::ALU);
}
Op::Rro(op) => {
@ -302,7 +365,7 @@ fn legalize_sm50_instr(
}
Op::DAdd(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F64);
copy_alu_src_if_f20_overflow(b, src1, GPR, SrcType::F64);
}
@ -311,10 +374,10 @@ fn legalize_sm50_instr(
copy_alu_src_if_fabs(b, src0, SrcType::F64);
copy_alu_src_if_fabs(b, src1, SrcType::F64);
copy_alu_src_if_fabs(b, src2, SrcType::F64);
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F64);
copy_alu_src_if_f20_overflow(b, src1, GPR, SrcType::F64);
if src_is_reg(src1) {
if src_is_reg(src1, GPR) {
copy_alu_src_if_imm(b, src2, GPR, SrcType::F64);
} else {
copy_alu_src_if_not_reg(b, src2, GPR, SrcType::F64);
@ -322,7 +385,7 @@ fn legalize_sm50_instr(
}
Op::DMnMx(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F64);
copy_alu_src_if_f20_overflow(b, src1, GPR, SrcType::F64);
}
@ -330,13 +393,13 @@ fn legalize_sm50_instr(
let [ref mut src0, ref mut src1] = op.srcs;
copy_alu_src_if_fabs(b, src0, SrcType::F64);
copy_alu_src_if_fabs(b, src1, SrcType::F64);
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F64);
copy_alu_src_if_f20_overflow(b, src1, GPR, SrcType::F64);
}
Op::DSetP(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, GPR) {
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F64);
@ -344,7 +407,7 @@ fn legalize_sm50_instr(
}
Op::Sel(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, GPR) {
op.cond = op.cond.bnot();
}
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::ALU);
@ -358,7 +421,7 @@ fn legalize_sm50_instr(
Op::Vote(_) => {}
Op::IAdd2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::I32);
}
Op::I2F(op) => {
@ -372,9 +435,9 @@ fn legalize_sm50_instr(
}
Op::IMad(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::ALU);
if src_is_reg(src1) {
if src_is_reg(src1, GPR) {
copy_alu_src_if_imm(b, src2, GPR, SrcType::ALU);
} else {
copy_alu_src_if_i20_overflow(b, src1, GPR, SrcType::ALU);
@ -383,7 +446,7 @@ fn legalize_sm50_instr(
}
Op::IMul(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, GPR) {
op.signed.swap(0, 1);
}
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::ALU);
@ -393,7 +456,7 @@ fn legalize_sm50_instr(
}
Op::IMnMx(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::ALU);
}
Op::Ipa(op) => {
@ -409,7 +472,7 @@ fn legalize_sm50_instr(
}
Op::FMnMx(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F32);
copy_alu_src_if_f20_overflow(b, src1, GPR, SrcType::F32);
}
@ -423,7 +486,7 @@ fn legalize_sm50_instr(
copy_alu_src_if_fabs(b, src0, SrcType::F32);
copy_alu_src_if_fabs(b, src1, SrcType::F32);
copy_alu_src_if_fabs(b, src2, SrcType::F32);
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, GPR);
copy_alu_src_if_not_reg(b, src0, GPR, SrcType::F32);
copy_alu_src_if_not_reg(b, src2, GPR, SrcType::F32);
copy_alu_src_if_f20_overflow(b, src1, GPR, SrcType::F32);
@ -457,7 +520,7 @@ fn legalize_sm50_instr(
assert!(src.as_ssa().is_some());
}
SrcType::GPR => {
assert!(src_is_reg(src));
assert!(src_is_reg(src, GPR));
}
SrcType::ALU
| SrcType::F16
@ -499,28 +562,28 @@ fn legalize_sm70_instr(
match &mut instr.op {
Op::FAdd(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F32);
}
Op::FFma(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F32);
copy_alu_src_if_both_not_reg(b, src1, src2, gpr, SrcType::F32);
}
Op::FMnMx(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F32);
}
Op::FMul(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F32);
}
Op::FSet(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
@ -528,7 +591,7 @@ fn legalize_sm70_instr(
}
Op::FSetP(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
@ -536,12 +599,12 @@ fn legalize_sm70_instr(
}
Op::HAdd2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F16v2);
}
Op::HFma2(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F16v2);
copy_alu_src_if_not_reg(b, src1, gpr, SrcType::F16v2);
copy_alu_src_if_both_not_reg(b, src1, src2, gpr, SrcType::F16v2);
@ -553,12 +616,12 @@ fn legalize_sm70_instr(
}
Op::HMul2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F16v2);
}
Op::HSet2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
@ -566,7 +629,7 @@ fn legalize_sm70_instr(
}
Op::HSetP2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
@ -574,29 +637,29 @@ fn legalize_sm70_instr(
}
Op::HMnMx2(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F16v2);
}
Op::MuFu(_) => (), // Nothing to do
Op::DAdd(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F64);
}
Op::DFma(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F64);
copy_alu_src_if_both_not_reg(b, src1, src2, gpr, SrcType::F64);
}
Op::DMul(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::F64);
}
Op::DSetP(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
@ -609,8 +672,8 @@ fn legalize_sm70_instr(
Op::IAbs(_) => (),
Op::IAdd3(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src2, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
swap_srcs_if_not_reg(src2, src1, gpr);
if !src0.src_mod.is_none() && !src1.src_mod.is_none() {
let val = b.alloc_ssa(gpr, 1);
b.push_op(OpIAdd3 {
@ -625,8 +688,8 @@ fn legalize_sm70_instr(
}
Op::IAdd3X(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src2, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
swap_srcs_if_not_reg(src2, src1, gpr);
if !src0.src_mod.is_none() && !src1.src_mod.is_none() {
let val = b.alloc_ssa(gpr, 1);
b.push_op(OpIAdd3X {
@ -639,11 +702,15 @@ fn legalize_sm70_instr(
}
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::B32);
copy_alu_src_if_both_not_reg(b, src1, src2, gpr, SrcType::B32);
if !op.is_uniform() {
copy_src_if_upred(b, &mut op.carry[0]);
copy_src_if_upred(b, &mut op.carry[1]);
}
}
Op::IDp4(op) => {
let [ref mut src_type0, ref mut src_type1] = op.src_types;
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, gpr) {
std::mem::swap(src_type0, src_type1);
}
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::ALU);
@ -651,28 +718,32 @@ fn legalize_sm70_instr(
}
Op::IMad(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::ALU);
copy_alu_src_if_both_not_reg(b, src1, src2, gpr, SrcType::ALU);
}
Op::IMad64(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::ALU);
copy_alu_src_if_both_not_reg(b, src1, src2, gpr, SrcType::ALU);
}
Op::IMnMx(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
swap_srcs_if_not_reg(src0, src1);
swap_srcs_if_not_reg(src0, src1, gpr);
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::ALU);
}
Op::ISetP(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.cmp_op = op.cmp_op.flip();
}
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::ALU);
if !op.is_uniform() {
copy_src_if_upred(b, &mut op.low_cmp);
copy_src_if_upred(b, &mut op.accum);
}
}
Op::Lop3(op) => {
// Fold constants and modifiers if we can
@ -690,11 +761,11 @@ fn legalize_sm70_instr(
}
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
if !src_is_reg(src0) && src_is_reg(src1) {
if !src_is_reg(src0, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src0, src1);
op.op = LogicOp3::new_lut(&|x, y, z| op.op.eval(y, x, z))
}
if !src_is_reg(src2) && src_is_reg(src1) {
if !src_is_reg(src2, gpr) && src_is_reg(src1, gpr) {
std::mem::swap(src2, src1);
op.op = LogicOp3::new_lut(&|x, y, z| op.op.eval(x, z, y))
}
@ -719,8 +790,11 @@ fn legalize_sm70_instr(
copy_alu_src_if_not_reg(b, &mut op.srcs[1], gpr, SrcType::ALU);
}
Op::Sel(op) => {
if !op.is_uniform() {
copy_src_if_upred(b, &mut op.cond);
}
let [ref mut src0, ref mut src1] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
if swap_srcs_if_not_reg(src0, src1, gpr) {
op.cond = op.cond.bnot();
}
copy_alu_src_if_not_reg(b, src0, gpr, SrcType::ALU);
@ -741,6 +815,27 @@ fn legalize_sm70_instr(
src.src_ref = SrcRef::True;
}
}
if !op.is_uniform() {
// The warp form of plop3 allows a single uniform predicate in
// src2. If we have a uniform predicate anywhere, try to move it
// there.
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
if src_is_upred_reg(src0) && !src_is_upred_reg(src2) {
std::mem::swap(src0, src2);
for lop in &mut op.ops {
*lop = LogicOp3::new_lut(&|x, y, z| lop.eval(z, y, x))
}
}
if src_is_upred_reg(src1) && !src_is_upred_reg(src2) {
std::mem::swap(src1, src2);
for lop in &mut op.ops {
*lop = LogicOp3::new_lut(&|x, y, z| lop.eval(x, z, y))
}
}
copy_src_if_upred(b, src0);
copy_src_if_upred(b, src1);
}
}
Op::FSwzAdd(op) => {
let [ref mut src0, ref mut src1] = op.srcs;
@ -780,20 +875,27 @@ fn legalize_sm70_instr(
Op::OutFinal(op) => {
copy_alu_src_if_not_reg(b, &mut op.handle, gpr, SrcType::GPR);
}
Op::Ldc(_) => (), // Nothing to do
Op::Ldc(op) => {
copy_alu_src_if_not_reg(b, &mut op.offset, gpr, SrcType::GPR);
}
Op::BSync(_) => (),
Op::Vote(_) => (), // Nothing to do
Op::Vote(op) => {
copy_src_if_upred(b, &mut op.pred);
}
Op::Copy(_) => (), // Nothing to do
_ => {
let src_types = instr.src_types();
for (i, src) in instr.srcs_mut().iter_mut().enumerate() {
match src_types[i] {
SrcType::SSA => {
assert!(src.as_ssa().is_some());
}
SrcType::GPR => {
assert!(src_is_reg(src));
}
SrcType::SSA | SrcType::GPR => match &mut src.src_ref {
SrcRef::Zero | SrcRef::True | SrcRef::False => {
assert!(src_types[i] != SrcType::SSA);
}
SrcRef::SSA(ssa) => {
copy_ssa_ref_if_uniform(b, ssa);
}
_ => panic!("Unsupported source reference"),
},
SrcType::ALU
| SrcType::F16
| SrcType::F16v2
@ -819,6 +921,12 @@ fn legalize_instr(
ip: usize,
instr: &mut Instr,
) {
if matches!(&instr.op, Op::PhiDsts(_)) {
debug_assert!(instr.pred.is_true());
} else if !instr.is_uniform() {
copy_pred_if_upred(b, &mut instr.pred);
}
let src_types = instr.src_types();
for (i, src) in instr.srcs_mut().iter_mut().enumerate() {
*src = src.fold_imm(src_types[i]);