534 lines
15 KiB
C++
534 lines
15 KiB
C++
#pragma once
|
|
|
|
#include "../../dependencies/kerep/src/base/base.h"
|
|
#include <map>
|
|
#include <sstream>
|
|
|
|
// typedef char* TKey;
|
|
// typedef char* TVal;
|
|
template<typename TKey, typename TVal>
|
|
class RBTree {
|
|
enum class Color {
|
|
Red, Black
|
|
};
|
|
|
|
struct Node {
|
|
TKey key;
|
|
TVal value;
|
|
Color color;
|
|
Node* parent;
|
|
Node* left = nullptr;
|
|
Node* right = nullptr;
|
|
|
|
Node(TKey _key, TVal _val, Color _color, Node* _parent)
|
|
: key(_key), value(_val), color(_color), parent(_parent)
|
|
{
|
|
}
|
|
|
|
~Node(){
|
|
if(left != nullptr)
|
|
delete left;
|
|
if(right != nullptr)
|
|
delete right;
|
|
}
|
|
|
|
inline Node* getSibling(){
|
|
if(parent == nullptr)
|
|
return nullptr;
|
|
else if(parent->left == this)
|
|
return parent->right;
|
|
else return parent->left;
|
|
}
|
|
|
|
inline Node* getGrandparent(){
|
|
if(parent == nullptr)
|
|
return nullptr;
|
|
else return parent->parent;
|
|
}
|
|
|
|
inline Node* getUncle(){
|
|
if(parent == nullptr)
|
|
return nullptr;
|
|
return parent->getSibling();
|
|
}
|
|
|
|
// n should be not null
|
|
Node* getMinChild(){
|
|
Node* n = this;
|
|
while(n->left != nullptr)
|
|
n = n->left;
|
|
return n;
|
|
}
|
|
|
|
// n should be not null
|
|
Node* getMaxChild(){
|
|
Node* n = this;
|
|
while(n->right != nullptr)
|
|
n = n->right;
|
|
return n;
|
|
}
|
|
};
|
|
|
|
|
|
Node* root = nullptr;
|
|
|
|
///@returns null if root is null
|
|
Node* findParentForKey(TKey key) const {
|
|
Node* n = root;
|
|
Node* parent = nullptr;
|
|
while(n != nullptr){
|
|
parent = n;
|
|
if(key < n->key)
|
|
n = n->left;
|
|
else if(key > n->key)
|
|
n = n->right;
|
|
else return n->parent; // key == n->key
|
|
}
|
|
return parent;
|
|
}
|
|
|
|
void rotateLeft(Node* x){
|
|
// 1. get right child of x
|
|
Node* y = x->right;
|
|
// 2. move y to the position of x
|
|
y->parent = x->parent;
|
|
if (x->parent != nullptr){ // x != root
|
|
if(x == x->parent->left)
|
|
x->parent->left = y;
|
|
else x->parent->right = y;
|
|
}
|
|
else root = y;
|
|
// 3. move y.left to x.right if it exists
|
|
x->right = y->left;
|
|
if (x->right != nullptr)
|
|
x->right->parent = x;
|
|
// 4. move x to y.left
|
|
y->left = x;
|
|
x->parent = y;
|
|
}
|
|
|
|
void rotateRight(Node* x){
|
|
// 1. get left child of x
|
|
Node* y = x->left;
|
|
// 2. move y up
|
|
y->parent = x->parent;
|
|
if (x->parent != nullptr){ // x != root
|
|
if(x == x->parent->left)
|
|
x->parent->left = y;
|
|
else x->parent->right = y;
|
|
}
|
|
else root = y;
|
|
// 3. move y.right to x.left if it exists
|
|
x->left = y->right;
|
|
if (x->left != nullptr)
|
|
x->left->parent = x;
|
|
// 4. move x to y.right
|
|
y->right = x;
|
|
x->parent = y;
|
|
}
|
|
|
|
void transplantNode(Node* old, Node* neww){
|
|
if(old->parent == nullptr)
|
|
root = neww;
|
|
else if(old->parent->left == old)
|
|
old->parent->left = neww;
|
|
else old->parent->right = neww;
|
|
if(neww != nullptr)
|
|
neww->parent = old->parent;
|
|
}
|
|
|
|
void fixupInsertion(Node* n){
|
|
// case 1: n is root -- root must be black
|
|
if (n->parent == nullptr){
|
|
n->color = Color::Black;
|
|
return;
|
|
}
|
|
|
|
// case 2: parent is black -- no requirements mismatch
|
|
if (n->parent->color == Color::Black)
|
|
return;
|
|
|
|
// case 3: parent and uncle are red -- red nodes must have black parents
|
|
Node* u = n->getUncle();
|
|
Node* g = n->getGrandparent();
|
|
if(u != nullptr && u->color == Color::Red){
|
|
n->parent->color = Color::Black;
|
|
u->color = Color::Black;
|
|
g->color = Color::Red;
|
|
fixupInsertion(g);
|
|
return;
|
|
}
|
|
|
|
// case 4: parent is red and uncle is black -- red nodes must have black parents
|
|
if ((n == n->parent->right) && (n->parent == g->left)) {
|
|
rotateLeft(n->parent);
|
|
n = n->left;
|
|
}
|
|
else if ((n == n->parent->left) && (n->parent == g->right)) {
|
|
rotateRight(n->parent);
|
|
n = n->right;
|
|
}
|
|
|
|
// case 5
|
|
n->parent->color = Color::Black;
|
|
g->color = Color::Red;
|
|
if ((n == n->parent->left) && (n->parent == g->left))
|
|
rotateRight(g);
|
|
else rotateLeft(g);
|
|
}
|
|
|
|
void fixupDeletion(Node* n){
|
|
// case 1
|
|
if(n->parent == nullptr)
|
|
return;
|
|
|
|
// case 2
|
|
Node* s = n->getSibling();
|
|
if(s->color == Color::Red){
|
|
n->parent->color = Color::Red;
|
|
s->color = Color::Black;
|
|
if (n == n->parent->left)
|
|
rotateLeft(n->parent);
|
|
else rotateRight(n->parent);
|
|
}
|
|
|
|
// case 3
|
|
if ((n->parent->color == Color::Black) &&
|
|
(s->color == Color::Black) &&
|
|
(s->left->color == Color::Black) &&
|
|
(s->right->color == Color::Black))
|
|
{
|
|
s->color = Color::Red;
|
|
fixupDeletion(n->parent);
|
|
return;
|
|
}
|
|
|
|
// case 4
|
|
else if ((n->parent->color == Color::Red) &&
|
|
(s->color == Color::Black) &&
|
|
(s->left->color == Color::Black) &&
|
|
(s->right->color == Color::Black))
|
|
{
|
|
s->color = Color::Red;
|
|
n->parent->color = Color::Black;
|
|
return;
|
|
}
|
|
|
|
// case 5
|
|
if(s->color == Color::Black) {
|
|
if ((n == n->parent->left) &&
|
|
(s->right->color == Color::Black) &&
|
|
(s->left->color == Color::Red))
|
|
{
|
|
s->color = Color::Red;
|
|
s->left->color = Color::Black;
|
|
rotateRight(s);
|
|
}
|
|
else if ((n == n->parent->right) &&
|
|
(s->left->color == Color::Black) &&
|
|
(s->right->color == Color::Red))
|
|
{
|
|
s->color = Color::Red;
|
|
s->right->color = Color::Black;
|
|
rotateLeft(s);
|
|
}
|
|
}
|
|
|
|
// case 6
|
|
s->color = n->parent->color;
|
|
n->parent->color = Color::Black;
|
|
|
|
if (n == n->parent->left) {
|
|
s->right->color = Color::Black;
|
|
rotateLeft(n->parent);
|
|
} else {
|
|
s->left->color = Color::Black;
|
|
rotateRight(n->parent);
|
|
}
|
|
}
|
|
|
|
|
|
template<typename TIteratorValue>
|
|
struct TreeIterator : std::iterator<std::bidirectional_iterator_tag, TIteratorValue> {
|
|
Node* n;
|
|
|
|
TreeIterator(TreeIterator const& src){
|
|
n = src.n;
|
|
}
|
|
|
|
TreeIterator(Node* ptr){
|
|
n = ptr;
|
|
}
|
|
|
|
bool operator!=(TreeIterator const& other) const {
|
|
return n != other.n;
|
|
}
|
|
|
|
bool operator==(TreeIterator const& other) const {
|
|
return n == other.n;
|
|
}
|
|
|
|
TIteratorValue& operator*() const {
|
|
if(n == nullptr)
|
|
throw "RBTree::TreeIterator::operator*() error: n == nullptr";
|
|
return *((TIteratorValue*)(void*)n);
|
|
}
|
|
|
|
void operator++() {
|
|
if(n == nullptr)
|
|
return;
|
|
if(n->right)
|
|
n = n->right->getMinChild();
|
|
else {
|
|
Node* p = n->parent;
|
|
while(p != nullptr && n == p->right){
|
|
n = p;
|
|
p = p->parent;
|
|
}
|
|
n = p;
|
|
}
|
|
}
|
|
};
|
|
|
|
public:
|
|
using iterator = TreeIterator<std::pair<const TKey, TVal>>;
|
|
using const_iterator = TreeIterator<std::pair<const TKey, const TVal>>;
|
|
|
|
RBTree() {}
|
|
|
|
~RBTree(){
|
|
delete root;
|
|
}
|
|
|
|
|
|
/// @param resultPtr nullable
|
|
bool tryAdd(TKey key, TVal& value, TVal** resultPtr){
|
|
if(root == nullptr){
|
|
root = new Node(key, value, Color::Black, nullptr);
|
|
if(resultPtr)
|
|
*resultPtr = &root->value;
|
|
return true;
|
|
}
|
|
|
|
Node* parent = findParentForKey(key);
|
|
// ptr to parent->right or parent->left
|
|
Node** nodePtrPtr;
|
|
if(key < parent->key)
|
|
nodePtrPtr = &parent->left;
|
|
else nodePtrPtr = &parent->right;
|
|
|
|
// if a child node already exists at this place, returns false
|
|
if(*nodePtrPtr != nullptr){
|
|
if(resultPtr)
|
|
*resultPtr = nullptr;
|
|
return false;
|
|
}
|
|
|
|
// places newNode to left or right of the parent
|
|
Node* newNode = new Node(key, value, Color::Red, parent);
|
|
if(resultPtr)
|
|
*resultPtr = &newNode->value;
|
|
*nodePtrPtr = newNode;
|
|
// auto-balancing
|
|
fixupInsertion(newNode);
|
|
return true;
|
|
}
|
|
|
|
/// @param resultPtr nullable
|
|
bool tryAdd(TKey key, TVal&& value, TVal** resultPtr){
|
|
return tryAdd(key, value, resultPtr);
|
|
}
|
|
|
|
|
|
/// @param resultPtr nullable
|
|
bool trySet(TKey key, TVal& value, TVal** resultPtr){
|
|
if(root == nullptr){
|
|
if(resultPtr)
|
|
*resultPtr = nullptr;
|
|
return false;
|
|
}
|
|
|
|
Node* parent = findParentForKey(key);
|
|
// ptr to parent->right or parent->left
|
|
Node** nodePtrPtr;
|
|
if(key < parent->key)
|
|
nodePtrPtr = &parent->left;
|
|
else nodePtrPtr = &parent->right;
|
|
|
|
// if a child node with the given key doesn't exist, returns false
|
|
if(*nodePtrPtr == nullptr){
|
|
if(resultPtr)
|
|
*resultPtr = nullptr;
|
|
return false;
|
|
}
|
|
|
|
// replaces the value of left or right child of the parent
|
|
(*nodePtrPtr)->value = value;
|
|
if(resultPtr)
|
|
*resultPtr = &(*nodePtrPtr)->value;
|
|
return true;
|
|
}
|
|
|
|
/// @param resultPtr nullable
|
|
bool trySet(TKey key, TVal&& value, TVal** resultPtr){
|
|
return trySet(key, value, resultPtr);
|
|
}
|
|
|
|
|
|
/// @param resultPtr nullable
|
|
void addOrSet(TKey key, TVal& value, TVal** resultPtr){
|
|
if(root == nullptr){
|
|
root = new Node(key, value, Color::Black, nullptr);
|
|
if(resultPtr != nullptr)
|
|
*resultPtr = &root->value;
|
|
return;
|
|
}
|
|
|
|
Node* parent = findParentForKey(key);
|
|
// ptr to parent->right or parent->left
|
|
Node** nodePtrPtr;
|
|
if(key < parent->key)
|
|
nodePtrPtr = &parent->left;
|
|
else nodePtrPtr = &parent->right;
|
|
|
|
// if a child node already exists at this place, sets it's value
|
|
if(*nodePtrPtr != nullptr){
|
|
(*nodePtrPtr)->value = value;
|
|
if(resultPtr)
|
|
*resultPtr = &(*nodePtrPtr)->value;
|
|
return;
|
|
}
|
|
|
|
// places newNode to left or right of the parent
|
|
Node* newNode = new Node(key, value, Color::Red, parent);
|
|
if(resultPtr != nullptr)
|
|
*resultPtr = &newNode->value;
|
|
*nodePtrPtr = newNode;
|
|
// auto-balancing
|
|
fixupInsertion(newNode);
|
|
}
|
|
|
|
/// @param resultPtr nullable
|
|
void addOrSet(TKey key, TVal&& value, TVal** resultPtr){
|
|
addOrSet(key, value, resultPtr);
|
|
}
|
|
|
|
|
|
bool tryGet(TKey key, TVal** resultPtr) const {
|
|
if(!resultPtr)
|
|
return false;
|
|
Node* parent = findParentForKey(key);
|
|
Node* n = nullptr;
|
|
if(parent == nullptr)
|
|
n = root;
|
|
else if(key < parent->key)
|
|
n = parent->left;
|
|
else n = parent->right;
|
|
// if there is no node with the given key
|
|
if(n == nullptr){
|
|
*resultPtr = nullptr;
|
|
return false;
|
|
}
|
|
|
|
*resultPtr = &n->value;
|
|
return true;
|
|
}
|
|
|
|
|
|
bool tryDelete(TKey key){
|
|
Node* parent = findParentForKey(key);
|
|
Node* n = nullptr;
|
|
if(parent == nullptr)
|
|
n = root;
|
|
else if(key < parent->key)
|
|
n = parent->left;
|
|
else n = parent->right;
|
|
// key not found
|
|
if(n == nullptr){
|
|
return false;
|
|
}
|
|
|
|
if(n->left == nullptr){
|
|
transplantNode(n, n->right);
|
|
if(n->color == Color::Black && n->right != nullptr)
|
|
fixupDeletion(n->right);
|
|
}
|
|
else if(n->right == nullptr){
|
|
transplantNode(n, n->left);
|
|
if(n->color == Color::Black && n->left != nullptr)
|
|
fixupDeletion(n->left);
|
|
}
|
|
else {
|
|
Node* minNode = n->right->getMinChild();
|
|
if(minNode != n->right){
|
|
transplantNode(minNode, minNode->right);
|
|
minNode->right = n->right;
|
|
n->right->parent = minNode;
|
|
}
|
|
// else minNode->right->parent = minNode; // wtf???
|
|
transplantNode(n, minNode);
|
|
minNode->left = n->left;
|
|
n->left->parent = minNode;
|
|
if(minNode->color == Color::Black)
|
|
fixupDeletion(minNode->right);
|
|
}
|
|
|
|
delete n;
|
|
return true;
|
|
}
|
|
|
|
|
|
iterator begin(){
|
|
if(root == nullptr)
|
|
return iterator(nullptr);
|
|
return iterator(root->getMinChild());
|
|
}
|
|
|
|
iterator end(){
|
|
if(root == nullptr)
|
|
return iterator(nullptr);
|
|
return iterator(root->getMaxChild());
|
|
}
|
|
|
|
const_iterator begin() const {
|
|
if(root == nullptr)
|
|
return const_iterator(nullptr);
|
|
return const_iterator(root->getMinChild());
|
|
}
|
|
|
|
const_iterator end() const {
|
|
return const_iterator(nullptr);
|
|
}
|
|
|
|
void _generateGraphVizCodeForChildren(std::stringstream& ss, Node* n) const {
|
|
if(!n)
|
|
return;
|
|
if(n->color == Color::Red)
|
|
ss<<" \""<<n->key<<"\" [color=red]\n";
|
|
if(n->left){
|
|
ss<<" \""<<n->key<<"\" -> \""<<n->left->key<<"\" [side=L]\n";
|
|
_generateGraphVizCodeForChildren(ss, n->left);
|
|
}
|
|
else ss<<" \""<<n->key<<"\" -> \"null\" [side=L]\n";
|
|
if(n->right){
|
|
ss<<" \""<<n->key<<"\" -> \""<<n->right->key<<"\" [side=R]\n";
|
|
_generateGraphVizCodeForChildren(ss, n->right);
|
|
}
|
|
else ss<<" \""<<n->key<<"\" -> \"null\" [side=R]\n";
|
|
}
|
|
|
|
std::string generateGraphVizCode() const {
|
|
std::stringstream ss;
|
|
ss<<"digraph {\n"
|
|
" node [style=filled,color=gray];\n";
|
|
if(root == nullptr)
|
|
ss<<" \"null\"\n";
|
|
else {
|
|
ss<<" \"null-parent\" -> \""<<root->key<<"\"\n";
|
|
_generateGraphVizCodeForChildren(ss, root);
|
|
}
|
|
ss<<"}";
|
|
return ss.str();
|
|
}
|
|
};
|