nak/nir: Rework CRS handling

THe new code tracks the whole call/return stack.  This means we know the
size of the stack at all times.  It also means that we can detect a
bunch of potential error cases.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30402>
This commit is contained in:
Faith Ekstrand 2024-07-26 16:33:32 -05:00 committed by Marge Bot
parent cb5e10d0aa
commit 9bbc692064

View file

@ -288,6 +288,7 @@ impl Index<FloatType> for ShaderFloatControls {
} }
} }
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
enum SyncType { enum SyncType {
Sync, Sync,
Brk, Brk,
@ -303,7 +304,8 @@ struct ShaderFromNir<'a> {
label_alloc: LabelAllocator, label_alloc: LabelAllocator,
block_label: HashMap<u32, Label>, block_label: HashMap<u32, Label>,
bar_label: HashMap<u32, Label>, bar_label: HashMap<u32, Label>,
block_sync: HashMap<u32, SyncType>, sync_blocks: HashSet<u32>,
crs: Vec<(u32, SyncType)>,
fs_out_regs: [SSAValue; 34], fs_out_regs: [SSAValue; 34],
end_block_id: u32, end_block_id: u32,
ssa_map: HashMap<u32, Vec<SSAValue>>, ssa_map: HashMap<u32, Vec<SSAValue>>,
@ -322,7 +324,8 @@ impl<'a> ShaderFromNir<'a> {
label_alloc: LabelAllocator::new(), label_alloc: LabelAllocator::new(),
block_label: HashMap::new(), block_label: HashMap::new(),
bar_label: HashMap::new(), bar_label: HashMap::new(),
block_sync: HashMap::new(), sync_blocks: HashSet::new(),
crs: Vec::new(),
fs_out_regs: [SSAValue::NONE; 34], fs_out_regs: [SSAValue::NONE; 34],
end_block_id: 0, end_block_id: 0,
ssa_map: HashMap::new(), ssa_map: HashMap::new(),
@ -338,6 +341,54 @@ impl<'a> ShaderFromNir<'a> {
.or_insert_with(|| self.label_alloc.alloc()) .or_insert_with(|| self.label_alloc.alloc())
} }
fn push_crs(&mut self, target: &nir_block, sync_type: SyncType) {
self.sync_blocks.insert(target.index);
self.crs.push((target.index, sync_type));
let crs_depth = u32::try_from(self.crs.len()).unwrap();
self.info.max_crs_depth = max(self.info.max_crs_depth, crs_depth);
}
fn pop_crs(&mut self, target: &nir_block, sync_type: SyncType) {
if let Some((top_index, top_sync_type)) = self.crs.pop() {
assert!(top_index == target.index);
assert!(top_sync_type == sync_type);
} else {
panic!("Tried to pop an empty stack");
}
}
fn peek_crs(&self, target: &nir_block) -> Option<SyncType> {
for (i, (index, sync_type)) in self.crs.iter().enumerate().rev() {
if *index != target.index {
continue;
}
match sync_type {
SyncType::Sync => {
// Sync must always be top-of-stack
assert!(i == self.crs.len() - 1);
}
SyncType::Brk => {
// Brk cannot skip over another Brk
for (_, inner_sync) in &self.crs[(i + 1)..] {
assert!(*inner_sync != SyncType::Brk);
}
}
SyncType::Cont => {
// Cont can only skip over Sync
for (_, inner_sync) in &self.crs[(i + 1)..] {
assert!(*inner_sync == SyncType::Sync);
}
}
}
return Some(*sync_type);
}
assert!(!self.sync_blocks.contains(&target.index));
None
}
fn get_ssa(&mut self, ssa: &nir_def) -> &[SSAValue] { fn get_ssa(&mut self, ssa: &nir_def) -> &[SSAValue] {
self.ssa_map.get(&ssa.index).unwrap() self.ssa_map.get(&ssa.index).unwrap()
} }
@ -3113,25 +3164,24 @@ impl<'a> ShaderFromNir<'a> {
} else { } else {
self.cfg.add_edge(nb.index, target.index); self.cfg.add_edge(nb.index, target.index);
if let Some(sync) = self.block_sync.get(&target.index) { match self.peek_crs(target) {
match sync { Some(SyncType::Sync) => {
SyncType::Sync => {
b.push_op(OpSync {}); b.push_op(OpSync {});
} }
SyncType::Brk => { Some(SyncType::Brk) => {
b.push_op(OpBrk {}); b.push_op(OpBrk {});
} }
SyncType::Cont => { Some(SyncType::Cont) => {
b.push_op(OpCont {}); b.push_op(OpCont {});
} }
} None => {
} else {
b.push_op(OpBra { b.push_op(OpBra {
target: self.get_block_label(target), target: self.get_block_label(target),
}); });
} }
} }
} }
}
fn emit_pred_jump( fn emit_pred_jump(
&mut self, &mut self,
@ -3234,10 +3284,13 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(phi); b.push_op(phi);
} }
if matches!(self.block_sync.get(&nb.index), Some(SyncType::Cont)) { if self.sm.sm() < 75 && nb.cf_node.prev().is_none() {
if let Some(_) = nb.parent().as_loop() {
b.push_op(OpPCnt { b.push_op(OpPCnt {
target: self.get_block_label(nb), target: self.get_block_label(nb),
}); });
self.push_crs(nb, SyncType::Cont);
}
} }
let mut goto = None; let mut goto = None;
@ -3291,19 +3344,17 @@ impl<'a> ShaderFromNir<'a> {
if self.sm.sm() < 70 { if self.sm.sm() < 70 {
if let Some(ni) = nb.following_if() { if let Some(ni) = nb.following_if() {
if ni.condition.as_def().divergent {
let fb = ni.following_block(); let fb = ni.following_block();
self.block_sync.insert(fb.index, SyncType::Sync);
b.push_op(OpSSy { b.push_op(OpSSy {
target: self.get_block_label(fb), target: self.get_block_label(fb),
}); });
} self.push_crs(fb, SyncType::Sync);
} else if let Some(nl) = nb.following_loop() { } else if let Some(nl) = nb.following_loop() {
let fb = nl.following_block(); let fb = nl.following_block();
self.block_sync.insert(fb.index, SyncType::Brk);
b.push_op(OpPBk { b.push_op(OpPBk {
target: self.get_block_label(fb), target: self.get_block_label(fb),
}); });
self.push_crs(fb, SyncType::Brk);
} }
} }
@ -3401,10 +3452,6 @@ impl<'a> ShaderFromNir<'a> {
assert!(succ[1].is_none()); assert!(succ[1].is_none());
let s0 = succ[0].unwrap(); let s0 = succ[0].unwrap();
self.emit_jump(&mut b, nb, s0); self.emit_jump(&mut b, nb, s0);
if self.sm.sm() < 70 && nb.following_loop().is_some() {
self.block_sync.insert(s0.index, SyncType::Cont);
}
} }
} }
@ -3424,6 +3471,11 @@ impl<'a> ShaderFromNir<'a> {
) { ) {
self.parse_cf_list(ssa_alloc, phi_map, ni.iter_then_list()); self.parse_cf_list(ssa_alloc, phi_map, ni.iter_then_list());
self.parse_cf_list(ssa_alloc, phi_map, ni.iter_else_list()); self.parse_cf_list(ssa_alloc, phi_map, ni.iter_else_list());
if self.sm.sm() < 70 {
let next_block = ni.cf_node.next().unwrap().as_block().unwrap();
self.pop_crs(next_block, SyncType::Sync);
}
} }
fn parse_loop( fn parse_loop(
@ -3433,6 +3485,13 @@ impl<'a> ShaderFromNir<'a> {
nl: &nir_loop, nl: &nir_loop,
) { ) {
self.parse_cf_list(ssa_alloc, phi_map, nl.iter_body()); self.parse_cf_list(ssa_alloc, phi_map, nl.iter_body());
if self.sm.sm() < 70 {
let header = nl.iter_body().next().unwrap().as_block().unwrap();
self.pop_crs(header, SyncType::Cont);
let next_block = nl.cf_node.next().unwrap().as_block().unwrap();
self.pop_crs(next_block, SyncType::Brk);
}
} }
fn parse_cf_list( fn parse_cf_list(