GraphC/src/RBTree.hpp

534 lines
15 KiB
C++

#pragma once
#include "../../dependencies/kerep/src/base/base.h"
#include <map>
#include <sstream>
// typedef char* TKey;
// typedef char* TVal;
template<typename TKey, typename TVal>
class RBTree {
enum class Color {
Red, Black
};
struct Node {
TKey key;
TVal value;
Color color;
Node* parent;
Node* left = nullptr;
Node* right = nullptr;
Node(TKey _key, TVal _val, Color _color, Node* _parent)
: key(_key), value(_val), color(_color), parent(_parent)
{
}
~Node(){
if(left != nullptr)
delete left;
if(right != nullptr)
delete right;
}
inline Node* getSibling(){
if(parent == nullptr)
return nullptr;
else if(parent->left == this)
return parent->right;
else return parent->left;
}
inline Node* getGrandparent(){
if(parent == nullptr)
return nullptr;
else return parent->parent;
}
inline Node* getUncle(){
if(parent == nullptr)
return nullptr;
return parent->getSibling();
}
// n should be not null
Node* getMinChild(){
Node* n = this;
while(n->left != nullptr)
n = n->left;
return n;
}
// n should be not null
Node* getMaxChild(){
Node* n = this;
while(n->right != nullptr)
n = n->right;
return n;
}
};
Node* root = nullptr;
///@returns null if root is null
Node* findParentForKey(TKey key) const {
Node* n = root;
Node* parent = nullptr;
while(n != nullptr){
parent = n;
if(key < n->key)
n = n->left;
else if(key > n->key)
n = n->right;
else return n->parent; // key == n->key
}
return parent;
}
void rotateLeft(Node* x){
// 1. get right child of x
Node* y = x->right;
// 2. move y to the position of x
y->parent = x->parent;
if (x->parent != nullptr){ // x != root
if(x == x->parent->left)
x->parent->left = y;
else x->parent->right = y;
}
else root = y;
// 3. move y.left to x.right if it exists
x->right = y->left;
if (x->right != nullptr)
x->right->parent = x;
// 4. move x to y.left
y->left = x;
x->parent = y;
}
void rotateRight(Node* x){
// 1. get left child of x
Node* y = x->left;
// 2. move y up
y->parent = x->parent;
if (x->parent != nullptr){ // x != root
if(x == x->parent->left)
x->parent->left = y;
else x->parent->right = y;
}
else root = y;
// 3. move y.right to x.left if it exists
x->left = y->right;
if (x->left != nullptr)
x->left->parent = x;
// 4. move x to y.right
y->right = x;
x->parent = y;
}
void transplantNode(Node* old, Node* neww){
if(old->parent == nullptr)
root = neww;
else if(old->parent->left == old)
old->parent->left = neww;
else old->parent->right = neww;
if(neww != nullptr)
neww->parent = old->parent;
}
void fixupInsertion(Node* n){
// case 1: n is root -- root must be black
if (n->parent == nullptr){
n->color = Color::Black;
return;
}
// case 2: parent is black -- no requirements mismatch
if (n->parent->color == Color::Black)
return;
// case 3: parent and uncle are red -- red nodes must have black parents
Node* u = n->getUncle();
Node* g = n->getGrandparent();
if(u != nullptr && u->color == Color::Red){
n->parent->color = Color::Black;
u->color = Color::Black;
g->color = Color::Red;
fixupInsertion(g);
return;
}
// case 4: parent is red and uncle is black -- red nodes must have black parents
if ((n == n->parent->right) && (n->parent == g->left)) {
rotateLeft(n->parent);
n = n->left;
}
else if ((n == n->parent->left) && (n->parent == g->right)) {
rotateRight(n->parent);
n = n->right;
}
// case 5
n->parent->color = Color::Black;
g->color = Color::Red;
if ((n == n->parent->left) && (n->parent == g->left))
rotateRight(g);
else rotateLeft(g);
}
void fixupDeletion(Node* n){
// case 1
if(n->parent == nullptr)
return;
// case 2
Node* s = n->getSibling();
if(s->color == Color::Red){
n->parent->color = Color::Red;
s->color = Color::Black;
if (n == n->parent->left)
rotateLeft(n->parent);
else rotateRight(n->parent);
}
// case 3
if ((n->parent->color == Color::Black) &&
(s->color == Color::Black) &&
(s->left->color == Color::Black) &&
(s->right->color == Color::Black))
{
s->color = Color::Red;
fixupDeletion(n->parent);
return;
}
// case 4
else if ((n->parent->color == Color::Red) &&
(s->color == Color::Black) &&
(s->left->color == Color::Black) &&
(s->right->color == Color::Black))
{
s->color = Color::Red;
n->parent->color = Color::Black;
return;
}
// case 5
if(s->color == Color::Black) {
if ((n == n->parent->left) &&
(s->right->color == Color::Black) &&
(s->left->color == Color::Red))
{
s->color = Color::Red;
s->left->color = Color::Black;
rotateRight(s);
}
else if ((n == n->parent->right) &&
(s->left->color == Color::Black) &&
(s->right->color == Color::Red))
{
s->color = Color::Red;
s->right->color = Color::Black;
rotateLeft(s);
}
}
// case 6
s->color = n->parent->color;
n->parent->color = Color::Black;
if (n == n->parent->left) {
s->right->color = Color::Black;
rotateLeft(n->parent);
} else {
s->left->color = Color::Black;
rotateRight(n->parent);
}
}
template<typename TIteratorValue>
struct TreeIterator : std::iterator<std::bidirectional_iterator_tag, TIteratorValue> {
Node* n;
TreeIterator(TreeIterator const& src){
n = src.n;
}
TreeIterator(Node* ptr){
n = ptr;
}
bool operator!=(TreeIterator const& other) const {
return n != other.n;
}
bool operator==(TreeIterator const& other) const {
return n == other.n;
}
TIteratorValue& operator*() const {
if(n == nullptr)
throw "RBTree::TreeIterator::operator*() error: n == nullptr";
return *((TIteratorValue*)(void*)n);
}
void operator++() {
if(n == nullptr)
return;
if(n->right)
n = n->right->getMinChild();
else {
Node* p = n->parent;
while(p != nullptr && n == p->right){
n = p;
p = p->parent;
}
n = p;
}
}
};
public:
using iterator = TreeIterator<std::pair<const TKey, TVal>>;
using const_iterator = TreeIterator<std::pair<const TKey, const TVal>>;
RBTree() {}
~RBTree(){
delete root;
}
/// @param resultPtr nullable
bool tryAdd(TKey key, TVal& value, TVal** resultPtr){
if(root == nullptr){
root = new Node(key, value, Color::Black, nullptr);
if(resultPtr)
*resultPtr = &root->value;
return true;
}
Node* parent = findParentForKey(key);
// ptr to parent->right or parent->left
Node** nodePtrPtr;
if(key < parent->key)
nodePtrPtr = &parent->left;
else nodePtrPtr = &parent->right;
// if a child node already exists at this place, returns false
if(*nodePtrPtr != nullptr){
if(resultPtr)
*resultPtr = nullptr;
return false;
}
// places newNode to left or right of the parent
Node* newNode = new Node(key, value, Color::Red, parent);
if(resultPtr)
*resultPtr = &newNode->value;
*nodePtrPtr = newNode;
// auto-balancing
fixupInsertion(newNode);
return true;
}
/// @param resultPtr nullable
bool tryAdd(TKey key, TVal&& value, TVal** resultPtr){
return tryAdd(key, value, resultPtr);
}
/// @param resultPtr nullable
bool trySet(TKey key, TVal& value, TVal** resultPtr){
if(root == nullptr){
if(resultPtr)
*resultPtr = nullptr;
return false;
}
Node* parent = findParentForKey(key);
// ptr to parent->right or parent->left
Node** nodePtrPtr;
if(key < parent->key)
nodePtrPtr = &parent->left;
else nodePtrPtr = &parent->right;
// if a child node with the given key doesn't exist, returns false
if(*nodePtrPtr == nullptr){
if(resultPtr)
*resultPtr = nullptr;
return false;
}
// replaces the value of left or right child of the parent
(*nodePtrPtr)->value = value;
if(resultPtr)
*resultPtr = &(*nodePtrPtr)->value;
return true;
}
/// @param resultPtr nullable
bool trySet(TKey key, TVal&& value, TVal** resultPtr){
return trySet(key, value, resultPtr);
}
/// @param resultPtr nullable
void addOrSet(TKey key, TVal& value, TVal** resultPtr){
if(root == nullptr){
root = new Node(key, value, Color::Black, nullptr);
if(resultPtr != nullptr)
*resultPtr = &root->value;
return;
}
Node* parent = findParentForKey(key);
// ptr to parent->right or parent->left
Node** nodePtrPtr;
if(key < parent->key)
nodePtrPtr = &parent->left;
else nodePtrPtr = &parent->right;
// if a child node already exists at this place, sets it's value
if(*nodePtrPtr != nullptr){
(*nodePtrPtr)->value = value;
if(resultPtr)
*resultPtr = &(*nodePtrPtr)->value;
return;
}
// places newNode to left or right of the parent
Node* newNode = new Node(key, value, Color::Red, parent);
if(resultPtr != nullptr)
*resultPtr = &newNode->value;
*nodePtrPtr = newNode;
// auto-balancing
fixupInsertion(newNode);
}
/// @param resultPtr nullable
void addOrSet(TKey key, TVal&& value, TVal** resultPtr){
addOrSet(key, value, resultPtr);
}
bool tryGet(TKey key, TVal** resultPtr) const {
if(!resultPtr)
return false;
Node* parent = findParentForKey(key);
Node* n = nullptr;
if(parent == nullptr)
n = root;
else if(key < parent->key)
n = parent->left;
else n = parent->right;
// if there is no node with the given key
if(n == nullptr){
*resultPtr = nullptr;
return false;
}
*resultPtr = &n->value;
return true;
}
bool tryDelete(TKey key){
Node* parent = findParentForKey(key);
Node* n = nullptr;
if(parent == nullptr)
n = root;
else if(key < parent->key)
n = parent->left;
else n = parent->right;
// key not found
if(n == nullptr){
return false;
}
if(n->left == nullptr){
transplantNode(n, n->right);
if(n->color == Color::Black && n->right != nullptr)
fixupDeletion(n->right);
}
else if(n->right == nullptr){
transplantNode(n, n->left);
if(n->color == Color::Black && n->left != nullptr)
fixupDeletion(n->left);
}
else {
Node* minNode = n->right->getMinChild();
if(minNode != n->right){
transplantNode(minNode, minNode->right);
minNode->right = n->right;
n->right->parent = minNode;
}
// else minNode->right->parent = minNode; // wtf???
transplantNode(n, minNode);
minNode->left = n->left;
n->left->parent = minNode;
if(minNode->color == Color::Black)
fixupDeletion(minNode->right);
}
delete n;
return true;
}
iterator begin(){
if(root == nullptr)
return iterator(nullptr);
return iterator(root->getMinChild());
}
iterator end(){
if(root == nullptr)
return iterator(nullptr);
return iterator(root->getMaxChild());
}
const_iterator begin() const {
if(root == nullptr)
return const_iterator(nullptr);
return const_iterator(root->getMinChild());
}
const_iterator end() const {
return const_iterator(nullptr);
}
void _generateGraphVizCodeForChildren(std::stringstream& ss, Node* n) const {
if(!n)
return;
if(n->color == Color::Red)
ss<<" \""<<n->key<<"\" [color=red]\n";
if(n->left){
ss<<" \""<<n->key<<"\" -> \""<<n->left->key<<"\" [side=L]\n";
_generateGraphVizCodeForChildren(ss, n->left);
}
else ss<<" \""<<n->key<<"\" -> \"null\" [side=L]\n";
if(n->right){
ss<<" \""<<n->key<<"\" -> \""<<n->right->key<<"\" [side=R]\n";
_generateGraphVizCodeForChildren(ss, n->right);
}
else ss<<" \""<<n->key<<"\" -> \"null\" [side=R]\n";
}
std::string generateGraphVizCode() const {
std::stringstream ss;
ss<<"digraph {\n"
" node [style=filled,color=gray];\n";
if(root == nullptr)
ss<<" \"null\"\n";
else {
ss<<" \"null-parent\" -> \""<<root->key<<"\"\n";
_generateGraphVizCodeForChildren(ss, root);
}
ss<<"}";
return ss.str();
}
};