nak: Add a Label struct for branch targets

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24998>
This commit is contained in:
Faith Ekstrand 2023-09-28 18:08:26 -05:00 committed by Marge Bot
parent f2e07cbab9
commit e3fa6f3557
3 changed files with 53 additions and 13 deletions

View file

@ -1553,7 +1553,7 @@ impl SM75Instr {
&mut self,
op: &OpBra,
ip: usize,
block_offsets: &HashMap<u32, usize>,
labels: &HashMap<Label, usize>,
) {
let ip = u64::try_from(ip).unwrap();
assert!(ip < i64::MAX as u64);
@ -1653,7 +1653,7 @@ impl SM75Instr {
instr: &Instr,
sm: u8,
ip: usize,
block_offsets: &HashMap<u32, usize>,
block_offsets: &HashMap<Label, usize>,
) -> [u32; 4] {
assert!(sm >= 75);
@ -1736,7 +1736,7 @@ pub fn encode_shader(shader: &Shader) -> Vec<u32> {
let mut num_instrs = 0_usize;
let mut block_offsets = HashMap::new();
for b in &func.blocks {
block_offsets.insert(b.id, num_instrs);
block_offsets.insert(b.label, num_instrs);
num_instrs += b.instrs.len() * 4;
}

View file

@ -138,6 +138,8 @@ struct ShaderFromNir<'a> {
nir: &'a nir_shader,
info: ShaderInfo,
cfg: CFGBuilder<u32, BasicBlock>,
label_alloc: LabelAllocator,
block_label: HashMap<u32, Label>,
fs_out_regs: [SSAValue; 34],
end_block_id: u32,
ssa_map: HashMap<u32, Vec<SSAValue>>,
@ -150,6 +152,8 @@ impl<'a> ShaderFromNir<'a> {
nir: nir,
info: init_info_from_nir(nir, sm),
cfg: CFGBuilder::new(),
label_alloc: LabelAllocator::new(),
block_label: HashMap::new(),
fs_out_regs: [SSAValue::NONE; 34],
end_block_id: 0,
ssa_map: HashMap::new(),
@ -157,6 +161,13 @@ impl<'a> ShaderFromNir<'a> {
}
}
fn get_block_label(&mut self, block: &nir_block) -> Label {
*self
.block_label
.entry(block.index)
.or_insert_with(|| self.label_alloc.alloc())
}
fn get_ssa(&mut self, ssa: &nir_def) -> &[SSAValue] {
self.ssa_map.get(&ssa.index).unwrap()
}
@ -1992,7 +2003,7 @@ impl<'a> ShaderFromNir<'a> {
self.cfg.add_edge(nb.index, ni.first_else_block().index);
let mut bra = Instr::new_boxed(OpBra {
target: ni.first_else_block().index,
target: self.get_block_label(ni.first_else_block()),
});
let cond = self.get_ssa(&ni.condition.as_def())[0];
@ -2009,13 +2020,15 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpExit {});
} else {
self.cfg.add_edge(nb.index, s0.index);
b.push_op(OpBra { target: s0.index });
b.push_op(OpBra {
target: self.get_block_label(s0),
});
}
}
let mut bb = BasicBlock::new(nb.index);
let mut bb = BasicBlock::new(self.get_block_label(nb));
bb.instrs.append(&mut b.as_vec());
self.cfg.add_node(bb.id, bb);
self.cfg.add_node(nb.index, bb);
}
fn parse_if<'b>(

View file

@ -18,6 +18,33 @@ use std::iter::Zip;
use std::ops::{BitAnd, BitOr, Deref, DerefMut, Index, IndexMut, Not, Range};
use std::slice;
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
pub struct Label {
idx: u32,
}
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "L{}", self.idx)
}
}
pub struct LabelAllocator {
count: u32,
}
impl LabelAllocator {
pub fn new() -> LabelAllocator {
LabelAllocator { count: 0 }
}
pub fn alloc(&mut self) -> Label {
let idx = self.count;
self.count += 1;
Label { idx: idx }
}
}
/// Represents a register file
#[repr(u8)]
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
@ -3341,12 +3368,12 @@ impl fmt::Display for OpMemBar {
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBra {
pub target: u32,
pub target: Label,
}
impl fmt::Display for OpBra {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BRA B{}", self.target)
write!(f, "BRA {}", self.target)
}
}
@ -4310,14 +4337,14 @@ impl MappedInstrs {
}
pub struct BasicBlock {
pub id: u32,
pub label: Label,
pub instrs: Vec<Box<Instr>>,
}
impl BasicBlock {
pub fn new(id: u32) -> BasicBlock {
pub fn new(label: Label) -> BasicBlock {
BasicBlock {
id: id,
label: label,
instrs: Vec::new(),
}
}
@ -4450,7 +4477,7 @@ impl Function {
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for i in 0..self.blocks.len() {
write!(f, "block {}(id={}) [", i, self.blocks[i].id)?;
write!(f, "block {} {} [", i, self.blocks[i].label)?;
for (pi, p) in self.blocks.pred_indices(i).iter().enumerate() {
if pi > 0 {
write!(f, ", ")?;