1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
| template <class S, S (*op)(S, S), S (*e)(), S (*diff)(S, S) = nullptr> class persistent_segtree { struct Node { S val; mutable int ref_cnt; const Node *left, *right; ~Node() { if (left != nullptr && !--left->ref_cnt) delete left; if (right != nullptr && !--right->ref_cnt) delete right; } };
int _n; const Node *_root, *_other; static const Node* _build(const std::vector<S>& v, int s, int t) { Node* node = new Node; node->ref_cnt = 1; if (s == t) { node->val = v[s]; node->left = node->right = nullptr; return node; } int m = s + ((t - s) >> 1); node->left = _build(v, s, m); node->right = _build(v, m + 1, t); node->val = op(node->left->val, node->right->val); return node; } static const Node* _set(const Node* root, int p, S x, int s, int t) { Node* node = new Node; node->ref_cnt = 1; if (s == t) { node->val = x; node->left = node->right = nullptr; return node; } int m = s + ((t - s) >> 1); if (p <= m) { node->left = _set(root->left, p, x, s, m); node->right = root->right, ++node->right->ref_cnt; } else { node->left = root->left, ++node->left->ref_cnt; node->right = _set(root->right, p, x, m + 1, t); } node->val = op(node->left->val, node->right->val); return node; } static S _get(const Node* root, const Node* other, int p, int s, int t) { if (s == t) return other == nullptr ? root->val : diff(root->val, other->val); int m = s + ((t - s) >> 1); if (p <= m) return _get(root->left, other == nullptr ? nullptr : other->left, p, s, m); return _get(root->right, other == nullptr ? nullptr : other->right, p, m + 1, t); } static S _prod(const Node* root, const Node* other, int l, int r, int s, int t) { if (l <= s && t <= r) return other == nullptr ? root->val : diff(root->val, other->val); int m = s + ((t - s) >> 1); S res = e(); if (l <= m) res = _prod(root->left, other == nullptr ? nullptr : other->left, l, r, s, m); if (r > m) res = op(res, _prod(root->right, other == nullptr ? nullptr : other->right, l, r, m + 1, t)); return res; } template <class F> static int _max_right(const Node* root, const Node* other, int l, F f, S& sm, int s, int t) { if (l == s) if (S nxt = op(sm, other == nullptr ? root->val : diff(root->val, other->val)); f(nxt)) return sm = nxt, t; if (s == t) return s - 1; int m = s + ((t - s) >> 1); if (l > m) return _max_right(root->right, other == nullptr ? nullptr : other->right, l, f, sm, m + 1, t); int r = _max_right(root->left, other == nullptr ? nullptr : other->left, l, f, sm, s, m); if (r < m) return r; return _max_right(root->right, other == nullptr ? nullptr : other->right, m + 1, f, sm, m + 1, t); } template <class F> static int _min_left(const Node* root, const Node* other, int r, F f, S& sm, int s, int t) { if (r == t) if (S nxt = op(other == nullptr ? root->val : diff(root->val, other->val), sm); f(nxt)) return sm = nxt, s; if (s == t) return t + 1; int m = s + ((t - s) >> 1); if (r <= m) return _min_left(root->left, other == nullptr ? nullptr : other->left, r, f, sm, s, m); int l = _min_left(root->right, other == nullptr ? nullptr : other->right, r, f, sm, m + 1, t); if (l > m + 1) return l; return _min_left(root->left, other == nullptr ? nullptr : other->left, m, f, sm, s, m); }
explicit persistent_segtree(int n, const Node* root, const Node* other = nullptr) : _n(n), _root(root), _other(other) {}
public: persistent_segtree() : persistent_segtree(0, nullptr) {} explicit persistent_segtree(int n) : persistent_segtree(std::vector<S>(n, e())) {} explicit persistent_segtree(const std::vector<S>& v) : persistent_segtree(v.size(), v.size() ? _build(v, 0, v.size() - 1) : nullptr) {} persistent_segtree(persistent_segtree&& rhs) noexcept : persistent_segtree(rhs._n, rhs._root) { rhs._root = nullptr; } persistent_segtree& operator=(persistent_segtree&& rhs) noexcept { if (this != &rhs) { _n = rhs._n; _root = rhs._root; rhs._root = nullptr; } return *this; } persistent_segtree(const persistent_segtree& rhs) : persistent_segtree(rhs._n, rhs._root) { ++_root->ref_cnt; } persistent_segtree& operator=(const persistent_segtree& rhs) { if (this != &rhs) { if (_root != nullptr && !--_root->ref_cnt) delete _root; _n = rhs._n; _root = rhs._root; ++_root->ref_cnt; } return *this; } ~persistent_segtree() { if (_root != nullptr && !--_root->ref_cnt) delete _root; if (_other != nullptr && !--_other->ref_cnt) delete _other; } int size() const { return _n; } persistent_segtree set(int p, S x) { assert(_other == nullptr); assert(0 <= p && p < _n); return persistent_segtree(_n, _set(_root, p, x, 0, _n - 1)); } S get(int p) const { assert(0 <= p && p < _n); return _get(_root, _other, p, 0, _n - 1); } S prod(int l, int r) const { assert(0 <= l && l <= r && r <= _n); if (l == r) return e(); return _prod(_root, _other, l, r - 1, 0, _n - 1); } persistent_segtree operator-(const persistent_segtree& other) const { static_assert(diff != nullptr, "Must specify diff function"); assert(_other == nullptr && other._other == nullptr); assert(_n == other._n); ++_root->ref_cnt, ++other._root->ref_cnt; return persistent_segtree(_n, _root, other._root); } S all_prod() const { assert(_root != nullptr); return _other == nullptr ? _root->val : diff(_root->val, _other->val); } template <bool (*f)(S)> int max_right(int l) const { return max_right(_root, l, [](S x) { return f(x); }); } template <class F> int max_right(int l, F f) const { assert(0 <= l && l < _n); assert(f(e())); S sm = e(); return _max_right(_root, _other, l, f, sm, 0, _n - 1); }
template <bool (*f)(S)> int min_left(int r) const { return min_left(_root, r, [](S x) { return f(x); }); } template <class F> int min_left(int r, F f) const { assert(0 <= r && r < _n); assert(f(e())); S sm = e(); return _min_left(_root, _other, r, f, sm, 0, _n - 1); } };
|