#pragma once #include "../../dependencies/kerep/src/base/base.h" #include #include // typedef char* TKey; // typedef char* TVal; template class RBTree { enum class Color { Red, Black }; struct Node { TKey key; TVal value; Color color; Node* parent; Node* left = nullptr; Node* right = nullptr; Node(TKey _key, TVal _val, Color _color, Node* _parent) : key(_key), value(_val), color(_color), parent(_parent) { } ~Node(){ if(left != nullptr) delete left; if(right != nullptr) delete right; } inline Node* getSibling(){ if(parent == nullptr) return nullptr; else if(parent->left == this) return parent->right; else return parent->left; } inline Node* getGrandparent(){ if(parent == nullptr) return nullptr; else return parent->parent; } inline Node* getUncle(){ if(parent == nullptr) return nullptr; return parent->getSibling(); } // n should be not null Node* getMinChild(){ Node* n = this; while(n->left != nullptr) n = n->left; return n; } // n should be not null Node* getMaxChild(){ Node* n = this; while(n->right != nullptr) n = n->right; return n; } }; Node* root = nullptr; ///@returns null if root is null Node* findParentForKey(TKey key) const { Node* n = root; Node* parent = nullptr; while(n != nullptr){ parent = n; if(key < n->key) n = n->left; else if(key > n->key) n = n->right; else return n->parent; // key == n->key } return parent; } void rotateLeft(Node* x){ // 1. get right child of x Node* y = x->right; // 2. move y to the position of x y->parent = x->parent; if (x->parent != nullptr){ // x != root if(x == x->parent->left) x->parent->left = y; else x->parent->right = y; } else root = y; // 3. move y.left to x.right if it exists x->right = y->left; if (x->right != nullptr) x->right->parent = x; // 4. move x to y.left y->left = x; x->parent = y; } void rotateRight(Node* x){ // 1. get left child of x Node* y = x->left; // 2. move y up y->parent = x->parent; if (x->parent != nullptr){ // x != root if(x == x->parent->left) x->parent->left = y; else x->parent->right = y; } else root = y; // 3. move y.right to x.left if it exists x->left = y->right; if (x->left != nullptr) x->left->parent = x; // 4. move x to y.right y->right = x; x->parent = y; } void transplantNode(Node* old, Node* neww){ if(old->parent == nullptr) root = neww; else if(old->parent->left == old) old->parent->left = neww; else old->parent->right = neww; if(neww != nullptr) neww->parent = old->parent; } void fixupInsertion(Node* n){ // case 1: n is root -- root must be black if (n->parent == nullptr){ n->color = Color::Black; return; } // case 2: parent is black -- no requirements mismatch if (n->parent->color == Color::Black) return; // case 3: parent and uncle are red -- red nodes must have black parents Node* u = n->getUncle(); Node* g = n->getGrandparent(); if(u != nullptr && u->color == Color::Red){ n->parent->color = Color::Black; u->color = Color::Black; g->color = Color::Red; fixupInsertion(g); return; } // case 4: parent is red and uncle is black -- red nodes must have black parents if ((n == n->parent->right) && (n->parent == g->left)) { rotateLeft(n->parent); n = n->left; } else if ((n == n->parent->left) && (n->parent == g->right)) { rotateRight(n->parent); n = n->right; } // case 5 n->parent->color = Color::Black; g->color = Color::Red; if ((n == n->parent->left) && (n->parent == g->left)) rotateRight(g); else rotateLeft(g); } void fixupDeletion(Node* n){ // case 1 if(n->parent == nullptr) return; // case 2 Node* s = n->getSibling(); if(s->color == Color::Red){ n->parent->color = Color::Red; s->color = Color::Black; if (n == n->parent->left) rotateLeft(n->parent); else rotateRight(n->parent); } // case 3 if ((n->parent->color == Color::Black) && (s->color == Color::Black) && (s->left->color == Color::Black) && (s->right->color == Color::Black)) { s->color = Color::Red; fixupDeletion(n->parent); return; } // case 4 else if ((n->parent->color == Color::Red) && (s->color == Color::Black) && (s->left->color == Color::Black) && (s->right->color == Color::Black)) { s->color = Color::Red; n->parent->color = Color::Black; return; } // case 5 if(s->color == Color::Black) { if ((n == n->parent->left) && (s->right->color == Color::Black) && (s->left->color == Color::Red)) { s->color = Color::Red; s->left->color = Color::Black; rotateRight(s); } else if ((n == n->parent->right) && (s->left->color == Color::Black) && (s->right->color == Color::Red)) { s->color = Color::Red; s->right->color = Color::Black; rotateLeft(s); } } // case 6 s->color = n->parent->color; n->parent->color = Color::Black; if (n == n->parent->left) { s->right->color = Color::Black; rotateLeft(n->parent); } else { s->left->color = Color::Black; rotateRight(n->parent); } } template struct TreeIterator : std::iterator { Node* n; TreeIterator(TreeIterator const& src){ n = src.n; } TreeIterator(Node* ptr){ n = ptr; } bool operator!=(TreeIterator const& other) const { return n != other.n; } bool operator==(TreeIterator const& other) const { return n == other.n; } TIteratorValue& operator*() const { if(n == nullptr) throw "RBTree::TreeIterator::operator*() error: n == nullptr"; return *((TIteratorValue*)(void*)n); } void operator++() { if(n == nullptr) return; if(n->right) n = n->right->getMinChild(); else { Node* p = n->parent; while(p != nullptr && n == p->right){ n = p; p = p->parent; } n = p; } } }; public: using iterator = TreeIterator>; using const_iterator = TreeIterator>; RBTree() {} ~RBTree(){ delete root; } /// @param resultPtr nullable bool tryAdd(TKey key, TVal& value, TVal** resultPtr){ if(root == nullptr){ root = new Node(key, value, Color::Black, nullptr); if(resultPtr) *resultPtr = &root->value; return true; } Node* parent = findParentForKey(key); // ptr to parent->right or parent->left Node** nodePtrPtr; if(key < parent->key) nodePtrPtr = &parent->left; else nodePtrPtr = &parent->right; // if a child node already exists at this place, returns false if(*nodePtrPtr != nullptr){ if(resultPtr) *resultPtr = nullptr; return false; } // places newNode to left or right of the parent Node* newNode = new Node(key, value, Color::Red, parent); if(resultPtr) *resultPtr = &newNode->value; *nodePtrPtr = newNode; // auto-balancing fixupInsertion(newNode); return true; } /// @param resultPtr nullable bool tryAdd(TKey key, TVal&& value, TVal** resultPtr){ return tryAdd(key, value, resultPtr); } /// @param resultPtr nullable bool trySet(TKey key, TVal& value, TVal** resultPtr){ if(root == nullptr){ if(resultPtr) *resultPtr = nullptr; return false; } Node* parent = findParentForKey(key); // ptr to parent->right or parent->left Node** nodePtrPtr; if(key < parent->key) nodePtrPtr = &parent->left; else nodePtrPtr = &parent->right; // if a child node with the given key doesn't exist, returns false if(*nodePtrPtr == nullptr){ if(resultPtr) *resultPtr = nullptr; return false; } // replaces the value of left or right child of the parent (*nodePtrPtr)->value = value; if(resultPtr) *resultPtr = &(*nodePtrPtr)->value; return true; } /// @param resultPtr nullable bool trySet(TKey key, TVal&& value, TVal** resultPtr){ return trySet(key, value, resultPtr); } /// @param resultPtr nullable void addOrSet(TKey key, TVal& value, TVal** resultPtr){ if(root == nullptr){ root = new Node(key, value, Color::Black, nullptr); if(resultPtr != nullptr) *resultPtr = &root->value; return; } Node* parent = findParentForKey(key); // ptr to parent->right or parent->left Node** nodePtrPtr; if(key < parent->key) nodePtrPtr = &parent->left; else nodePtrPtr = &parent->right; // if a child node already exists at this place, sets it's value if(*nodePtrPtr != nullptr){ (*nodePtrPtr)->value = value; if(resultPtr) *resultPtr = &(*nodePtrPtr)->value; return; } // places newNode to left or right of the parent Node* newNode = new Node(key, value, Color::Red, parent); if(resultPtr != nullptr) *resultPtr = &newNode->value; *nodePtrPtr = newNode; // auto-balancing fixupInsertion(newNode); } /// @param resultPtr nullable void addOrSet(TKey key, TVal&& value, TVal** resultPtr){ addOrSet(key, value, resultPtr); } bool tryGet(TKey key, TVal** resultPtr) const { if(!resultPtr) return false; Node* parent = findParentForKey(key); Node* n = nullptr; if(parent == nullptr) n = root; else if(key < parent->key) n = parent->left; else n = parent->right; // if there is no node with the given key if(n == nullptr){ *resultPtr = nullptr; return false; } *resultPtr = &n->value; return true; } bool tryDelete(TKey key){ Node* parent = findParentForKey(key); Node* n = nullptr; if(parent == nullptr) n = root; else if(key < parent->key) n = parent->left; else n = parent->right; // key not found if(n == nullptr){ return false; } if(n->left == nullptr){ transplantNode(n, n->right); if(n->color == Color::Black && n->right != nullptr) fixupDeletion(n->right); } else if(n->right == nullptr){ transplantNode(n, n->left); if(n->color == Color::Black && n->left != nullptr) fixupDeletion(n->left); } else { Node* minNode = n->right->getMinChild(); if(minNode != n->right){ transplantNode(minNode, minNode->right); minNode->right = n->right; n->right->parent = minNode; } // else minNode->right->parent = minNode; // wtf??? transplantNode(n, minNode); minNode->left = n->left; n->left->parent = minNode; if(minNode->color == Color::Black) fixupDeletion(minNode->right); } delete n; return true; } iterator begin(){ if(root == nullptr) return iterator(nullptr); return iterator(root->getMinChild()); } iterator end(){ if(root == nullptr) return iterator(nullptr); return iterator(root->getMaxChild()); } const_iterator begin() const { if(root == nullptr) return const_iterator(nullptr); return const_iterator(root->getMinChild()); } const_iterator end() const { return const_iterator(nullptr); } void _generateGraphVizCodeForChildren(std::stringstream& ss, Node* n) const { if(!n) return; if(n->color == Color::Red) ss<<" \""<key<<"\" [color=red]\n"; if(n->left){ ss<<" \""<key<<"\" -> \""<left->key<<"\" [side=L]\n"; _generateGraphVizCodeForChildren(ss, n->left); } else ss<<" \""<key<<"\" -> \"null\" [side=L]\n"; if(n->right){ ss<<" \""<key<<"\" -> \""<right->key<<"\" [side=R]\n"; _generateGraphVizCodeForChildren(ss, n->right); } else ss<<" \""<key<<"\" -> \"null\" [side=R]\n"; } std::string generateGraphVizCode() const { std::stringstream ss; ss<<"digraph {\n" " node [style=filled,color=gray];\n"; if(root == nullptr) ss<<" \"null\"\n"; else { ss<<" \"null-parent\" -> \""<key<<"\"\n"; _generateGraphVizCodeForChildren(ss, root); } ss<<"}"; return ss.str(); } };