diff --git a/src/compiler/rust/cfg.rs b/src/compiler/rust/cfg.rs index bb2b11005a1..d8a82da92aa 100644 --- a/src/compiler/rust/cfg.rs +++ b/src/compiler/rust/cfg.rs @@ -200,27 +200,42 @@ fn calc_dominance(nodes: &mut Vec>) { debug_assert!(dom_dfs.count == nodes.len() * 2); } -fn loop_detect_dfs( - nodes: &[CFGNode], - id: usize, - pre: &mut BitSet, - post: &mut BitSet, - back_edges: &mut Vec<(usize, usize)>, -) { - if pre.contains(id) { - return; +struct BackEdgesDFS<'a, N> { + nodes: &'a [CFGNode], + pre: BitSet, + post: BitSet, + back_edges: Vec<(usize, usize)>, +} + +impl<'a, N> DepthFirstSearch for BackEdgesDFS<'a, N> { + type ChildIter = Cloned>; + + fn pre(&mut self, id: usize) -> Self::ChildIter { + self.pre.insert(id); + + self.nodes[id].succ.iter().cloned() } - pre.insert(id); - - for &s in nodes[id].succ.iter() { - if pre.contains(s) && !post.contains(s) { - back_edges.push((id, s)); + fn edge(&mut self, parent: usize, child: usize) { + if self.pre.contains(child) && !self.post.contains(child) { + self.back_edges.push((parent, child)); } - loop_detect_dfs(nodes, s, pre, post, back_edges); } - post.insert(id); + fn post(&mut self, id: usize) { + self.post.insert(id); + } +} + +fn find_back_edges(nodes: &[CFGNode]) -> Vec<(usize, usize)> { + let mut be_dfs = BackEdgesDFS { + nodes, + pre: Default::default(), + post: Default::default(), + back_edges: Default::default(), + }; + dfs(&mut be_dfs, 0); + be_dfs.back_edges } /// Computes the set of nodes that reach the given node without going through @@ -245,11 +260,7 @@ fn reaches_dfs( } fn detect_loops(nodes: &mut Vec>) -> bool { - let mut dfs_pre = BitSet::new(); - let mut dfs_post = BitSet::new(); - let mut back_edges = Vec::new(); - loop_detect_dfs(nodes, 0, &mut dfs_pre, &mut dfs_post, &mut back_edges); - + let back_edges = find_back_edges(nodes); if back_edges.is_empty() { return false; } @@ -266,9 +277,8 @@ fn detect_loops(nodes: &mut Vec>) -> bool { loops.insert(h); // re-use dfs_pre for our reaches set - dfs_pre.clear(); - let reaches = &mut dfs_pre; - reaches_dfs(nodes, c, h, reaches); + let mut reaches = BitSet::new(); + reaches_dfs(nodes, c, h, &mut reaches); for n in reaches.iter() { node_loops[n].insert(h);