nak: Use a different inner struct type for each opcode

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24998>
This commit is contained in:
Faith Ekstrand 2023-01-30 20:53:18 -06:00 committed by Marge Bot
parent e994acdb36
commit b41b4bd7f5
5 changed files with 766 additions and 546 deletions

View file

@ -8,16 +8,6 @@ use crate::nak_ir::*;
use std::ops::Range;
trait SrcMod {
fn src_mod(&self) -> u8;
}
impl SrcMod for Src {
fn src_mod(&self) -> u8 {
0 /* TODO */
}
}
enum ALUSrc {
Imm(Immediate),
Reg(RegRef),
@ -94,11 +84,6 @@ impl SM75Instr {
self.set_field(range, reg.base_idx());
}
fn set_src_mod(&mut self, range: Range<usize>, src_mod: u8) {
assert!(range.len() == 2);
self.set_field(range, src_mod);
}
fn set_src_cb(&mut self, range: Range<usize>, cb: &CBufRef) {
let mut v = self.subset_mut(range);
v.set_field(0..16, cb.offset);
@ -150,82 +135,97 @@ impl SM75Instr {
fn encode_alu(
&mut self,
opcode: u16,
dst: Option<&Dst>,
src0: Option<&Src>,
src1: &Src,
src2: Option<&Src>,
dst: Option<Dst>,
src0: Option<ModSrc>,
src1: ModSrc,
src2: Option<ModSrc>,
) {
if let Some(dst) = dst {
self.set_dst_reg(*dst.as_reg().unwrap());
}
if let Some(src0) = src0 {
self.set_reg(24..32, *src0.as_reg().unwrap());
self.set_reg(24..32, *src0.src.as_reg().unwrap());
self.set_bit(72, src0.src_mod.has_neg());
self.set_bit(73, src0.src_mod.has_abs());
}
let form = match ALUSrc::from_src(src1, 1) {
let form = match ALUSrc::from_src(&src1.src, 1) {
ALUSrc::Reg(reg1) => {
if let Some(src2) = src2 {
match ALUSrc::from_src(src2, 1) {
match ALUSrc::from_src(&src2.src, 1) {
ALUSrc::Reg(reg2) => {
self.set_reg(32..40, reg1);
self.set_src_mod(62..64, src1.src_mod());
self.set_bit(62, src1.src_mod.has_abs());
self.set_bit(63, src1.src_mod.has_neg());
self.set_reg(64..72, reg2);
self.set_src_mod(74..76, src2.src_mod());
self.set_bit(74, src2.src_mod.has_abs());
self.set_bit(75, src2.src_mod.has_neg());
1_u8 /* form */
}
ALUSrc::UReg(reg2) => {
self.set_ureg(32..40, reg2);
self.set_src_mod(62..64, src2.src_mod());
self.set_bit(62, src2.src_mod.has_abs());
self.set_bit(63, src2.src_mod.has_neg());
self.set_reg(64..72, reg1);
self.set_src_mod(74..76, src1.src_mod());
self.set_bit(74, src1.src_mod.has_abs());
self.set_bit(75, src1.src_mod.has_neg());
7_u8 /* form */
}
ALUSrc::Imm(imm) => {
self.set_src_imm(32..64, &imm);
self.set_reg(64..72, reg1);
self.set_src_mod(74..76, src1.src_mod());
self.set_bit(74, src1.src_mod.has_abs());
self.set_bit(75, src1.src_mod.has_neg());
2_u8 /* form */
}
ALUSrc::CBuf(cb) => {
/* TODO set_src_cx */
self.set_src_cb(38..59, &cb);
self.set_src_mod(62..64, src2.src_mod());
self.set_bit(62, src2.src_mod.has_abs());
self.set_bit(63, src2.src_mod.has_neg());
self.set_reg(64..72, reg1);
self.set_src_mod(74..76, src1.src_mod());
self.set_bit(74, src1.src_mod.has_abs());
self.set_bit(75, src1.src_mod.has_neg());
3_u8 /* form */
}
_ => panic!("Invalid instruction form"),
}
} else {
self.set_reg(32..40, reg1);
self.set_src_mod(62..64, src1.src_mod());
self.set_bit(62, src1.src_mod.has_abs());
self.set_bit(63, src1.src_mod.has_neg());
1_u8 /* form */
}
}
ALUSrc::UReg(reg1) => {
self.set_ureg(32..40, reg1);
self.set_src_mod(62..64, src1.src_mod());
self.set_bit(62, src1.src_mod.has_abs());
self.set_bit(63, src1.src_mod.has_neg());
if let Some(src2) = src2 {
self.set_reg(64..72, *src2.as_reg().unwrap());
self.set_src_mod(74..76, src2.src_mod());
self.set_reg(64..72, *src2.src.as_reg().unwrap());
self.set_bit(74, src2.src_mod.has_abs());
self.set_bit(75, src2.src_mod.has_neg());
}
6_u8 /* form */
}
ALUSrc::Imm(imm) => {
self.set_src_imm(32..64, &imm);
if let Some(src2) = src2 {
self.set_reg(64..72, *src2.as_reg().unwrap());
self.set_src_mod(74..76, src2.src_mod());
self.set_reg(64..72, *src2.src.as_reg().unwrap());
self.set_bit(74, src2.src_mod.has_abs());
self.set_bit(75, src2.src_mod.has_neg());
}
4_u8 /* form */
}
ALUSrc::CBuf(cb) => {
self.set_src_cb(38..59, &cb);
self.set_src_mod(62..64, src1.src_mod());
self.set_bit(62, src1.src_mod.has_abs());
self.set_bit(63, src1.src_mod.has_neg());
if let Some(src2) = src2 {
self.set_reg(64..72, *src2.as_reg().unwrap());
self.set_src_mod(74..76, src2.src_mod());
self.set_reg(64..72, *src2.src.as_reg().unwrap());
self.set_bit(74, src2.src_mod.has_abs());
self.set_bit(75, src2.src_mod.has_neg());
}
5_u8 /* form */
}
@ -245,85 +245,19 @@ impl SM75Instr {
self.set_field(122..126, deps.reuse_mask);
}
fn encode_s2r(&mut self, instr: &Instr, idx: u8) {
self.set_opcode(0x919);
self.set_dst_reg(*instr.dst(0).as_reg().unwrap());
self.set_field(72..80, idx);
}
fn encode_mov(&mut self, instr: &Instr) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 1);
self.encode_alu(0x002, Some(instr.dst(0)), None, instr.src(0), None);
self.set_field(72..76, 0xf_u32 /* TODO: Quad lanes */);
}
fn encode_sel(&mut self, instr: &Instr) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 3);
self.encode_alu(
0x007,
Some(instr.dst(0)),
Some(instr.src(1)),
instr.src(2),
None,
);
self.set_pred_reg(87..90, *instr.src(0).as_reg().unwrap());
self.set_bit(90, false); /* not */
}
fn encode_iadd3(&mut self, instr: &Instr) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 3);
fn encode_iadd3(&mut self, op: &OpIAdd3) {
self.encode_alu(
0x010,
Some(instr.dst(0)),
Some(instr.src(0)),
instr.src(1),
Some(instr.src(2)),
Some(op.dst),
Some(op.mod_src(0)),
op.mod_src(1),
Some(op.mod_src(2)),
);
self.set_field(81..84, 7_u32); /* pred */
self.set_field(84..87, 7_u32); /* pred */
}
fn encode_lop3(&mut self, instr: &Instr, op: &LogicOp) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 3);
self.encode_alu(
0x012,
Some(instr.dst(0)),
Some(instr.src(0)),
instr.src(1),
Some(instr.src(2)),
);
self.set_field(72..80, op.lut);
self.set_bit(80, false); /* .PAND */
self.set_field(81..84, 7_u32); /* pred */
self.set_field(84..87, 7_u32); /* pred */
self.set_bit(90, true);
}
fn encode_plop3(&mut self, instr: &Instr, op: &LogicOp) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 3);
self.set_opcode(0x81c);
self.set_field(64..67, op.lut & 0x7);
self.set_field(72..77, op.lut >> 3);
self.set_pred_reg(68..71, *instr.src(2).as_reg().unwrap());
self.set_bit(71, false); /* NOT(src2) */
self.set_pred_reg(77..80, *instr.src(1).as_reg().unwrap());
self.set_bit(80, false); /* NOT(src1) */
self.set_pred_reg(81..84, *instr.dst(0).as_reg().unwrap());
self.set_field(84..87, 7_u8); /* Def1 */
self.set_pred_reg(87..90, *instr.src(0).as_reg().unwrap());
self.set_bit(90, false); /* NOT(src0) */
// self.set_pred_reg(81..84, *op.carry[0].as_reg().unwrap());
// self.set_pred_reg(84..87, *op.carry[1].as_reg().unwrap());
self.set_field(81..84, 0x7_u8);
self.set_field(84..87, 0x7_u8);
}
fn set_cmp_op(&mut self, range: Range<usize>, op: &CmpOp) {
@ -341,10 +275,14 @@ impl SM75Instr {
);
}
fn encode_isetp(&mut self, instr: &Instr, op: &IntCmpOp) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 2);
self.encode_alu(0x00c, None, Some(instr.src(0)), instr.src(1), None);
fn encode_isetp(&mut self, op: &OpISetP) {
self.encode_alu(
0x00c,
None,
Some(op.srcs[0].into()),
op.srcs[1].into(),
None,
);
self.set_field(
73..74,
@ -356,21 +294,35 @@ impl SM75Instr {
self.set_field(74..76, 0_u32); /* pred combine op */
self.set_cmp_op(76..79, &op.cmp_op);
self.set_pred_reg(81..84, *instr.dst(0).as_reg().unwrap());
self.set_pred_reg(81..84, *op.dst.as_reg().unwrap());
self.set_field(84..87, 7_u32); /* dst1 */
self.set_field(87..90, 7_u32); /* src pred */
self.set_bit(90, false); /* src pred neg */
}
fn encode_shl(&mut self, instr: &Instr) {
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 2);
fn encode_lop3(&mut self, op: &OpLop3) {
self.encode_alu(
0x012,
Some(op.dst),
Some(op.srcs[0].into()),
op.srcs[1].into(),
Some(op.srcs[2].into()),
);
self.set_field(72..80, op.op.lut);
self.set_bit(80, false); /* .PAND */
self.set_field(81..84, 7_u32); /* pred */
self.set_field(84..87, 7_u32); /* pred */
self.set_bit(90, true);
}
fn encode_shl(&mut self, op: &OpShl) {
self.encode_alu(
0x019,
Some(instr.dst(0)),
Some(instr.src(0)),
instr.src(1),
Some(op.dst),
Some(op.srcs[0].into()),
op.srcs[1].into(),
None,
);
@ -380,36 +332,39 @@ impl SM75Instr {
self.set_bit(80, false /* HI */);
}
fn encode_ald(&mut self, instr: &Instr, attr: &AttrAccess) {
self.set_opcode(0x321);
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 2);
self.set_dst_reg(*instr.dst(0).as_reg().unwrap());
self.set_reg(24..32, *instr.src(0).as_reg().unwrap());
self.set_reg(32..40, *instr.src(1).as_reg().unwrap());
self.set_field(40..50, attr.addr);
self.set_field(74..76, attr.comps - 1);
self.set_field(76..77, attr.patch);
self.set_field(77..78, attr.flags);
self.set_field(79..80, attr.out_load);
fn encode_mov(&mut self, op: &OpMov) {
self.encode_alu(0x002, Some(op.dst), None, op.src.into(), None);
self.set_field(72..76, op.quad_lanes);
}
fn encode_ast(&mut self, instr: &Instr, attr: &AttrAccess) {
self.set_opcode(0x322);
assert!(instr.num_dsts() == 0);
assert!(instr.num_srcs() == 3);
fn encode_sel(&mut self, op: &OpSel) {
self.encode_alu(
0x007,
Some(op.dst),
Some(op.srcs[0].into()),
op.srcs[1].into(),
None,
);
self.set_reg(32..40, *instr.src(0).as_reg().unwrap());
self.set_reg(24..32, *instr.src(1).as_reg().unwrap());
self.set_reg(64..72, *instr.src(2).as_reg().unwrap());
self.set_pred_reg(87..90, *op.cond.as_reg().unwrap());
self.set_bit(90, op.cond_mod.has_not());
}
self.set_field(40..50, attr.addr);
self.set_field(74..76, attr.comps - 1);
self.set_field(76..77, attr.patch);
self.set_field(77..78, attr.flags);
assert!(!attr.out_load);
fn encode_plop3(&mut self, op: &OpPLop3) {
self.set_opcode(0x81c);
self.set_field(64..67, op.op.lut & 0x7);
self.set_field(72..77, op.op.lut >> 3);
self.set_pred_reg(68..71, *op.srcs[2].as_reg().unwrap());
self.set_bit(71, op.src_mods[2].has_not());
self.set_pred_reg(77..80, *op.srcs[1].as_reg().unwrap());
self.set_bit(80, op.src_mods[1].has_not());
self.set_pred_reg(81..84, *op.dst.as_reg().unwrap());
self.set_field(84..87, 7_u8); /* Def1 */
self.set_pred_reg(87..90, *op.srcs[0].as_reg().unwrap());
self.set_bit(90, op.src_mods[0].has_not());
}
fn set_mem_access(&mut self, access: &MemAccess) {
@ -458,34 +413,56 @@ impl SM75Instr {
}
}
fn encode_ld(&mut self, instr: &Instr, access: &MemAccess) {
fn encode_ld(&mut self, op: &OpLd) {
self.set_opcode(0x980);
assert!(instr.num_dsts() == 1);
assert!(instr.num_srcs() == 1);
self.set_dst_reg(*instr.dst(0).as_reg().unwrap());
self.set_reg(24..32, *instr.src(0).as_reg().unwrap());
self.set_field(32..64, 0_u32 /* Immediate offset */);
self.set_dst_reg(*op.dst.as_reg().unwrap());
self.set_reg(24..32, *op.addr.as_reg().unwrap());
self.set_field(32..64, op.offset);
self.set_mem_access(access);
self.set_mem_access(&op.access);
}
fn encode_st(&mut self, instr: &Instr, access: &MemAccess) {
fn encode_st(&mut self, op: &OpSt) {
self.set_opcode(0x385);
assert!(instr.num_dsts() == 0);
assert!(instr.num_srcs() == 2);
self.set_reg(24..32, *instr.src(0).as_reg().unwrap());
self.set_field(32..64, 0_u32 /* Immediate offset */);
self.set_reg(64..72, *instr.src(1).as_reg().unwrap());
self.set_reg(24..32, *op.addr.as_reg().unwrap());
self.set_field(32..64, op.offset);
self.set_reg(64..72, *op.data.as_reg().unwrap());
self.set_mem_access(access);
self.set_mem_access(&op.access);
}
fn encode_exit(&mut self, instr: &Instr) {
fn encode_ald(&mut self, op: &OpALd) {
self.set_opcode(0x321);
self.set_dst_reg(*op.dst.as_reg().unwrap());
self.set_reg(24..32, *op.vtx.as_reg().unwrap());
self.set_reg(32..40, *op.offset.as_reg().unwrap());
self.set_field(40..50, op.access.addr);
self.set_field(74..76, op.access.comps - 1);
self.set_field(76..77, op.access.patch);
self.set_field(77..78, op.access.flags);
self.set_field(79..80, op.access.out_load);
}
fn encode_ast(&mut self, op: &OpASt) {
self.set_opcode(0x322);
self.set_reg(32..40, *op.data.as_reg().unwrap());
self.set_reg(24..32, *op.vtx.as_reg().unwrap());
self.set_reg(64..72, *op.offset.as_reg().unwrap());
self.set_field(40..50, op.access.addr);
self.set_field(74..76, op.access.comps - 1);
self.set_field(76..77, op.access.patch);
self.set_field(77..78, op.access.flags);
assert!(!op.access.out_load);
}
fn encode_exit(&mut self, op: &OpExit) {
self.set_opcode(0x94d);
assert!(instr.num_dsts() == 0);
assert!(instr.num_srcs() == 0);
/* ./.KEEPREFCOUNT/.PREEMPTED/.INVALID3 */
self.set_field(84..85, false);
@ -494,6 +471,12 @@ impl SM75Instr {
self.set_field(90..91, false); /* NOT */
}
fn encode_s2r(&mut self, op: &OpS2R) {
self.set_opcode(0x919);
self.set_dst_reg(*op.dst.as_reg().unwrap());
self.set_field(72..80, op.idx);
}
pub fn encode(instr: &Instr, sm: u8) -> [u32; 4] {
assert!(sm >= 75);
@ -503,19 +486,19 @@ impl SM75Instr {
};
match &instr.op {
Opcode::S2R(i) => si.encode_s2r(instr, *i),
Opcode::MOV => si.encode_mov(instr),
Opcode::SEL => si.encode_sel(instr),
Opcode::IADD3 => si.encode_iadd3(instr),
Opcode::LOP3(op) => si.encode_lop3(instr, &op),
Opcode::PLOP3(op) => si.encode_plop3(instr, &op),
Opcode::ISETP(op) => si.encode_isetp(instr, &op),
Opcode::SHL => si.encode_shl(instr),
Opcode::ALD(a) => si.encode_ald(instr, &a),
Opcode::AST(a) => si.encode_ast(instr, &a),
Opcode::LD(a) => si.encode_ld(instr, a),
Opcode::ST(a) => si.encode_st(instr, a),
Opcode::EXIT => si.encode_exit(instr),
Op::IAdd3(op) => si.encode_iadd3(&op),
Op::ISetP(op) => si.encode_isetp(&op),
Op::Lop3(op) => si.encode_lop3(&op),
Op::Shl(op) => si.encode_shl(&op),
Op::Mov(op) => si.encode_mov(&op),
Op::Sel(op) => si.encode_sel(&op),
Op::PLop3(op) => si.encode_plop3(&op),
Op::Ld(op) => si.encode_ld(&op),
Op::St(op) => si.encode_st(&op),
Op::ALd(op) => si.encode_ald(&op),
Op::ASt(op) => si.encode_ast(&op),
Op::Exit(op) => si.encode_exit(&op),
Op::S2R(op) => si.encode_s2r(&op),
_ => panic!("Unhandled instruction"),
}

View file

@ -90,9 +90,6 @@ impl<'a> ShaderFromNir<'a> {
self.instrs
.push(Instr::new_sel(dst, srcs[0], srcs[1], srcs[2]));
}
nir_op_fadd => {
self.instrs.push(Instr::new_fadd(dst, srcs[0], srcs[1]));
}
nir_op_iadd => {
self.instrs.push(Instr::new_iadd(dst, srcs[0], srcs[1]));
}

File diff suppressed because it is too large Load diff

View file

@ -174,3 +174,30 @@ pub fn derive_dsts_as_slice(input: TokenStream) -> TokenStream {
pub fn derive_src_mods_as_slice(input: TokenStream) -> TokenStream {
derive_as_slice(input, "SrcModsAsSlice", "src_mods", "SrcMod")
}
#[proc_macro_derive(Display)]
pub fn enum_derive_display(input: TokenStream) -> TokenStream {
let DeriveInput { ident, data, .. } = parse_macro_input!(input);
if let Data::Enum(e) = data {
let mut cases = TokenStream2::new();
for v in e.variants {
let case = v.ident;
cases.extend(quote! {
#ident::#case(x) => x.fmt(f),
});
}
quote! {
impl fmt::Display for #ident {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#cases
}
}
}
}
.into()
} else {
panic!("Not an enum type");
}
}

View file

@ -29,19 +29,19 @@ impl CopyPropPass {
pub fn run(&mut self, f: &mut Function) {
for b in &mut f.blocks {
for instr in &mut b.instrs {
match instr.op {
Opcode::VEC => {
match &instr.op {
Op::Vec(vec) => {
self.add_copy(
instr.dst(0).as_ssa().unwrap(),
instr.srcs().to_vec(),
vec.dst.as_ssa().unwrap(),
vec.srcs.to_vec(),
);
}
Opcode::SPLIT => {
let src_ssa = instr.src(0).as_ssa().unwrap();
Op::Split(split) => {
let src_ssa = split.src.as_ssa().unwrap();
if let Some(src_vec) = self.get_copy(src_ssa).cloned() {
assert!(src_vec.len() == instr.num_dsts());
for i in 0..instr.num_dsts() {
if let Dst::SSA(ssa) = instr.dst(i) {
assert!(src_vec.len() == split.dsts.len());
for (i, dst) in split.dsts.iter().enumerate() {
if let Dst::SSA(ssa) = dst {
self.add_copy(ssa, vec![src_vec[i]]);
}
}
@ -51,7 +51,7 @@ impl CopyPropPass {
}
if let Pred::SSA(src_ssa) = &instr.pred {
if let Some(src_vec) = self.get_copy(src_ssa) {
if let Some(src_vec) = self.get_copy(&src_ssa) {
if let Src::SSA(ssa) = src_vec[0] {
instr.pred = Pred::SSA(ssa);
}
@ -61,7 +61,7 @@ impl CopyPropPass {
for src in instr.srcs_mut() {
if let Ref::SSA(src_ssa) = src {
if src_ssa.comps() == 1 {
if let Some(src_vec) = self.get_copy(src_ssa) {
if let Some(src_vec) = self.get_copy(&src_ssa) {
*src = src_vec[0];
}
}