0%

基于变换合并的树上动态DP的链分治算法和全局平衡二叉树学习笔记

引入

在有些dp中,转移可以用一种具有结合律的变换描述,并且可以快速合并

因此我们使用数据结构维护,来支持修改并快速得到全局或某个子结构的dp值

直接看例题吧


例题

例题一 【模板】动态dp

Luogu P4719 【模板】动态dp

题意

给定一棵 \(n\) 个点的树,第 \(i\) 个点有点权 \(a_i\)

\(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

\(n,m\le10^5\)

分析

一个弱化

对于没有修改的情况,我们有一种简单的 \(\mathcal O(n)\) 树形dp

钦定一个根节点

\(f_{u,0/1}\) 表示以 \(u\) 为根的子树, \(u\) 号点不选/选时的最大权独立集的权值大小

\[ \begin{align} f_{u,0}&=\sum_{u\to v}max(f_{v,0},f_{v,1})\\ f_{u,1}&=a_u+\sum_{u\to v}f_{v,0} \end{align} \]

另一个弱化

对于树是一条链的情况,上述dp可以大大简化

树链剖分

考虑树剖,令 \(son_u\) 表示 \(u\) 的重儿子

dp转化为

\[ \begin{align} f_{u,0}&=max(f_{son_u,0},f_{son_u,1})+\sum_{u\to v,v\ne son_u}max(f_{v,0},f_{v,1})\\ f_{u,1}&=a_u+f_{son_u,0}+\sum_{u\to v,v\ne son_u}f_{v,0} \end{align} \]

\[ \begin{align} g_{u,0}&=\sum_{u\to v,v\ne son_u}max(f_{v,0},f_{v,1})\\ g_{u,1}&=a_u+\sum_{u\to v,v\ne son_u}f_{v,0} \end{align} \]

从而

\[ \begin{align} f_{u,0}&=max(f_{son_u,0},f_{son_u,1})+g_{u,0}\\ f_{u,1}&=f_{son_u,0}+g_{u,1} \end{align} \]

矩阵

我们用重新定义的线性变换来描述这个转移

  • 把乘法变成加法
  • 把加法变成取max

\[ \begin{bmatrix} g_{u,0} & g_{u,0} \\ g_{u,1} & -\infty \end{bmatrix} \begin{pmatrix} f_{son_u,0} \\ f_{son_u,1} \end{pmatrix} = \begin{pmatrix} f_{u,0} \\ f_{u,1} \end{pmatrix} \]

可以自己验证一下

并且可以证明这个这样的矩阵乘法是具有结合律的

需要求一个点的dp值的时候,只需要将这个点走重儿子走到底的矩阵乘起来就好了

考虑到一个点跳到根只会经过 \(\mathcal O(\log n)\) 条轻边,我们用线段树维护矩阵的积,修改的时候反复执行如下操作

  1. 修改当前点的矩阵
  2. 跳到重链的顶端,计算dp值,更新父亲的 \(g\),并跳到父亲处

单次修改复杂度是 \(\mathcal O(\log^2n)\),查询 \(\mathcal O(\log n)\)

单位矩阵是

\[ \begin{bmatrix} 0 & -\infty \\ -\infty & 0 \end{bmatrix} \]

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>

using namespace std;
#define ll long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 100005;
int n, m, num, cnt, dfn[N], idfn[N], a[N], fa[N], h[N], siz[N], top[N], last[N], e[N<<1], pre[N<<1], f[N][2], g[N][2];
struct matrix{
int a[2][2];
inline matrix(){ memset(a, 0, sizeof a);}
inline matrix(int x, int y){ a[0][0]=a[0][1]=x, a[1][0]=y, a[1][1]=-1e9;}
inline matrix operator *(const matrix &rhs)const{
matrix ans;
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j)
ans.a[i][j]=max(a[i][0]+rhs.a[0][j], a[i][1]+rhs.a[1][j]);
return ans;
}
}s[N<<2];

inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
void dfs1(int u){
siz[u]=1, f[u][1]=a[u];
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u]){
fa[e[i]]=u, dfs1(e[i]), siz[u]+=siz[e[i]];
f[u][0]+=max(f[e[i]][0], f[e[i]][1]), f[u][1]+=f[e[i]][0];
}
}
void dfs2(int u){
idfn[dfn[u]=++cnt]=u;
int son=0;
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u] && siz[e[i]]>siz[son]) son=e[i];
if(son) top[son]=top[u], dfs2(son), last[u]=last[son]; else last[u]=cnt;
g[u][1]+=a[u];
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u] && e[i]!=son){
top[e[i]]=e[i], dfs2(e[i]);
g[u][0]+=max(f[e[i]][0], f[e[i]][1]), g[u][1]+=f[e[i]][0];
}
}
void build(int l, int r, int t){
if(l==r) return (void)(s[t]=matrix(g[idfn[l]][0], g[idfn[l]][1]));
int mid=l+r>>1, k=t<<1;
build(l, mid, k), build(mid+1, r, k|1);
s[t]=s[k]*s[k|1];
}
void modify(int l, int r, int t, int x){
if(l==r) return (void)(s[t]=matrix(g[idfn[l]][0], g[idfn[l]][1]));
int mid=l+r>>1, k=t<<1;
if(x<=mid) modify(l, mid, k, x); else modify(mid+1, r, k|1, x);
s[t]=s[k]*s[k|1];
}
matrix query(int l, int r, int t, int L, int R){
if(L<=l && r<=R) return s[t];
int mid=l+r>>1, k=t<<1;
if(R<=mid) return query(l, mid, k, L, R);
if(L>mid) return query(mid+1, r, k|1, L, R);
return query(l, mid, k, L, R)*query(mid+1, r, k|1, L, R);
}
int main() {
read(n), read(m);
for(int i=1; i<=n; ++i) read(a[i]);
for(int i=1; i<n; ++i){
static int x, y;
read(x), read(y), add(x, y), add(y, x);
}
dfs1(1), top[1]=1, dfs2(1);
build(1, n, 1);
while(m--){
static int x, y;
read(x), read(y);
g[x][1]+=y-a[x], a[x]=y;
while(x){
modify(1, n, 1, dfn[x]);
x=top[x];
matrix tmp=query(1, n, 1, dfn[x], last[x]);
g[fa[x]][0]-=max(f[x][0], f[x][1]), g[fa[x]][1]-=f[x][0];
f[x][0]=tmp.a[0][0], f[x][1]=tmp.a[1][0];
g[fa[x]][0]+=max(f[x][0], f[x][1]), g[fa[x]][1]+=f[x][0];
x=fa[x];
}
print(max(f[1][0], f[1][1])), print('\n');
}
return flush(), 0;
}

例题二 动态dp【加强版】

Luogu P4751 动态dp【加强版】

题意

同上

强制在线

\(n,m\le10^6\)

分析

数据范围要求了更优秀的复杂度,有人会想到 \(\mathcal O((n+q)\log n)\) 的LCT,但是LCT实际表现并不理想..

考虑到这里不需要一些link, cut和换根操作,我们可以构造一种类似Link-Cut Trees的静态结构

全局平衡二叉树

概述

像LCT一样,把每条重链用一棵辅助二叉树维护,辅助树之间用虚边连接,重链之间也构成了一棵有根树

每个节点需要维护自己所在的重链的辅助树的子树矩阵积

事实上前面的线段树也是一种类似的结构,只是每棵都保持了绝对的平衡,导致复杂度 \(\mathcal O(\log^2n)\)

而我们需要找到一种给辅助树定制合适形态的方法,做到全局平衡

构造

定义点 \(u\) 的权重 \(w_u=size_u-size_{son_u}\),即所有轻儿子的size和+1

构造一条重链的辅助树的时候,令带权重心为根,左右递归构造即可

复杂度

可以证明这样总的一棵全局平衡二叉树的深度是 \(\mathcal O(\log n)\)

具体可以参考文末的资料1

