线段树

线段树是一种用来维护区间性质的数据结构,此篇不讲解基础,而是讲解在算法竞赛中的应用。读者需要至少会写线段树懒标记维护区间求和操作。

目前使用的模板

我的懒标记的定义是: 已经维护好了当前节点,但子节点还没维护好.

普通版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
const int maxn = 1e5 + 5;

int n, m;

typedef long long typec;
#define lson (o<<1)
#define rson (o<<1|1)
#define mid ((l+r)>>1)
#define len(x,y) ((y)-(x)+1)
typec sumv[maxn*4], addv[maxn*4];
typec A[maxn], qans, v;
int ql, qr;


void build(int o, int l, int r) {
addv[o] = 0;
if (l == r) { sumv[o] = A[l]; return; }
build(lson, l, mid); build(rson, mid+1, r);
sumv[o] = sumv[lson] + sumv[rson];
}

inline void ad(int o, int l, int r, int v) {
sumv[o] += len(l, r) * v;
addv[o] += v;
}

inline void pushdown(int o, int l, int r) {
if (!addv[o]) return;
ad(lson, l, mid, addv[o]);
ad(rson, mid+1, r, addv[o]);
addv[o] = 0;
}

void update(int o, int l, int r) {
if (ql <= l && r <= qr) ad(o, l, r, v);
else {
pushdown(o, l, r);
if (ql <= mid) update(lson, l, mid);
if (qr > mid) update(rson, mid+1, r);
sumv[o] = sumv[lson] + sumv[rson];
}
}

typec query(int o, int l, int r) {
if (ql <= l && r <= qr) return sumv[o];
else {
pushdown(o, l, r);
return (ql <= mid ? query(lson, l, mid) : 0) + (qr > mid ? query(rson, mid+1, r) : 0);
}
}

取模版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
const int maxn = 1e5 + 5;

typedef long long typec;
#define lson (o<<1)
#define rson (o<<1|1)
#define mid ((l+r)>>1)
#define len(x,y) ((y)-(x)+1)
typec sumv[maxn*4], addv[maxn*4], mulv[maxn*4];
typec A[maxn], qans, v, p;
int ql, qr;

inline typec add(typec a, typec b) {
return (a + b) % p;
}

inline typec mul(typec a, typec b) {
return (a * b) % p;
}

void build(int o, int l, int r) {
addv[o] = 0; mulv[o] = 1;
if (l == r) { sumv[o] = A[l]; return; }
build(lson, l, mid); build(rson, mid+1, r);
sumv[o] = add(sumv[lson], sumv[rson]);
}

inline void ad(int o, int l, int r, int v) {
addv[o] = add(addv[o], v);
sumv[o] = add(sumv[o], mul(v, len(l, r)));
}

inline void ml(int o, int l, int r, int v){
sumv[o] = mul(sumv[o], v);
mulv[o] = mul(mulv[o], v);
addv[o] = mul(addv[o], v);
}

void pushdown(int o, int l, int r) {
if (l == r || (mulv[o] == 1 && !addv[o])) return;
ml(lson, l, mid, mulv[o]); ml(rson, mid+1, r, mulv[o]);
ad(lson, l, mid, addv[o]); ad(rson, mid+1, r, addv[o]);
mulv[o] = 1; addv[o] = 0;
}

void updatemul(int o, int l, int r) {
if (ql <= l && r <= qr) ml(o, l, r, v);
else {
pushdown(o, l, r);
if (ql <= mid) updatemul(lson, l, mid);
if (qr > mid) updatemul(rson, mid+1, r);
sumv[o] = add(sumv[lson], sumv[rson]);
}
}

void updateadd(int o, int l, int r) {
if (ql <= l && r <= qr) ad(o, l, r, v);
else {
pushdown(o, l, r);
if (ql <= mid) updateadd(lson, l, mid);
if (qr > mid) updateadd(rson, mid+1, r);
sumv[o] = add(sumv[lson], sumv[rson]);
}
}

typec query(int o, int l, int r) {
if (ql <= l && r <= qr) return sumv[o] % p;
else {
pushdown(o, l, r);
return add(ql <= mid ? query(lson, l, mid) : 0, qr > mid ? query(rson, mid+1, r) : 0);
}
}

点修改版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
typedef long long typec;
#define lson (o<<1)
#define rson (o<<1|1)
#define mid ((l+r)>>1)
#define len(x,y) ((y)-(x)+1)
typec prod[maxn*4];
typec A[maxn], qans, v, p;
int ql, qr;

