主要作用

线段树是主要用来维护区间信息的数据结构
线段树可以在O(logN)的时间复杂度内实现单点修改、区间修改、区间查询(区间求和、求区间最大值、求区间最小值)等操作

基本结构

将数组a={10,11,12,13,14}转化为线段树:设线段树的根节点编号为1,用数组d来保存我们的线段树,did_i用来保存线段树上编号为i的节点的值,这里每个节点所维护的值就是这个节点所表示的区间总和。

did_i的左儿子节点就是d2id_{2*i},did_i的右儿子节点就是d2i+1d_{2*i+1}.如果did_i表示的是区间[s,t](即did_i=asa_s+as+1a_{s+1}+...+ata_t)的话,那么did_i的左儿子节点表示的是区间[s,s+t2\frac{s+t}{2}],did_i的右儿子表示的区间为s+t2\frac{s+t}{2}+1到t的闭区间

建树

在实现时,我们考虑递归建树。设当前的根节点为pp,如果根节点管辖的区间长度已经是1,则可以直接根据a数组上相应位置的值初始化该节点。否则我们将该区间从中点处分割为两个子区间,分别进入左右子节点递归建树,最后合并两个子节点的信息。

  • c++
1
2
3
4
5
6
7
8
9
10
11
12
13
void build(int s, int t, int p) {
// 对 [s,t] 区间建立线段树,当前根的编号为 p
if (s == t) {
d[p] = a[s];
return;
}
int m = s + ((t - s) >> 1);
// 移位运算符的优先级小于加减法,所以加上括号
// 如果写成 (s + t) >> 1 可能会超出 int 范围
build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
// 递归对左右区间建树
d[p] = d[p * 2] + d[(p * 2) + 1];
}
  • python
1
2
3
4
5
6
7
8
9
10
11
12
def build(s, t, p):
# 对 [s,t] 区间建立线段树,当前根的编号为 p
if s == t:
d[p] = a[s]
return
m = s + ((t - s) >> 1)
# 移位运算符的优先级小于加减法,所以加上括号
# 如果写成 (s + t) >> 1 可能会超出 int 范围
build(s, m, p * 2)
build(m + 1, t, p * 2 + 1)
# 递归对左右区间建树
d[p] = d[p * 2] + d[(p * 2) + 1]
  • golang
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package main

import "fmt"

// 假设数组 a 和线段树数组 d 已经定义
var a []int
var d []int

func build(s, t, p int) {
// 对 [s,t] 区间建立线段树,当前根的编号为 p
if s == t {
d[p] = a[s]
return
}
m := s + ((t - s) >> 1)
// 递归对左右区间建树
build(s, m, p*2)
build(m+1, t, p*2+1)
d[p] = d[p*2] + d[p*2+1]
}

func main() {
// 示例
a = []int{0, 1, 2, 3, 4, 5} // 下标从 1 开始
d = make([]int, 4*len(a))
build(1, len(a)-1, 1)
fmt.Println(d)
}

  • java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public class SegmentTree {
static int[] a; // 原数组,下标从 1 开始
static int[] d; // 线段树数组

static void build(int s, int t, int p) {
// 对 [s,t] 区间建立线段树,当前根的编号为 p
if (s == t) {
d[p] = a[s];
return;
}
int m = s + ((t - s) >> 1);
// 递归对左右区间建树
build(s, m, p * 2);
build(m + 1, t, p * 2 + 1);
d[p] = d[p * 2] + d[p * 2 + 1];
}

public static void main(String[] args) {
a = new int[]{0, 1, 2, 3, 4, 5}; // 下标从 1 开始
d = new int[4 * a.length];
build(1, a.length - 1, 1);
for (int val : d) {
if (val != 0) System.out.print(val + " ");
}
}
}

关于线段树的空间:如果采用堆式存储,2p2ppp的左儿子,2p+12p+1pp的右儿子,若有n个叶子节点,则d数组的范围最大为2logn+12^{\lceil log_n \rceil +1}

容易知道线段树的深度是logn\lceil log_n \rceil的,则在堆式存储情况下叶子节点(包括无用的叶子节点)数量为2logn2^{\left\lceil\log{n}\right\rceil}个,又由于其为一棵完全二叉树,则其总节点个2logn+1 2^{\left\lceil\log{n}\right\rceil+1}-1。当然如果你懒得计算的话可以直接把数组长度设为4n 4n,因为 2logn+11n\frac{2^{\left\lceil\log{n}\right\rceil+1}-1}{n}的最大值在n=2x+1(xN+)n=2^{x}+1(x\in N_{+})时取到,此时节点数为2logn+11=2x+21=4n52^{\left\lceil\log{n}\right\rceil+1}-1=2^{x+2}-1=4n-5

而堆式存储存在无用的叶子节点,可以考虑使用内存池管理线段树节点,每当需要新建节点时从池中获取。自底向上考虑,必有每两个底层节点合并为一个上层节点,因此可以类似哈夫曼树地证明,如果有nn个叶子节点,这样的线段树总共有2n12n-1个节点。其空间效率优于堆式存储,并且是可能的最优情况

这样的线段树可以自底向上维护

线段树的区间查询

区间查询,比如求区间[l,r]的总和(即ala_l+al+1a_{l+1}+…+ara_r)、求区间最大值/最小值等操作

  • cpp
1
2
3
4
5
6
7
8
9
10
11
int getsum(int l, int r, int s, int t, int p) {
// [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if (l <= s && t <= r)
return d[p]; // 当前区间为询问区间的子集时直接返回当前区间的和
int m = s + ((t - s) >> 1), sum = 0;
if (l <= m) sum += getsum(l, r, s, m, p * 2);
// 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子
if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
// 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子
return sum;
}
  • python
1
2
3
4
5
6
7
8
9
10
11
12
13
def getsum(l, r, s, t, p):
# [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if l <= s and t <= r:
return d[p] # 当前区间为询问区间的子集时直接返回当前区间的和
m = s + ((t - s) >> 1)
sum = 0
if l <= m:
sum = sum + getsum(l, r, s, m, p * 2)
# 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子
if r > m:
sum = sum + getsum(l, r, m + 1, t, p * 2 + 1)
# 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子
return sum
  • golang
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// d 是全局数组,存储线段树节点的值
func getsum(l, r, s, t, p int, d []int) int {
// [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if l <= s && t <= r {
return d[p] // 当前区间为询问区间的子集时直接返回
}
m := s + ((t - s) >> 1)
sum := 0
if l <= m {
sum += getsum(l, r, s, m, p*2, d)
}
if r > m {
sum += getsum(l, r, m+1, t, p*2+1, d)
}
return sum
}

  • java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class SegmentTree {
int[] d; // 存储线段树节点的数组

public SegmentTree(int size) {
d = new int[size * 4]; // 足够容纳线段树节点
}

public int getsum(int l, int r, int s, int t, int p) {
// [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if (l <= s && t <= r) {
return d[p]; // 当前区间为询问区间的子集时直接返回
}
int m = s + ((t - s) >> 1);
int sum = 0;
if (l <= m) {
sum += getsum(l, r, s, m, p * 2);
}
if (r > m) {
sum += getsum(l, r, m + 1, t, p * 2 + 1);
}
return sum;
}
}