这样我们只需要在这棵树上向上跳并更新,在跳虚边的时候注意类似的特判(更新父亲的 \(g\)

代码

由于是第一次写,这里构造全局平衡二叉树的写法和下面的例题三略有不同

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>

using namespace std;
#define ll long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 1000005;
int n, m, num, root, lastans, top[N], son[N], s[N], b[N], a[N], w[N], siz[N], fa[N], h[N], e[N<<1], pre[N<<1], ch[N][2], f[N][2], g[N][2];
struct matrix{
int a[2][2];
inline matrix(){ memset(a, 0, sizeof a);}
inline matrix(int x, int y){ a[0][0]=a[0][1]=x, a[1][0]=y, a[1][1]=-1e9;}
inline matrix operator *(const matrix &rhs)const{
matrix ans;
for(int i=0; i<2; ++i) for(int j=0; j<2; ++j)
ans.a[i][j]=max(a[i][0]+rhs.a[0][j], a[i][1]+rhs.a[1][j]);
return ans;
}
}F[N], G[N];

inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
void update(int x){ F[x]=F[ch[x][0]]*G[x]*F[ch[x][1]];}
int build(int l, int r){
if(l>r) return 0;
int x=(s[l-1]+s[r]+1)/2, L=l, R=r, t=r;
while(L<=R){
int mid=L+R>>1;
if(s[mid]>=x) t=mid, R=mid-1; else L=mid+1;
}
int u=b[t];
fa[ch[u][0]=build(l, t-1)]=u, fa[ch[u][1]=build(t+1, r)]=u;
return update(u), u;
}
void dfs1(int u){
siz[u]=1, f[u][1]=a[u];
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u]){
fa[e[i]]=u, dfs1(e[i]), siz[u]+=siz[e[i]];
if(siz[e[i]]>siz[son[u]]) son[u]=e[i];
f[u][0]+=max(f[e[i]][0], f[e[i]][1]), f[u][1]+=f[e[i]][0];
}
w[u]=siz[u]-siz[son[u]];
}
void dfs2(int u){
if(son[u]) top[son[u]]=top[u], dfs2(son[u]);
g[u][1]+=a[u];
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u] && e[i]!=son[u]){
top[e[i]]=e[i], dfs2(e[i]);
g[u][0]+=max(f[e[i]][0], f[e[i]][1]), g[u][1]+=f[e[i]][0];
}
G[u]=matrix(g[u][0], g[u][1]);

if(top[u]==u){
int cnt=0;
for(int j=u; j; j=son[j]) b[++cnt]=j, s[cnt]=s[cnt-1]+w[j];
int tmp=fa[u];
fa[root=build(1, cnt)]=tmp;
}
}
int main() {
read(n), read(m);
for(int i=1; i<=n; ++i) read(a[i]);
for(int i=1; i<n; ++i){
static int x, y;
read(x), read(y), add(x, y), add(y, x);
}
F[0].a[0][1]=F[0].a[1][0]=-1e9;
dfs1(1), top[1]=1, dfs2(1);

while(m--){
static int x, y, v;
read(x), read(y);
x^=lastans;
g[x][1]+=y-a[x], a[x]=y;
G[x]=matrix(g[x][0], g[x][1]);
while(x){
v=fa[x];
if(ch[v][0]!=x && ch[v][1]!=x)
g[v][0]-=max(F[x].a[0][0], F[x].a[1][0]), g[v][1]-=F[x].a[0][0];
update(x);
if(ch[v][0]!=x && ch[v][1]!=x)
g[v][0]+=max(F[x].a[0][0], F[x].a[1][0]), g[v][1]+=F[x].a[0][0],
G[v]=matrix(g[v][0], g[v][1]);
x=v;
}
print(lastans=max(F[root].a[0][0], F[root].a[1][0])), print('\n');
}
return flush(), 0;
}

例题三 洪水

BZOJ 4712 洪水

题意

你有一棵 \(n\) 个点的树,第 \(i\) 个点的点权是 \(a_i\)

\(q\) 次操作

  • C x y表示修改第 \(x\) 个点的点权为 \(y\)

  • Q x表示询问删除一些点使得 \(x\) 的子树中的每个叶子与 \(x\) 不连通的最小代价

其中删除一个点的代价是它的点权,总代价是每个删除的点的代价的和

\(n,q\le2*10^5\)

分析

dp

考虑静态的dp,令 \(f_u\) 表示以 \(u\) 为根的子树的答案,有

\[f_u=min(a_u,\sum_{u\to v}f_v)\]

\(son_u\) 表示 \(u\) 的重儿子,则

\[f_u=min(a_u,f_{son_u}+\sum_{u\to v,v\ne son_u}f_v)\]

\[g_u=\sum_{u\to v,v\ne son_u}f_v\]

特殊地令叶子节点 \(g_i=\infty\)

\[f_u=min(a_u,f_{son_u}+g_u)\]

变换