inline typec mul(typec a, typec b) {
return (a * b) % p;
}

void build(int o, int l, int r) {
if (l == r) { prod[o] = A[l]; return; }
build(lson, l, mid); build(rson, mid+1, r);
prod[o] = mul(prod[lson], prod[rson]);
}

void update(int o, int l, int r, int k) {
if (l == r) prod[o] = 1;
else {
if (k <= mid) update(lson, l, mid, k);
else update(rson, mid+1, r, k);
prod[o] = mul(prod[lson], prod[rson]);
}
}

typec query(int o, int l, int r) {
if (ql > qr) return 1;
if (ql <= l && r <= qr) return prod[o] % p;
else {
return mul(ql <= mid ? query(lson, l, mid) : 1, qr > mid ? query(rson, mid+1, r) : 1);
}
}

万用模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
template<typename T, class Lz>
struct SegTree {

#define M ((L+R)>>1)
#define lson (o<<1)
#define rson (o<<1|1)

vector<T> sumv;
vector<Lz> tag;
vector<T> A;

T (*add)(T, T);
T (*addt)(T, Lz, int, int);

SegTree(vector<T> &arr, int n,T (*fadd)(T,T), T (*faddt)(T, Lz, int, int)) {
sumv = vector<T>(n<<2);
tag = vector<Lz>(n<<2);
A = arr;
add = fadd;
addt = faddt;
build(1, 1, n);
}

inline void build(int o=1, int L=1, int R=n) {
tag[o] = 0; // -
if (L == R) sumv[o] = A[L];
else {
build(lson, L, M);
build(rson, M+1, R);
sumv[o] = add(sumv[lson], sumv[rson]);
}
}

inline void pushdown(int o, int L, int R) {
if (L == R || tag[o] == 0) return;
sumv[lson] = addt(sumv[lson], tag[o], L, M);
sumv[rson] = addt(sumv[rson], tag[o], M+1, R);
tag[lson] = tag[lson] + tag[o];
tag[rson] = tag[rson] + tag[o];
tag[o] = 0;
}

inline void update(int x, int y, Lz val, int o=1, int L=1, int R=n) {
if (x <= L && R <= y) sumv[o] = addt(sumv[o], val, L, R), tag[o] = tag[o] + val;
else {
pushdown(o, L, R);
if (x <= M) update(x, y, val, lson, L, M);
if (y > M) update(x, y, val, rson, M+1, R);
sumv[o] = add(sumv[lson], sumv[rson]);
}
}

inline T query(int x, int y, int o=1, int L=1, int R=n) {
pushdown(o, L, R);
if (x <= L && R <= y) return sumv[o];
else return add(x <= M ? query(x, y, lson, L, M) : 0, y > M ? query(x, y, rson, M+1, R) : 0);
}

void print() {
for (int i = 1;i <= n; ++i) {
cout << sumv[i] << " ";
}
cout << endl;
}
void printTag() {
cout << "tag: ";
for (int i = 1;i <= n; ++i) {
cout << tag[i].v << " ";
}
cout << endl;
}
};

struct Tag {
ll v;

Tag(ll a=0) {
v = a;
}
Tag operator=(const ll& num) {
v = num;
return *this;
}
Tag operator+(const Tag& a) {
/** 懒标记与懒标记 */
return Tag(v + a.v);
}
bool operator==(ll num) {
return v == num;
}
};

ll add(ll a, ll b) {
/** 本身加和逻辑 */
return a + b;
}

ll addt(ll a, ll t, int L, int R) {
/** 数组与懒标记加和逻辑 */
return a + t * (R-L+1);
}

使用注意事项

区间范围

build(1, 1, n) 后面2个参数是要维护的范围可以根据题目选择,比如可以改成 build(1, 0, n)​

读入优化

一般读入数据比较大,建议加上读入优化

1
2
ios_base::sync_with_stdio(0);
cin.tie(0);

基础应用

区间求和类

  • 区间求和

【模板】线段树1

  • 区间维护乘积与求和

【模板】线段树2

  • 区间覆盖问题

校门外的树(增强版)

  • 区间最大/最小值

I Hate It

  • 区间反转问题

可以转换为区间求和再对2取模

简单题

解题技巧

维护2个线段树

注意不能用1个懒标记同时维护2个线段树,很容易错

校门外的树(增强版)

通过区间加减来减少维护难度

剩余树苗=剩余树-没砍过的树

此题里面, 砍下的数苗=总种下数苗 - 剩余树苗

总种下树苗=每次种的树苗之和

