nak: Add support for instruction predicates

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24998>
This commit is contained in:
Faith Ekstrand 2023-01-30 20:53:17 -06:00 committed by Marge Bot
parent de073a10e6
commit ac2a56f56f
6 changed files with 143 additions and 25 deletions

View file

@ -13,7 +13,9 @@ use std::collections::HashMap;
struct TrivialRegAlloc {
next_reg: u8,
next_ureg: u8,
reg_map: HashMap<SSAValue, Ref>,
next_pred: u8,
next_upred: u8,
reg_map: HashMap<SSAValue, RegRef>,
}
impl TrivialRegAlloc {
@ -21,11 +23,13 @@ impl TrivialRegAlloc {
TrivialRegAlloc {
next_reg: 16, /* Leave some space for FS outputs */
next_ureg: 0,
next_pred: 0,
next_upred: 0,
reg_map: HashMap::new(),
}
}
fn alloc_reg(&mut self, file: RegFile, comps: u8) -> Ref {
fn alloc_reg(&mut self, file: RegFile, comps: u8) -> RegRef {
let align = comps.next_power_of_two();
let idx = match file {
RegFile::GPR => {
@ -38,22 +42,35 @@ impl TrivialRegAlloc {
self.next_ureg = idx + comps;
idx
}
RegFile::Pred | RegFile::UPred => panic!("Not handled"),
RegFile::Pred => {
let idx = self.next_pred.next_multiple_of(align);
self.next_pred = idx + comps;
idx
}
RegFile::UPred => {
let idx = self.next_upred.next_multiple_of(align);
self.next_upred = idx + comps;
idx
}
};
Ref::new_reg(file, idx, comps)
RegRef::new(file, idx, comps)
}
pub fn rewrite_ref(&mut self, r: &Ref) -> Ref {
if let Ref::SSA(ssa) = r {
if let Some(reg) = self.reg_map.get(ssa) {
*reg
} else {
let reg = self.alloc_reg(ssa.file(), ssa.comps());
self.reg_map.insert(*ssa, reg);
reg
}
pub fn rewrite_ssa(&mut self, ssa: SSAValue) -> RegRef {
if let Some(reg) = self.reg_map.get(&ssa) {
*reg
} else {
*r
let reg = self.alloc_reg(ssa.file(), ssa.comps());
self.reg_map.insert(ssa, reg);
reg
}
}
pub fn rewrite_ref(&mut self, r: Ref) -> Ref {
if let Ref::SSA(ssa) = r {
Ref::Reg(self.rewrite_ssa(ssa))
} else {
r
}
}
@ -61,13 +78,14 @@ impl TrivialRegAlloc {
for f in &mut s.functions {
for b in &mut f.blocks {
for instr in &mut b.instrs {
if let Pred::SSA(ssa) = instr.pred {
instr.pred = Pred::Reg(self.rewrite_ssa(ssa));
}
for dst in instr.dsts_mut() {
let new_dst = self.rewrite_ref(&dst);
*dst = new_dst;
*dst = self.rewrite_ref(*dst);
}
for src in instr.srcs_mut() {
let new_src = self.rewrite_ref(&src);
*src = new_src;
*src = self.rewrite_ref(*src);
}
}
}

View file

@ -36,6 +36,12 @@ fn encode_ureg(bs: &mut impl BitSetMut, range: Range<usize>, reg: RegRef) {
bs.set_field(range, reg.base_idx());
}
fn encode_pred(bs: &mut impl BitSetMut, range: Range<usize>, reg: RegRef) {
assert!(range.len() == 3);
assert!(reg.file() == RegFile::Pred);
bs.set_field(range, reg.base_idx());
}
fn encode_mod(bs: &mut impl BitSetMut, range: Range<usize>, src_mod: u8) {
assert!(range.len() == 2);
bs.set_field(range, src_mod);
@ -64,10 +70,23 @@ fn encode_cx(bs: &mut impl BitSetMut, range: Range<usize>, cb: &CBufRef) {
fn encode_instr_base(bs: &mut impl BitSetMut, instr: &Instr, opcode: u16) {
bs.set_field(0..12, opcode);
bs.set_field(12..16, 0x7_u8); /* TODO: Predicate */
if instr.pred.is_none() {
bs.set_field(12..15, 0x7_u8);
bs.set_bit(15, false);
} else {
encode_pred(bs, 12..15, *instr.pred.as_reg().unwrap());
bs.set_bit(15, instr.pred_inv);
}
if instr.num_dsts() > 0 {
assert!(instr.num_dsts() == 1);
encode_reg(bs, 16..24, *instr.dst(0).as_reg().unwrap());
let reg = instr.dst(0).as_reg().unwrap();
match reg.file() {
RegFile::GPR => encode_reg(bs, 16..24, *reg),
RegFile::Pred => encode_pred(bs, 81..84, *reg),
_ => panic!("Unsupported destination"),
}
}
bs.set_field(105..109, instr.deps.delay);

View file

@ -38,10 +38,14 @@ impl<'a> ShaderFromNir<'a> {
}
fn ref_for_nir_def(&self, def: &nir_def) -> Src {
assert!(def.bit_size == 32 || def.bit_size == 64);
let dwords = (def.bit_size / 32) * def.num_components;
//Src::new_ssa(def.index, dwords, !def.divergent)
Src::new_ssa(RegFile::GPR, def.index, dwords)
if def.bit_size == 1 {
Src::new_ssa(RegFile::Pred, def.index, def.num_components)
} else {
assert!(def.bit_size == 32 || def.bit_size == 64);
let dwords = (def.bit_size / 32) * def.num_components;
//Src::new_ssa(def.index, dwords, !def.divergent)
Src::new_ssa(RegFile::GPR, def.index, dwords)
}
}
fn get_src(&self, src: &nir_src) -> Src {
@ -53,10 +57,10 @@ impl<'a> ShaderFromNir<'a> {
}
fn get_alu_src(&mut self, alu_src: &nir_alu_src) -> Src {
assert!(alu_src.src.bit_size() == 32);
if alu_src.src.num_components() == 1 {
self.get_src(&alu_src.src)
} else {
assert!(alu_src.src.bit_size() == 32);
let vec_src = self.get_src(&alu_src.src);
let comp = self.alloc_ssa(vec_src.as_ssa().unwrap().file(), 1);
let mut dsts = Vec::new();

View file

@ -314,6 +314,60 @@ impl fmt::Display for Ref {
pub type Src = Ref;
pub type Dst = Ref;
pub enum Pred {
None,
SSA(SSAValue),
Reg(RegRef),
}
impl Pred {
pub fn new_ssa(file: RegFile, idx: u32, comps: u8) -> Pred {
Pred::SSA(SSAValue::new(file, idx, comps))
}
pub fn new_reg(file: RegFile, idx: u8, comps: u8) -> Pred {
Pred::Reg(RegRef::new(file, idx, comps))
}
pub fn as_reg(&self) -> Option<&RegRef> {
match self {
Pred::Reg(r) => Some(r),
_ => None,
}
}
pub fn as_ssa(&self) -> Option<&SSAValue> {
match self {
Pred::SSA(r) => Some(r),
_ => None,
}
}
pub fn is_none(&self) -> bool {
match self {
Pred::None => true,
_ => false,
}
}
}
impl fmt::Display for Pred {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Pred::None => (),
Pred::SSA(v) => {
if v.is_uniform() {
write!(f, "USSA{}@{}", v.idx(), v.comps())?;
} else {
write!(f, "SSA{}@{}", v.idx(), v.comps())?;
}
}
Pred::Reg(r) => r.fmt(f)?,
}
Ok(())
}
}
struct InstrRefArr {
num_dsts: u8,
num_srcs: u8,
@ -658,6 +712,8 @@ impl fmt::Display for InstrDeps {
pub struct Instr {
pub op: Opcode,
pub pred: Pred,
pub pred_inv: bool,
pub deps: InstrDeps,
refs: InstrRefs,
}
@ -666,6 +722,8 @@ impl Instr {
pub fn new(op: Opcode, dsts: &[Dst], srcs: &[Src]) -> Instr {
Instr {
op: op,
pred: Pred::None,
pred_inv: false,
refs: InstrRefs::new(dsts, srcs),
deps: InstrDeps::new(),
}
@ -906,6 +964,13 @@ impl fmt::Display for Opcode {
impl fmt::Display for Instr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.pred.is_none() {
if self.pred_inv {
write!(f, "@!{}", self.pred)?;
} else {
write!(f, "@{}", self.pred)?;
}
}
write!(f, "{} {{", self.op)?;
if self.num_dsts() > 0 {
write!(f, " {}", self.dst(0))?;

View file

@ -50,6 +50,14 @@ impl CopyPropPass {
_ => (),
}
if let Pred::SSA(src_ssa) = &instr.pred {
if let Some(src_vec) = self.get_copy(src_ssa) {
if let Src::SSA(ssa) = src_vec[0] {
instr.pred = Pred::SSA(ssa);
}
}
}
for src in instr.srcs_mut() {
if let Ref::SSA(src_ssa) = src {
if src_ssa.comps() == 1 {

View file

@ -23,6 +23,10 @@ impl DeadCodePass {
}
fn mark_instr_live(&mut self, instr: &Instr) {
if let Pred::SSA(ssa) = &instr.pred {
self.mark_ssa_live(ssa);
}
for src in instr.srcs() {
if let Src::SSA(ssa) = src {
self.mark_ssa_live(ssa);