1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include <bits/stdc++.h>
template <class T = int> class weighted_dsu { int _n; T _mod; struct node { int parent_or_size; T weight; }; std::vector<node> _nodes; T safe_mod(T x) const { if (_mod == T{}) return x; x %= _mod; if (x < 0) x += _mod; return x; }
public: weighted_dsu() : weighted_dsu(0) {} explicit weighted_dsu(int n, T mod = T{}) : _n(n), _mod(mod), _nodes(n, {-1, T{}}) {} weighted_dsu(const weighted_dsu&) = default; weighted_dsu(weighted_dsu&&) = default; weighted_dsu& operator=(const weighted_dsu&) = default; weighted_dsu& operator=(weighted_dsu&&) = default; int leader(int a) { assert(0 <= a && a < _n); std::stack<int> stk; int root = a; while (_nodes[root].parent_or_size >= 0) { stk.push(root); root = _nodes[root].parent_or_size; } for (int parent = root; !stk.empty(); stk.pop()) { int cur = stk.top(); _nodes[cur].parent_or_size = root; _nodes[cur].weight = safe_mod(_nodes[cur].weight + _nodes[parent].weight); parent = cur; } return root; } int size(int a) { assert(0 <= a && a < _n); return -_nodes[leader(a)].parent_or_size; } T weight(int a) { assert(0 <= a && a < _n); leader(a); return _nodes[a].weight; } std::optional<T> relation(int a, int b) { assert(0 <= a && a < _n); assert(0 <= b && b < _n); int leader_a = leader(a), leader_b = leader(b); if (leader_a != leader_b) return std::nullopt; return safe_mod(_nodes[a].weight - _nodes[b].weight); } bool merge(int a, int b, T c) { assert(0 <= a && a < _n); assert(0 <= b && b < _n); c = safe_mod(c); int leader_a = leader(a), leader_b = leader(b); if (leader_a == leader_b) return safe_mod(_nodes[a].weight - _nodes[b].weight) == c; if (-_nodes[leader_a].parent_or_size < -_nodes[leader_b].parent_or_size) { std::swap(leader_a, leader_b); std::swap(a, b); c = safe_mod(-c); } _nodes[leader_a].parent_or_size += _nodes[leader_b].parent_or_size; _nodes[leader_b].parent_or_size = leader_a; _nodes[leader_b].weight = safe_mod(safe_mod(_nodes[a].weight - _nodes[b].weight) - c); return true; } };
|