每次种的树苗=种之后的树-种之前的树

加粗的部分可以用线段树维护,下划线是要求的

对查询操作建树

  • 数学计算

    除之前的数相当于在序列上把那个值变成1, 然后求乘积即可

权值线段树

权值线段树是用线段树来维护某个元素出现的个数,由于值域范围通常较大,权值线段树会采用离散化或动态开点的策略优化空间。

离散化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
cin >> n;
for (int i = 1;i <= n; ++i) {
cin >> B[i];
C[i] = B[i];
}
sort(C+1, C+n+1);
int m = unique(C+1, C+n+1) - C-1;


for (int i = 1;i <= n; ++i) {
B[i] = lower_bound(C+1, C+m+1, B[i]) - C;
}
for (int i = 1;i <= n; ++i) {
cout << C[B[i]] << " ";
}
cout << endl;

离散化的时候如果需要反映射可以加一个map

线段树的每个节点存储 \([l,r]\) 区间的数有多少个.

求比某个区间里的数有几个

  • 逆序对

    不用开 \(n\) 个线段树,因为每个区间求过一次就不用再求了,所以可以从小到大依次使用

  • 三元上升子序列

    也是计数问题,注意在查询的时候 ql <= qr 不然会死循环. 这题也是线段树的重复使用。先统计左边,统计右边的时候依次把元素删除即可.

  • NOI冒泡排序

    这题需要推导一下. 先分析冒泡排序的性质。定义 \(d[i]\) 为前面比自己大的数的个数。发现每一轮冒泡排序,\(d[i]=0\) 的那些数会与后面比自己小的数交换(减少逆序对),直到比自己大的人出现。然后交换过的数的 \(d[i]\) 会减小1. (因为前面有一个比他大的数跑后面去了). 若 \(|x|\) 表示 \(x\) 的数量的话,所以每一轮减小逆序对的数量是 \(n - |d[i]=0|\) ,因为每排一轮 \(d[i]\) 会相应减小,所以下一轮\(d[i]=0\) 的数是前一轮 \(d[i]=1\) 的数. 那么第 \(k\)\(|d[i]=0|\) 就是第开始 \(|d[i]\leq k-1|\) 所以经过 \(k\) 轮排序减小逆序对的数量是 \[ \sum_{s=0}^{k-1}{n-|d[i]\leq s|} \\ =kn-\sum_{s=0}^{k-1}{|d[i]\leq s|} \]\(ans\) 是开始总的逆序对数量,那么我们要求的 \(k\) 轮后逆序对的数量是 \[ ans -kn+\sum_{s=0}^{k-1}{|d[i]\leq s|} \] 其中后面那个求和可以用权值线段树来维护. 交换操作的话可以用点修改来维护.

模拟平衡树

需要离散化或者动态开点.

查询前驱后继可以用map的lower_boundupper_bound

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;

typedef long long typec;
#define lson (o<<1)
#define rson (o<<1|1)
#define mid ((l+r)>>1)
#define len(x,y) ((y)-(x)+1)

typec sumv[maxn*4];
int n, m, op[maxn], q[maxn], A[maxn], C[maxn], d;
map<int, int> f;
map<int, int> ma;

void insert(int o, int l, int r, int x) {
if (l == r) sumv[o]++;
else {
if (x <= mid) insert(lson, l, mid, x);
else insert(rson, mid+1, r, x);
sumv[o] = sumv[lson] + sumv[rson];
}
}

void del(int o, int l, int r, int x) {
if (l == r) sumv[o]--;
else {
if (x <= mid) del(lson, l, mid, x);
else del(rson, mid+1, r, x);
sumv[o] = sumv[lson] + sumv[rson];
}
}

int query(int o, int l, int r, int ql, int qr) {
if (ql > qr) return 0;
if (ql <= l && r <= qr) return sumv[o];
else return (ql <= mid ? query(lson, l, mid, ql, qr) : 0) + (qr > mid ? query(rson, mid+1, r, ql, qr) : 0);
}

int qrank(int o, int l, int r, int rk) {
if (l == r) return C[l];
else {
if (sumv[lson] >= rk) return qrank(lson, l, mid, rk);
else return qrank(rson, mid+1, r, rk - sumv[lson]);
}
}

