树上线段树

学了这么久的线段树,竟然从没写过树上线段树。其实本质还是线段树。只是把一开始的树结构用 dfs 序变成线段的形式,然后再用线段树加速。

问题是给你一颗树,每个节点都有权重,问从起点 0 开始必须经过节点 x 的权重和的最大值。经过 dfs,$L[x], R[x]$ 是以 $x$ 为根的节点。那么每次更新 $x$ 的值相当于给区间 $L[x], R[x]$ 加一个值。而最终我们的问题,其实就是求区间 $L[x], R[x]$ 的最大值。最后写线段树时要延迟更新。

例题 hdu

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
template<typename T>
void upmax(T &a,T b){ if(a<b) a=b;}
#define lrt rt<<1
#define rrt rt<<1|1
#define lson l,m,lrt
#define rson m+1,r,rrt
const int N=1e5+2;
LL mx[N*3],flag[N*3];
int head[N], sc, ncnt, L[N], R[N];
struct Node{
int ed;
int next;
}e[N<<1];
void init(int n){
sc = ncnt = 0;
for(int i=0;i<=n;++i) head[i]=-1;
}
void addedge(int u, int v){
e[sc].ed = v;
e[sc].next = head[u];
head[u] = sc++;
}
void dfs(int x, int fa){ // dfs order
L[x] = ++ncnt;
for(int i=head[x]; i!=-1; i=e[i].next){
if(e[i].ed != fa) dfs(e[i].ed, x);
}
R[x] = ncnt;
}
void pushUp(int rt){
mx[rt]=max(mx[lrt],mx[rrt]);
}
void pushDown(int rt){
if(flag[rt]){
flag[lrt]+=flag[rt];
flag[rrt]+=flag[rt];
mx[lrt]+=flag[rt];
mx[rrt]+=flag[rt];
flag[rt]=0;
}
}
void build(int l,int r,int rt){
flag[rt] = mx[rt] = 0;
if(l == r) return;
int m=(l+r)>>1;
build(lson);
build(rson);
}
void add(int p, int L, int R, int l, int r, int rt){
if(L <= l && r<= R){
mx[rt] += p;
flag[rt] += p;
return;
}
pushDown(rt);
int m = (l+r)>>1;
if(L <= m) add(p, L, R, lson);
if(R > m) add(p, L, R, rson);
pushUp(rt);
}
LL query(int L, int R, int l, int r, int rt){
if(L <=l && r <= R) return mx[rt];
pushDown(rt);
int m=(l+r)>>1;
LL ans= -1LL<<62;
if(L <= m) upmax(ans, query(L, R, lson));
if(R > m) upmax(ans, query(L, R, rson));
return ans;
}
int a[N];
int main(){
// freopen("/Users/dna049/Desktop/AC/in","r",stdin);
// freopen("/Users/dna049/Desktop/AC/out","w",stdout);
// openStack();
int T, Case = 0;
scanf("%d",&T);
while(T--){
printf("Case #%d:\n",++Case);
int n, m, u, v, op;
scanf("%d%d", &n, &m);
init(n);
for(int i=1; i<n; ++i){
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dfs(0, -1);
build(1, n, 1);
for(int i=0; i<n; ++i){
scanf("%d",a+i);
add(a[i], L[i], R[i], 1, n, 1);
}
while(m--){
scanf("%d%d",&op, &u);
if(op == 0){
scanf("%d",&v);
v -= a[u];
add(v, L[u], R[u], 1, n, 1);
a[u] += v;
}else{
cout<<query(L[u], R[u], 1, n, 1)<<endl;
}
}
}
return 0;
}