mesa/src/nouveau/compiler/nak/union_find.rs

293 lines
8.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright © 2024 Mel Henning
// SPDX-License-Identifier: MIT
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash};
#[derive(Copy, Clone)]
struct Root<X: Copy> {
size: usize,
representative: X,
}
#[derive(Copy, Clone)]
enum Node<X: Copy> {
Child { parent_idx: usize },
Root(Root<X>),
}
/// Union-find structure
///
/// This implementation follows Tarjan and van Leeuwen - specifically the
/// "link by size" and "halving" variant.
///
/// Robert E. Tarjan and Jan van Leeuwen. 1984. Worst-case Analysis of Set
/// Union Algorithms. J. ACM 31, 2 (April 1984), 245281.
/// <https://doi.org/10.1145/62.2160>
pub struct UnionFind<X: Copy + Hash + Eq, H: BuildHasher + Default> {
idx_map: HashMap<X, usize, H>,
nodes: Vec<Node<X>>,
}
impl<X: Copy + Hash + Eq, H: BuildHasher + Default> UnionFind<X, H> {
/// Create a new union-find structure
///
/// At initialization, each possible value is in its own set
pub fn new() -> Self {
UnionFind {
idx_map: Default::default(),
nodes: Vec::new(),
}
}
fn find_root(&mut self, mut idx: usize) -> (usize, Root<X>) {
loop {
match self.nodes[idx] {
Node::Child { parent_idx } => {
match self.nodes[parent_idx] {
Node::Child {
parent_idx: grandparent_idx,
} => {
// "Halving" in Tarjan and van Leeuwen
self.nodes[idx] = Node::Child {
parent_idx: grandparent_idx,
};
idx = grandparent_idx;
}
Node::Root(parent_root) => {
return (parent_idx, parent_root)
}
}
}
Node::Root(root) => return (idx, root),
}
}
}
/// Find the representative element for x
pub fn find(&mut self, x: X) -> X {
match self.idx_map.get(&x) {
Some(&idx) => {
let (_, Root { representative, .. }) = self.find_root(idx);
representative
}
None => x,
}
}
fn map_or_create(&mut self, x: X) -> usize {
*self.idx_map.entry(x).or_insert_with(|| {
self.nodes.push(Node::Root(Root {
size: 1,
representative: x,
}));
self.nodes.len() - 1
})
}
/// Union the sets containing a and b
///
/// The representative for a will become the representative of
/// the combined set
pub fn union(&mut self, a: X, b: X) {
if a == b {
return;
}
let a_idx = self.map_or_create(a);
let b_idx = self.map_or_create(b);
let (a_root_idx, a_root) = self.find_root(a_idx);
let (b_root_idx, b_root) = self.find_root(b_idx);
if a_root_idx != b_root_idx {
// Keep the tree balanced
let (new_root_idx, new_child_idx) = if a_root.size >= b_root.size {
(a_root_idx, b_root_idx)
} else {
(b_root_idx, a_root_idx)
};
self.nodes[new_root_idx] = Node::Root(Root {
size: a_root.size + b_root.size,
representative: a_root.representative,
});
self.nodes[new_child_idx] = Node::Child {
parent_idx: new_root_idx,
};
}
}
/// Return true if find() is the identity mapping
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
}
#[cfg(test)]
mod tests {
use crate::union_find::Node;
use crate::union_find::UnionFind;
use std::cmp::max;
use std::hash::BuildHasher;
use std::hash::Hash;
use std::hash::RandomState;
fn ceil_log2(x: usize) -> u32 {
assert!(x > 0);
usize::BITS - (x - 1).leading_zeros()
}
struct HeightInfo {
height: u32,
size: usize,
}
pub struct HeightCalc<'a, X: Copy + Hash + Eq, H: BuildHasher + Default> {
uf: &'a UnionFind<X, H>,
downward_edges: Vec<Vec<usize>>,
}
impl<'a, X: Copy + Hash + Eq, H: BuildHasher + Default> HeightCalc<'a, X, H> {
fn new(uf: &'a UnionFind<X, H>) -> Self {
let mut downward_edges: Vec<Vec<usize>> =
uf.nodes.iter().map(|_| Vec::new()).collect();
for (i, node) in uf.nodes.iter().enumerate() {
if let Node::Child { parent_idx } = node {
downward_edges[*parent_idx].push(i);
}
}
HeightCalc { uf, downward_edges }
}
fn calc_info(&self, idx: usize) -> HeightInfo {
let mut result = HeightInfo { height: 0, size: 1 };
for child in &self.downward_edges[idx] {
let child_result = self.calc_info(*child);
result.height = max(result.height, child_result.height + 1);
result.size += child_result.size;
}
result
}
fn check_roots(&self) -> u32 {
let mut total_size = 0;
let mut max_height = 0;
for (i, node) in self.uf.nodes.iter().enumerate() {
if let Node::Root(root) = node {
let info = self.calc_info(i);
assert_eq!(root.size, info.size);
total_size += info.size;
max_height = max(max_height, info.height);
let max_expected_height = ceil_log2(root.size + 1) - 1;
if info.height > max_expected_height {
eprintln!(
"height {}\t max_expected_height {}\t size {}",
info.height, max_expected_height, info.size
);
}
assert!(info.height <= max_expected_height);
}
}
assert_eq!(total_size, self.uf.nodes.len());
assert_eq!(total_size, self.uf.idx_map.len());
max_height
}
pub fn check(uf: &'a UnionFind<X, H>) -> u32 {
HeightCalc::new(uf).check_roots()
}
}
#[test]
fn test_basic() {
let mut f = UnionFind::<u32, RandomState>::new();
assert_eq!(f.find(10), 10);
assert_eq!(f.find(12), 12);
f.union(10, 12);
f.union(11, 13);
HeightCalc::check(&f);
assert_eq!(f.find(13), 11);
assert_eq!(f.find(12), 10);
assert_eq!(f.find(11), 11);
assert_eq!(f.find(10), 10);
f.union(12, 13);
HeightCalc::check(&f);
assert_eq!(f.find(13), 10);
assert_eq!(f.find(12), 10);
assert_eq!(f.find(11), 10);
assert_eq!(f.find(10), 10);
assert_eq!(f.find(14), 14);
HeightCalc::check(&f);
// Union the set with itself
f.union(11, 10);
HeightCalc::check(&f);
assert_eq!(f.find(13), 10);
assert_eq!(f.find(12), 10);
assert_eq!(f.find(11), 10);
assert_eq!(f.find(10), 10);
}
#[test]
fn test_chain_a_height() {
let mut f = UnionFind::<u32, RandomState>::new();
for i in 0..1000 {
f.union(i, i + 1);
HeightCalc::check(&f);
}
assert_eq!(f.find(1000), 0);
}
#[test]
fn test_chain_b_height() {
let mut f = UnionFind::<u32, RandomState>::new();
for i in 0..1000 {
f.union(i + 1, i);
HeightCalc::check(&f);
}
assert_eq!(f.find(0), 1000);
}
#[test]
fn test_binary_tree_height() {
let height = 8;
let count = 1 << height;
let mut f = UnionFind::<usize, RandomState>::new();
for current_height in 0..height {
let stride = 1 << current_height;
for i in (0..count).step_by(2 * stride) {
f.union(i, i + stride);
}
let actual_height = HeightCalc::check(&f);
// actual_height can vary based on tiebreaker condition
assert!(
actual_height == current_height
|| actual_height == current_height + 1
);
}
// Check path halving
let actual_height_before = HeightCalc::check(&f);
for i in 0..count {
assert_eq!(f.find(i), 0);
}
let actual_height_after = HeightCalc::check(&f);
assert!(actual_height_after <= actual_height_before.div_ceil(2));
}
}