int main() {
cin >> n;
for (int i = 1;i <= n; ++i) {
cin >> op[i] >> q[i];
if (op[i] == 1 || op[i] == 3 || op[i] >= 5) {
A[++m] = q[i];
C[m] = A[m];
}
}

sort(C+1, C+m+1);
d = unique(C+1, C+m+1) - (C+1);
for (int i = 1;i <= m; ++i) {
A[i] = lower_bound(C+1, C+d+1, A[i]) - C;
f[C[A[i]]] = A[i];
}

for (int i = 1;i <= n; ++i) {
if (op[i] == 1) { // 插入
insert(1, 1, d, f[q[i]]);
ma[q[i]]++;
} else if (op[i] == 2) { // 删除
del(1, 1, d, f[q[i]]);
ma[q[i]]--;
if (ma[q[i]] == 0) ma.erase(q[i]);
} else if (op[i] == 3) { // 查排名
cout << query(1, 1, d, 1, f[q[i]]-1) + 1 << endl;
} else if (op[i] == 4) { // 根据排名查数
cout << qrank(1, 1, d, q[i]) << endl;
} else if (op[i] == 5) { // 前驱
printf("%d\n",(--ma.lower_bound(q[i]))->first);
} else { // 后继
printf("%d\n",(ma.upper_bound(q[i]))->first);
}
}
return 0;
}

线段树进阶

进阶的线段树可能需要维护节点的多个信息

分治思想

答案要么在左区间,要么在右区间,要么在中间.

最大子段和

\(ls\) 是区间左节点开始延伸的最大值,\(rs\) 是右节点向左延伸的最大值,\(ss\) 是区间的总最大值,\(sum\) 是区间所有元素和.

维护的时候

  • \(ls = max(ls_{lson}, sum_{lson} + ls_{rson})\)
  • \(rs = max(rs_{rson}, sum_{rson} + rs_{lson})\)
  • \(sum = sum_{lson} + sum_{rson}\)
  • \(ss = max(ss_{lson}, ss_{rson}, rs_{lson} + ls_{rson})\)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 5;

typedef long long typec;
#define lson (o<<1)
#define rson (o<<1|1)
#define mid ((l+r)>>1)
#define len(x,y) ((y)-(x)+1)

int n, m, A[maxn];

struct Node {
int v, ls, rs, ss, sum;
} seg[maxn * 4];

void build(int o, int l, int r) {
if (l == r) seg[o].ls = seg[o].rs = seg[o].ss = seg[o].sum = A[l];
else {
build(lson, l, mid);
build(rson, mid+1, r);
seg[o].sum = seg[lson].sum + seg[rson].sum;
seg[o].ls = max(seg[lson].ls, seg[lson].sum + seg[rson].ls);
seg[o].rs = max(seg[rson].rs, seg[rson].sum + seg[lson].rs);
seg[o].ss = max(max(seg[lson].ss, seg[rson].ss), seg[lson].rs + seg[rson].ls);
}
}

int main() {
cin >> n;
for (int i = 1;i <= n; ++i) cin >> A[i];
build(1, 1, n);
cout << seg[1].ss << endl;
return 0;
}

与最大子段和类似,但是在标签转移的时候考虑条件. 只有不相等的时候才转移。注意的是这题不需要把所有状态(给左右是L,R的情况都定义对应的标签)都定义出来,那样会非常复杂。可以把0和1的两种情况综合起来,(因为2个标签必然有1个是0),只需要考虑端点处是否不相等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 5;

typedef long long ll;
#define lson (o<<1)
#define rson (o<<1|1)
#define mid ((l+r)>>1)
#define len(x,y) ((y)-(x)+1)

int n, m, A[maxn], q;

struct Node {
int ls, rs, ss;
} seg[maxn * 4];

void update(int o, int l, int r, int x) {
if (l != r) {
if (x <= mid) update(lson, l, mid, x);
else update(rson, mid+1, r, x);
seg[o].ss = max(max(seg[lson].ss, seg[rson].ss), (A[mid] != A[mid+1] ? seg[lson].rs + seg[rson].ls : 0));
seg[o].ls = max(seg[lson].ls, seg[lson].ls == len(l, mid) && A[mid] != A[mid+1] ? seg[lson].ls + seg[rson].ls : 0);
seg[o].rs = max(seg[rson].rs, seg[rson].rs == len(mid+1, r) && A[mid] != A[mid+1] ? seg[rson].rs + seg[lson].rs : 0);
}
}

int main() {
cin >> n >> q;
for (int i = 1;i <= n * 4; ++i) seg[i].ls = seg[i].rs = seg[i].ss = 1;
for (int i = 1, x;i <= q; ++i) {
cin >> x;
A[x] = (A[x] + 1) % 2;
update(1, 1, n, x);
cout << seg[1].ss << endl;
}

return 0;
}