这可以看做一个变换 \(trans_{a,b}(x)=min(a,x+b)\) ,用这种方式可以解决在重链上的转移

一个节点的答案就是这个点到所在重链的底端的变换反顺序作用在 \(0\)

而两个这样的变换合并后仍然是同样的形式

\[ \begin{align} &trans_{c,d}(trans_{a,b}(x))\\ =&trans_{c,d}(min(a,x+b))\\ =&min(c,min(a,x+b)+d)\\ =&min(c,min(a+d,x+b+d))\\ =&min(min(c,a+d),x+b+d)\\ =&trans_{min(c,a+d),b+d}(x) \end{align} \]

复杂度

如果用例题一的方法直接线段树维护,可以做到修改 \(\mathcal O(\log^2n)\),询问 \(\mathcal O(\log n)\)

而使用全局平衡二叉树可以做到 \(\mathcal O(\log n)\)

询问

由于这里的询问不是全局,我们要把重链上深度不比 \(x\) 小的所有点的变换合并起来,这在二叉树上走一下就好了

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>

using namespace std;
#define ll long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 200005;
const ll inf = 1e15;
int n, m, num, root, son[N], w[N], a[N], b[N], siz[N], h[N], fa[N], e[N<<1], pre[N<<1], ch[N][2];
ll f[N];
bool isr[N];
struct tf{
ll a, b;
inline tf(){}
inline tf(ll x, ll y){ a=x, b=y;}
inline tf operator *(const tf &rhs)const{
return tf(min(a, b+rhs.a), b+rhs.b);
}
}F[N], G[N];

inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
inline void update(int u){ F[u]=F[ch[u][0]]*G[u]*F[ch[u][1]];}
int divide(int l, int r){
if(l>r) return 0;
int sum=0, t, x=0;
for(int i=l; i<=r; ++i) sum+=w[b[i]];
for(t=l; t<=r; ++t) if((x+=w[b[t]])*2>=sum) break;
int u=b[t];
fa[ch[u][0]=divide(l, t-1)]=u, fa[ch[u][1]=divide(t+1, r)]=u;
return update(u), u;
}
inline int build(int u){
int cnt=0;
for(; u; u=son[u]) b[++cnt]=u;
int r=divide(1, cnt);
return isr[r]=1, r;
}
void dfs(int u){
siz[u]=1;
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u]){
fa[e[i]]=u, dfs(e[i]), siz[u]+=siz[e[i]], f[u]+=f[e[i]];
if(siz[e[i]]>siz[son[u]]) son[u]=e[i];
}
w[u]=siz[u]-siz[son[u]];
if(son[u]) G[u]=tf(a[u], f[u]-f[son[u]]), f[u]=min(f[u], (ll)a[u]);
else G[u]=tf(a[u], inf), f[u]=a[u];
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa[u] && e[i]!=son[u])
fa[build(e[i])]=u;
}
int main() {
read(n);
for(int i=1; i<=n; ++i) read(a[i]);
for(int i=1; i<n; ++i){
static int x, y;
read(x), read(y), add(x, y), add(y, x);
}
F[0]=tf(inf, 0);
dfs(1), fa[root=build(1)]=0;

read(m);
while(m--){
static char opt;
static int x, y;
while(isspace(opt=read()));
read(x);
if(opt=='Q'){
tf ans=G[x]*F[ch[x][1]];
while(!isr[x]){
if(ch[fa[x]][0]==x) ans=ans*G[fa[x]]*F[ch[fa[x]][1]];
x=fa[x];
}
print(ans.a), print('\n');
}
else{
read(y);
G[x].a+=y;
while(x){
if(isr[x]) G[fa[x]].b-=F[x].a;
update(x);
if(isr[x]) G[fa[x]].b+=F[x].a;
x=fa[x];
}
}
}
return flush(), 0;
}

例题四 切树游戏

题解


总结

动态dp还是很神奇的

全局平衡二叉树,zx2003说拿这个卡掉多一个 \(\log\) 的不好,那就不好吧..

不过代码大概确实比树剖线段树短


参考资料

  1. 杨哲《SPOJ375 QTREE 解法的一些研究》

  2. 陈俊锟《〈神奇的子图〉命题报告及其拓展》

  3. 基于变换合并的树上动态 DP 的链分治算法 & SDOI2017 切树游戏(cut)解题报告