diff --git a/src/.clang-format b/src/.clang-format index 0f0916c369e..802b782ed03 100644 --- a/src/.clang-format +++ b/src/.clang-format @@ -62,6 +62,7 @@ ForEachMacros: - rb_tree_foreach_rev - rb_tree_foreach_rev_safe - rb_tree_foreach_safe + - uinterval_tree_foreach - set_foreach - set_foreach_remove diff --git a/src/util/rb_tree.c b/src/util/rb_tree.c index a3bdd1b4912..98784129b17 100644 --- a/src/util/rb_tree.c +++ b/src/util/rb_tree.c @@ -39,6 +39,8 @@ #include #include +#include "macros.h" + static bool rb_node_is_black(struct rb_node *n) { @@ -118,7 +120,8 @@ rb_tree_splice(struct rb_tree *T, struct rb_node *u, struct rb_node *v) } static void -rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x) +rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x, + void (*update)(struct rb_node *)) { assert(x && x->right); @@ -129,10 +132,15 @@ rb_tree_rotate_left(struct rb_tree *T, struct rb_node *x) rb_tree_splice(T, x, y); y->left = x; rb_node_set_parent(x, y); + if (update) { + update(x); + update(y); + } } static void -rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y) +rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y, + void (*update)(struct rb_node *)) { assert(y && y->left); @@ -143,15 +151,23 @@ rb_tree_rotate_right(struct rb_tree *T, struct rb_node *y) rb_tree_splice(T, y, x); x->right = y; rb_node_set_parent(y, x); + if (update) { + update(y); + update(x); + } } void -rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, - struct rb_node *node, bool insert_left) +rb_augmented_tree_insert_at(struct rb_tree *T, struct rb_node *parent, + struct rb_node *node, bool insert_left, + void (*update)(struct rb_node *node)) { /* This sets null children, parent, and a color of red */ memset(node, 0, sizeof(*node)); + if (update) + update(node); + if (parent == NULL) { assert(T->root == NULL); T->root = node; @@ -168,6 +184,14 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, } rb_node_set_parent(node, parent); + if (update) { + struct rb_node *p = parent; + while (p) { + update(p); + p = rb_node_parent(p); + } + } + /* Now we do the insertion fixup */ struct rb_node *z = node; while (rb_node_is_red(rb_node_parent(z))) { @@ -185,7 +209,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, } else { if (z == z_p->right) { z = z_p; - rb_tree_rotate_left(T, z); + rb_tree_rotate_left(T, z, update); /* We changed z */ z_p = rb_node_parent(z); assert(z == z_p->left || z == z_p->right); @@ -193,7 +217,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, } rb_node_set_black(z_p); rb_node_set_red(z_p_p); - rb_tree_rotate_right(T, z_p_p); + rb_tree_rotate_right(T, z_p_p, update); } } else { struct rb_node *y = z_p_p->left; @@ -205,7 +229,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, } else { if (z == z_p->left) { z = z_p; - rb_tree_rotate_right(T, z); + rb_tree_rotate_right(T, z, update); /* We changed z */ z_p = rb_node_parent(z); assert(z == z_p->left || z == z_p->right); @@ -213,7 +237,7 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, } rb_node_set_black(z_p); rb_node_set_red(z_p_p); - rb_tree_rotate_left(T, z_p_p); + rb_tree_rotate_left(T, z_p_p, update); } } } @@ -221,7 +245,8 @@ rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, } void -rb_tree_remove(struct rb_tree *T, struct rb_node *z) +rb_augmented_tree_remove(struct rb_tree *T, struct rb_node *z, + void (*update)(struct rb_node *)) { /* x_p is always the parent node of X. We have to track this * separately because x may be NULL. @@ -260,6 +285,14 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z) assert(x_p == NULL || x == x_p->left || x == x_p->right); + if (update) { + struct rb_node *p = x_p; + while (p) { + update(p); + p = rb_node_parent(p); + } + } + if (!y_was_black) return; @@ -270,7 +303,7 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z) if (rb_node_is_red(w)) { rb_node_set_black(w); rb_node_set_red(x_p); - rb_tree_rotate_left(T, x_p); + rb_tree_rotate_left(T, x_p, update); assert(x == x_p->left); w = x_p->right; } @@ -281,13 +314,13 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z) if (rb_node_is_black(w->right)) { rb_node_set_black(w->left); rb_node_set_red(w); - rb_tree_rotate_right(T, w); + rb_tree_rotate_right(T, w, update); w = x_p->right; } rb_node_copy_color(w, x_p); rb_node_set_black(x_p); rb_node_set_black(w->right); - rb_tree_rotate_left(T, x_p); + rb_tree_rotate_left(T, x_p, update); x = T->root; } } else { @@ -295,7 +328,7 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z) if (rb_node_is_red(w)) { rb_node_set_black(w); rb_node_set_red(x_p); - rb_tree_rotate_right(T, x_p); + rb_tree_rotate_right(T, x_p, update); assert(x == x_p->right); w = x_p->left; } @@ -306,13 +339,13 @@ rb_tree_remove(struct rb_tree *T, struct rb_node *z) if (rb_node_is_black(w->left)) { rb_node_set_black(w->right); rb_node_set_red(w); - rb_tree_rotate_left(T, w); + rb_tree_rotate_left(T, w, update); w = x_p->left; } rb_node_copy_color(w, x_p); rb_node_set_black(x_p); rb_node_set_black(w->left); - rb_tree_rotate_right(T, x_p); + rb_tree_rotate_right(T, x_p, update); x = T->root; } } @@ -378,6 +411,229 @@ rb_node_prev(struct rb_node *node) } } +/* Return the first node in an interval tree that intersects a given interval + * or point. The tests against the interval and the max field are abstracted + * via function pointers, so that this works for any type of interval. + */ +static struct rb_node * +rb_node_min_intersecting(struct rb_node *node, void *interval, + int (*cmp_interval)(const struct rb_node *node, + const void *interval), + bool (*cmp_max)(const struct rb_node *node, + const void *interval)) +{ + if (!cmp_max(node, interval)) + return NULL; + + while (node) { + int cmp = cmp_interval(node, interval); + + /* If the node's interval is entirely to the right of the interval + * we're searching for, all of its right descendants are also to the + * right and don't intersect so we have to search to the left. + */ + if (cmp > 0) { + node = node->left; + continue; + } + + /* The interval overlaps or is to the left. This must also be true for + * its left descendants because their start points are to the left of + * node's. We can use the max to tell if there is an interval in its + * left descendants which overlaps our interval, in which case we + * should descend to the left. + */ + if (node->left && cmp_max(node->left, interval)) { + node = node->left; + continue; + } + + /* Now the only possibilities are the node's interval intersects the + * interval or one of its right descendants does. + */ + if (cmp == 0) + return node; + + node = node->right; + if (node && !cmp_max(node, interval)) + return NULL; + } + + return NULL; +} + +/* Return the next node after "node" that intersects a given interval. + * + * Because rb_node_min_intersecting() takes O(log n) time and may be called up + * to O(log n) times, in addition to the O(log n) crawl up the tree, a naive + * runtime analysis would show that this takes O((log n)^2) time, but actually + * it's O(log n). Proving this is tricky: + * + * Call the rightmost node in the tree whose start is before the end of the + * interval we're searching for N. All nodes after N in the tree are to the + * right of the interval. We'll divide the search into two phases: in phase 1, + * "node" is *not* an ancestor of N, and in phase 2 it is. Because we always + * crawl up the tree, phase 2 cannot turn back into phase 1, but phase 1 may + * be followed by phase 2. We'll prove that the calls to + * rb_node_min_intersecting() take O(log n) time in both phases. + * + * Phase 1: Because "node" is to the left of N and N isn't a descendant of + * "node", the start of every interval in "node"'s subtree must be less than + * or equal to N's start which is less than or equal to the end of the search + * interval. Furthermore, either "node"'s max_end is less than the start of + * the interval, in which case rb_node_min_intersecting() immediately returns + * NULL, or some descendant has an end equal to "node"'s max_end which is + * greater than or equal to the search interval's start, and therefore it + * intersects the search interval and rb_node_min_intersecting() must return + * non-NULL which causes us to terminate. rb_node_min_intersecting() is called + * O(log n) times, with all but the last call taking constant time and the + * last call taking O(log n), so the overall runtime is O(log n). + * + * Phase 2: After the first call to rb_node_min_intersecting, we may crawl up + * the tree until we get to a node p where "node", and therefore N, is in p's + * left subtree. However this means that p is to the right of N in the tree + * and is therefore to the right of the search interval, and the search + * terminates on the first iteration of the loop so that + * rb_node_min_intersecting() is only called once. + */ +static struct rb_node * +rb_node_next_intersecting(struct rb_node *node, + void *interval, + int (*cmp_interval)(const struct rb_node *node, + const void *interval), + bool (*cmp_max)(const struct rb_node *node, + const void *interval)) +{ + while (true) { + /* The first place to search is the node's right subtree. */ + if (node->right) { + struct rb_node *next = + rb_node_min_intersecting(node->right, interval, cmp_interval, cmp_max); + if (next) + return next; + } + + /* If we don't find a matching interval there, crawl up the tree until + * we find an ancestor to the right. This is the next node after the + * right subtree which we determined didn't match. + */ + struct rb_node *p = rb_node_parent(node); + while (p && node == p->right) { + node = p; + p = rb_node_parent(node); + } + assert(p == NULL || node == p->left); + + /* Check if we've searched everything in the tree. */ + if (!p) + return NULL; + + int cmp = cmp_interval(p, interval); + + /* If it intersects, return it. */ + if (cmp == 0) + return p; + + /* If it's to the right of the interval, all following nodes will be + * to the right and we can bail early. + */ + if (cmp > 0) + return NULL; + + node = p; + } +} + +static int +uinterval_cmp(struct uinterval a, struct uinterval b) +{ + if (a.end < b.start) + return -1; + else if (b.end < a.start) + return 1; + else + return 0; +} + +static int +uinterval_node_cmp(const struct rb_node *_a, const struct rb_node *_b) +{ + const struct uinterval_node *a = rb_node_data(struct uinterval_node, _a, node); + const struct uinterval_node *b = rb_node_data(struct uinterval_node, _b, node); + + return (int) (b->interval.start - a->interval.start); +} + +static int +uinterval_search_cmp(const struct rb_node *_node, const void *_interval) +{ + const struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, node); + const struct uinterval *interval = _interval; + + return uinterval_cmp(node->interval, *interval); +} + +static bool +uinterval_max_cmp(const struct rb_node *_node, const void *data) +{ + const struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, node); + const struct uinterval *interval = data; + + return node->max_end >= interval->start; +} + +static void +uinterval_update_max(struct rb_node *_node) +{ + struct uinterval_node *node = rb_node_data(struct uinterval_node, _node, node); + node->max_end = node->interval.end; + if (node->node.left) { + struct uinterval_node *left = rb_node_data(struct uinterval_node, node->node.left, node); + node->max_end = MAX2(node->max_end, left->max_end); + } + if (node->node.right) { + struct uinterval_node *right = rb_node_data(struct uinterval_node, node->node.right, node); + node->max_end = MAX2(node->max_end, right->max_end); + } +} + +void +uinterval_tree_insert(struct rb_tree *tree, struct uinterval_node *node) +{ + rb_augmented_tree_insert(tree, &node->node, uinterval_node_cmp, + uinterval_update_max); +} + +void +uinterval_tree_remove(struct rb_tree *tree, struct uinterval_node *node) +{ + rb_augmented_tree_remove(tree, &node->node, uinterval_update_max); +} + +struct uinterval_node * +uinterval_tree_first(struct rb_tree *tree, struct uinterval interval) +{ + if (!tree->root) + return NULL; + + struct rb_node *node = + rb_node_min_intersecting(tree->root, &interval, uinterval_search_cmp, + uinterval_max_cmp); + + return node ? rb_node_data(struct uinterval_node, node, node) : NULL; +} + +struct uinterval_node * +uinterval_node_next(struct uinterval_node *node, + struct uinterval interval) +{ + struct rb_node *next = + rb_node_next_intersecting(&node->node, &interval, uinterval_search_cmp, + uinterval_max_cmp); + + return next ? rb_node_data(struct uinterval_node, next, node) : NULL; +} + static void validate_rb_node(struct rb_node *n, int black_depth) { diff --git a/src/util/rb_tree.h b/src/util/rb_tree.h index 5e00977b5ba..b5400306fa4 100644 --- a/src/util/rb_tree.h +++ b/src/util/rb_tree.h @@ -117,6 +117,36 @@ struct rb_node *rb_node_prev(struct rb_node *node); #define rb_node_data(type, node, field) \ ((type *)(((char *)(node)) - rb_tree_offsetof(type, field, node))) +/** Insert a node into a possibly augmented tree at a particular location + * + * This function should probably not be used directly as it relies on the + * caller to ensure that the parent node is correct. Use rb_tree_insert + * instead. + * + * If \p update is non-NULL, it will be called for the node being inserted as + * well as any nodes which have their children changed and all of their + * ancestors. The intent is that each node may contain some augmented data + * which is calculated recursively from the node itself and its children, and + * \p update should recalculate that data. It's assumed that the function used + * to calculate the node data is associative in order to avoid calling it + * redundantly after rebalancing the tree. + * + * \param T The red-black tree into which to insert the new node + * + * \param parent The node in the tree that will be the parent of the + * newly inserted node + * + * \param node The node to insert + * + * \param insert_left If true, the new node will be the left child of + * \p parent, otherwise it will be the right child + * + * \param update The optional function used to calculate per-node data + */ +void rb_augmented_tree_insert_at(struct rb_tree *T, struct rb_node *parent, + struct rb_node *node, bool insert_left, + void (*update)(struct rb_node *)); + /** Insert a node into a tree at a particular location * * This function should probably not be used directly as it relies on the @@ -133,20 +163,27 @@ struct rb_node *rb_node_prev(struct rb_node *node); * \param insert_left If true, the new node will be the left child of * \p parent, otherwise it will be the right child */ -void rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, - struct rb_node *node, bool insert_left); +static inline void +rb_tree_insert_at(struct rb_tree *T, struct rb_node *parent, + struct rb_node *node, bool insert_left) +{ + rb_augmented_tree_insert_at(T, parent, node, insert_left, NULL); +} -/** Insert a node into a tree +/** Insert a node into a possibly augmented tree * * \param T The red-black tree into which to insert the new node * * \param node The node to insert * * \param cmp A comparison function to use to order the nodes. + * + * \param update Same meaning as in rb_augmented_tree_insert_at() */ static inline void -rb_tree_insert(struct rb_tree *T, struct rb_node *node, - int (*cmp)(const struct rb_node *, const struct rb_node *)) +rb_augmented_tree_insert(struct rb_tree *T, struct rb_node *node, + int (*cmp)(const struct rb_node *, const struct rb_node *), + void (*update)(struct rb_node *)) { /* This function is declared inline in the hopes that the compiler can * optimize away the comparison function pointer call. @@ -163,16 +200,47 @@ rb_tree_insert(struct rb_tree *T, struct rb_node *node, x = x->right; } - rb_tree_insert_at(T, y, node, left); + rb_augmented_tree_insert_at(T, y, node, left, update); } +/** Insert a node into a tree + * + * \param T The red-black tree into which to insert the new node + * + * \param node The node to insert + * + * \param cmp A comparison function to use to order the nodes. + */ +static inline void +rb_tree_insert(struct rb_tree *T, struct rb_node *node, + int (*cmp)(const struct rb_node *, const struct rb_node *)) +{ + rb_augmented_tree_insert(T, node, cmp, NULL); +} + +/** Remove a node from a possibly augmented tree + * + * \param T The red-black tree from which to remove the node + * + * \param node The node to remove + * + * \param update Same meaning as in rb_agumented_tree_insert_at() + * + */ +void rb_augmented_tree_remove(struct rb_tree *T, struct rb_node *z, + void (*update)(struct rb_node *)); + /** Remove a node from a tree * * \param T The red-black tree from which to remove the node * * \param node The node to remove */ -void rb_tree_remove(struct rb_tree *T, struct rb_node *z); +static inline void +rb_tree_remove(struct rb_tree *T, struct rb_node *z) +{ + rb_augmented_tree_remove(T, z, NULL); +} /** Search the tree for a node * @@ -332,6 +400,60 @@ rb_tree_search_sloppy(struct rb_tree *T, const void *key, __node = __prev, \ __prev = (type *)rb_node_prev_or_null((struct rb_node *)__node)) +/** Unsigned interval + * + * Intervals are closed by convention. + */ +struct uinterval { + unsigned start, end; +}; + +struct uinterval_node { + struct rb_node node; + + /* Must be filled in before inserting */ + struct uinterval interval; + + /* Managed internally by the tree */ + unsigned max_end; +}; + +/** Insert a node into an unsigned interval tree. */ +void uinterval_tree_insert(struct rb_tree *tree, struct uinterval_node *node); + +/** Remove a node from an unsigned interval tree. */ +void uinterval_tree_remove(struct rb_tree *tree, struct uinterval_node *node); + +/** Get the first node intersecting the given interval. */ +struct uinterval_node *uinterval_tree_first(struct rb_tree *tree, + struct uinterval interval); + +/** Get the next node after \p node intersecting the given interval. */ +struct uinterval_node *uinterval_node_next(struct uinterval_node *node, + struct uinterval interval); + +/** Iterate over the nodes in the tree intersecting the given interval + * + * The iteration itself should take O(k log n) time, where k is the number of + * iterations of the loop and n is the size of the tree. + * + * \param type The type of the containing data structure + * + * \param node The variable name for current node in the iteration; + * this will be declared as a pointer to \p type + * + * \param interval The interval to be tested against. + * + * \param T The red-black tree + * + * \param field The uinterval_node field in containing data structure + */ +#define uinterval_tree_foreach(type, iter, interval, T, field) \ + for (type *iter, *__node = (type *)uinterval_tree_first(T, interval); \ + __node != NULL && \ + (iter = rb_node_data(type, (struct uinterval_node *)__node, field), true); \ + __node = (type *)uinterval_node_next((struct uinterval_node *)__node, interval)) + /** Validate a red-black tree * * This function walks the tree and validates that this is a valid red- diff --git a/src/util/tests/rb_tree_test.cpp b/src/util/tests/rb_tree_test.cpp index 2676bd52b15..60680947637 100644 --- a/src/util/tests/rb_tree_test.cpp +++ b/src/util/tests/rb_tree_test.cpp @@ -28,6 +28,8 @@ #include #include +#include "macros.h" + /* A list of 100 random numbers from 1 to 100. The number 30 is explicitly * missing from this list. */ @@ -46,8 +48,6 @@ int test_numbers[] = { #define NON_EXISTANT_NUMBER 30 -#define ARRAY_SIZE(a) (sizeof(a) / sizeof(*a)) - struct rb_test_node { int key; struct rb_node node; @@ -283,3 +283,86 @@ TEST(RBTreeTest, FindFirstOfMiddle) EXPECT_NE(rb_test_node_cmp(prev, n), 0); } + +struct uinterval_test_node { + struct uinterval_node node; +}; + +static void +validate_interval_search(struct rb_tree *tree, + struct uinterval_test_node *nodes, + int first_node, int last_node, + unsigned start, + unsigned end) +{ + /* Count the number of intervals intersecting */ + unsigned actual_count = 0; + for (int i = first_node; i <= last_node; i++) { + if (nodes[i].node.interval.start <= end && + nodes[i].node.interval.end >= start) + actual_count++; + } + + /* iterate over matching intervals */ + struct uinterval interval = { start, end }; + unsigned max_val = 0; + struct uinterval_test_node *prev = NULL; + unsigned count = 0; + uinterval_tree_foreach (struct uinterval_test_node, n, interval, tree, node) { + assert(n->node.interval.start <= end && + n->node.interval.end >= start); + + /* Everything should be in increasing order */ + assert(n->node.interval.start >= max_val); + if (n->node.interval.start > max_val) { + max_val = n->node.interval.start; + } else { + /* Things should be stable, i.e., given equal keys, they should + * show up in the list in order of insertion. We insert them + * in the order they are in in the array. + */ + assert(prev == NULL || prev < n); + } + + prev = n; + count++; + } + + assert(count == actual_count); +} + +TEST(IntervalTreeTest, InsertAndSearch) +{ + struct uinterval_test_node nodes[ARRAY_SIZE(test_numbers) / 2]; + struct rb_tree tree; + + rb_tree_init(&tree); + + for (unsigned i = 0; 2 * i < ARRAY_SIZE(test_numbers); i++) { + nodes[i].node.interval.start = MIN2(test_numbers[2 * i], test_numbers[2 * i + 1]); + nodes[i].node.interval.end = MAX2(test_numbers[2 * i], test_numbers[2 * i + 1]); + uinterval_tree_insert(&tree, &nodes[i].node); + rb_tree_validate(&tree); + validate_interval_search(&tree, nodes, 0, i, 0, 100); + validate_interval_search(&tree, nodes, 0, i, 0, 50); + validate_interval_search(&tree, nodes, 0, i, 50, 100); + validate_interval_search(&tree, nodes, 0, i, 0, 2); + } + + for (unsigned i = 0; 2 * i < ARRAY_SIZE(test_numbers); i++) { + uinterval_tree_remove(&tree, &nodes[i].node); + rb_tree_validate(&tree); + validate_interval_search(&tree, nodes, i + 1, + ARRAY_SIZE(test_numbers) / 2 - 1, + 0, 100); + validate_interval_search(&tree, nodes, i + 1, + ARRAY_SIZE(test_numbers) / 2 - 1, + 0, 50); + validate_interval_search(&tree, nodes, i + 1, + ARRAY_SIZE(test_numbers) / 2 - 1, + 50, 100); + validate_interval_search(&tree, nodes, i + 1, + ARRAY_SIZE(test_numbers) / 2 - 1, + 0, 2); + } +}