线段树

线段树

参考:https://github.com/atcoder/ac-library

1. 模板

模板的接口文档在这里,使用方法参考例题代码中的注释。线段树只支持单点修改、区间查询,区间修改、单点查询需要用到的懒标记线段树见 懒标记线段树的模板

cpp <atcoder/segtree>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#include <algorithm>
#include <cassert>
#include <functional>
#include <vector>


#ifdef _MSC_VER
#include <intrin.h>
#endif

#if __cplusplus >= 202002L
#include <bit>
#endif

namespace atcoder {

namespace internal {

#if __cplusplus >= 202002L

using std::bit_ceil;

#else

unsigned int bit_ceil(unsigned int n) {
unsigned int x = 1;
while (x < (unsigned int)(n)) x *= 2;
return x;
}

#endif

int countr_zero(unsigned int n) {
#ifdef _MSC_VER
unsigned long index;
_BitScanForward(&index, n);
return index;
#else
return __builtin_ctz(n);
#endif
}

constexpr int countr_zero_constexpr(unsigned int n) {
int x = 0;
while (!(n & (1 << x))) x++;
return x;
}

} // namespace internal

} // namespace atcoder


namespace atcoder {

#if __cplusplus >= 201703L

template <class S, auto op, auto e> struct segtree {
static_assert(std::is_convertible_v<decltype(op), std::function<S(S, S)>>,
"op must work as S(S, S)");
static_assert(std::is_convertible_v<decltype(e), std::function<S()>>,
"e must work as S()");

#else

template <class S, S (*op)(S, S), S (*e)()> struct segtree {

#endif

public:
segtree() : segtree(0) {}
explicit segtree(int n) : segtree(std::vector<S>(n, e())) {}
explicit segtree(const std::vector<S>& v) : _n(int(v.size())) {
size = (int)internal::bit_ceil((unsigned int)(_n));
log = internal::countr_zero((unsigned int)size);
d = std::vector<S>(2 * size, e());
for (int i = 0; i < _n; i++) d[size + i] = v[i];
for (int i = size - 1; i >= 1; i--) {
update(i);
}
}

void set(int p, S x) {
assert(0 <= p && p < _n);
p += size;
d[p] = x;
for (int i = 1; i <= log; i++) update(p >> i);
}

S get(int p) const {
assert(0 <= p && p < _n);
return d[p + size];
}

S prod(int l, int r) const {
assert(0 <= l && l <= r && r <= _n);
S sml = e(), smr = e();
l += size;
r += size;

while (l < r) {
if (l & 1) sml = op(sml, d[l++]);
if (r & 1) smr = op(d[--r], smr);
l >>= 1;
r >>= 1;
}
return op(sml, smr);
}

S all_prod() const { return d[1]; }

template <bool (*f)(S)> int max_right(int l) const {
return max_right(l, [](S x) { return f(x); });
}
template <class F> int max_right(int l, F f) const {
assert(0 <= l && l <= _n);
assert(f(e()));
if (l == _n) return _n;
l += size;
S sm = e();
do {
while (l % 2 == 0) l >>= 1;
if (!f(op(sm, d[l]))) {
while (l < size) {
l = (2 * l);
if (f(op(sm, d[l]))) {
sm = op(sm, d[l]);
l++;
}
}
return l - size;
}
sm = op(sm, d[l]);
l++;
} while ((l & -l) != l);
return _n;
}

template <bool (*f)(S)> int min_left(int r) const {
return min_left(r, [](S x) { return f(x); });
}
template <class F> int min_left(int r, F f) const {
assert(0 <= r && r <= _n);
assert(f(e()));
if (r == 0) return 0;
r += size;
S sm = e();
do {
r--;
while (r > 1 && (r % 2)) r >>= 1;
if (!f(op(d[r], sm))) {
while (r < size) {
r = (2 * r + 1);
if (f(op(d[r], sm))) {
sm = op(d[r], sm);
r--;
}
}
return r + 1 - size;
}
sm = op(d[r], sm);
} while ((r & -r) != r);
return 0;
}

private:
int _n, size, log;
std::vector<S> d;

void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};

} // namespace atcoder

以及使用和其它模板中一样的递归方式实现prod的版本:

cpp <atcoder/segtree>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#include <algorithm>
#include <cassert>
#include <functional>
#include <vector>


#ifdef _MSC_VER
#include <intrin.h>
#endif

#if __cplusplus >= 202002L
#include <bit>
#endif

namespace atcoder {

namespace internal {

#if __cplusplus >= 202002L

using std::bit_ceil;

#else

unsigned int bit_ceil(unsigned int n) {
unsigned int x = 1;
while (x < (unsigned int)(n)) x *= 2;
return x;
}

#endif

int countr_zero(unsigned int n) {
#ifdef _MSC_VER
unsigned long index;
_BitScanForward(&index, n);
return index;
#else
return __builtin_ctz(n);
#endif
}

constexpr int countr_zero_constexpr(unsigned int n) {
int x = 0;
while (!(n & (1 << x))) x++;
return x;
}

} // namespace internal

} // namespace atcoder


namespace atcoder {

#if __cplusplus >= 201703L

template <class S, auto op, auto e> struct segtree {
static_assert(std::is_convertible_v<decltype(op), std::function<S(S, S)>>,
"op must work as S(S, S)");
static_assert(std::is_convertible_v<decltype(e), std::function<S()>>,
"e must work as S()");

#else

template <class S, S (*op)(S, S), S (*e)()> struct segtree {

#endif

public:
segtree() : segtree(0) {}
explicit segtree(int n) : segtree(std::vector<S>(n, e())) {}
explicit segtree(const std::vector<S>& v) : _n(int(v.size())) {
size = (int)internal::bit_ceil((unsigned int)(_n));
log = internal::countr_zero((unsigned int)size);
d = std::vector<S>(2 * size, e());
for (int i = 0; i < _n; i++) d[size + i] = v[i];
for (int i = size - 1; i >= 1; i--) {
update(i);
}
}

void set(int p, S x) {
assert(0 <= p && p < _n);
p += size;
d[p] = x;
for (int i = 1; i <= log; i++) update(p >> i);
}

S get(int p) const {
assert(0 <= p && p < _n);
return d[p + size];
}

private:
S _prod(int l, int r, int s, int t, int p) const {
// [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if (l <= s && t <= r) return d[p];
int m = s + ((t - s) >> 1);
S res = e();
if (l <= m) res = op(res, _prod(l, r, s, m, (p << 1)));
if (r > m) res = op(res, _prod(l, r, m + 1, t, (p << 1) | 1));
return res;
}

public:
S prod(int l, int r) const {
assert(0 <= l && l <= r && r <= _n);
if (l == r) return e();
return _prod(l, r - 1, 0, size - 1, 1);
}

S all_prod() const { return d[1]; }

// TODO: vvv 改为以递归方式实现 vvv
template <bool (*f)(S)> int max_right(int l) const {
return max_right(l, [](S x) { return f(x); });
}
template <class F> int max_right(int l, F f) const {
assert(0 <= l && l <= _n);
assert(f(e()));
if (l == _n) return _n;
l += size;
S sm = e();
do {
while (l % 2 == 0) l >>= 1;
if (!f(op(sm, d[l]))) {
while (l < size) {
l = (2 * l);
if (f(op(sm, d[l]))) {
sm = op(sm, d[l]);
l++;
}
}
return l - size;
}
sm = op(sm, d[l]);
l++;
} while ((l & -l) != l);
return _n;
}

template <bool (*f)(S)> int min_left(int r) const {
return min_left(r, [](S x) { return f(x); });
}
template <class F> int min_left(int r, F f) const {
assert(0 <= r && r <= _n);
assert(f(e()));
if (r == 0) return 0;
r += size;
S sm = e();
do {
r--;
while (r > 1 && (r % 2)) r >>= 1;
if (!f(op(d[r], sm))) {
while (r < size) {
r = (2 * r + 1);
if (f(op(d[r], sm))) {
sm = op(d[r], sm);
r--;
}
}
return r + 1 - size;
}
sm = op(d[r], sm);
} while ((r & -r) != r);
return 0;
}

private:
int _n, size, log;
std::vector<S> d;

void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};

} // namespace atcoder

2. 例题

2.1 洛谷P3374 单点修改,查询区间和

  • 1 x k 将第 x 个数加上 k
  • 2 l r 输出区间 [l, r] 内每个数的和
cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include <bits/stdc++.h>

#include <atcoder/segtree>

using namespace std;

// S: 幺半群内元素的定义
using S = int;
// op: 幺半群内元素的二元运算,需要满足结合律和封闭性
S op(S a, S b) { return a + b; }
// e: 单位元
S e() { return 0; }

using segtree = atcoder::segtree<S, op, e>;

int main() {
ios::sync_with_stdio(false), cin.tie(0);
int N, M;
cin >> N >> M;
vector<int> a(N);
for (auto &x : a) cin >> x;
segtree seg(a);
while (M--) {
int op, x, y;
cin >> op >> x >> y;
if (op == 1) {
seg.set(x - 1, seg.get(x - 1) + y);
} else {
cout << seg.prod(x - 1, y) << "\n";
}
}
}

线段树
https://blog.fredbill.eu.org/2023/12/01/算法/数据结构/线段树/
作者
FredBill
发布于
2023年12月1日
许可协议