0%

「BZOJ 1921」「Ctsc2010」珠宝商

BZOJ 1921

题意

给一棵 $n$ 个点的树和长度为 $m$ 的特征串,树的每个节点有一个字符。

求随机两个点形成有向路径上构成的串在特征串里出现次数的期望

仅含小写字母,$n,m\le 5*10^4$


分析

首先可以发现两种算法

  • 暴力处理

对”特征串”建 SAM

枚举路径的一个端点, dfs 另一个端点,同时维护在 SAM 上的位置.

每到一个位置会有 SAM 上对应节点的 right 集合大小的贡献

复杂度 $\mathcal O({\rm size}^2)$

  • 处理经过一个点的所有路径

设这个点是 $u$ ,字符为 $a[u]$

需要建出正反特征串的后缀树

考虑从点 $u$ 出发的所有路径($a[u]$ 为字符串的开头),统计出以特征串的每一位为起始的这些串的数量

同理将这些串翻转($a[u]$为串的最后一位),统计出在特征串每一位结束的串的数量

对应位上两组串数量的乘积的和即为贡献,因为某正串在一位起始,某反串在这位结束,即可拼出一个路径

由于 dfs 时正串是每次在末尾加字符维护在特征串中起始位置,反串是每次开头加字符维护结束位置,将特征串和路径串翻转后即为同一个问题,我们只考虑 push_front 维护结束位置

后缀自动机的转移只能支持末尾插入,于是需要利用后缀树

后缀树上一条边会对应原串中的一段区间

转移时需要注意从上往下走了不满一条边的情况,此时大概需要走到儿子处,注意判无转移时无解

每次在节点上打标记,最后全部下放到叶子处 统计每个位置的出现次数

复杂度$\mathcal O({\rm size}+m)$


实现

然后使用点分治,若分治的大小$size>\sqrt{m}$使用方法2, 否则暴力做方法1

这里需要注意同一子树的去重 在去重的时候应使用对应的方法保证复杂度

易知复杂度为$\mathcal O((n+m)\sqrt{m})$


代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>

using namespace std;
#define ll long long
#define rep(i,x,y) for(int i=(x);i<=(y);++i)
#define travel(i,x) for(int i=h[x];i;i=pre[i])

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s == t ? t = (s = buf) + fread(buf, 1, IN_LEN, stdin), (s == t ? -1 : *s++) : *s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig = false, c = read(); !isdigit(c); c = read()) {
if (c == '-') iosig = true;
if (c == -1) return;
}
for (x = 0; isdigit(c); c = read()) x = ((x + (x << 2)) << 1) + (c ^ '0');
if (iosig) x = -x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh = obuf;
inline void print(char c) {
if (ooh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh = obuf;
*ooh++ = c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x == 0) print('0');
else {
if (x < 0) print('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }
const int N = 50005;
int n, m, num, tot1[N], tot2[N], h[N], pre[N<<1], e[N<<1];
bool vis[N];
char a[N], s[N];
inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
struct sam{
int last, cnt, b[N], str[N], t[N<<1], lazy[N<<1], q[N<<1], siz[N<<1], len[N<<1], fa[N<<1], ch[N<<1][26], son[N<<1][26];
bool isl[N<<1];
inline sam(){ last=cnt=1;}
inline void ins(int c){
int p=last, np=++cnt;
last=np, str[len[np]=len[p]+1]=c, t[np]=len[np];
while(p && !ch[p][c]) ch[p][c]=np, p=fa[p];
if(!p) fa[np]=1;
else{
int q=ch[p][c];
if(len[q]==len[p]+1) fa[np]=q;
else{
int nq=++cnt;
len[nq]=len[p]+1, memcpy(ch[nq], ch[q], sizeof ch[0]);
t[nq]=t[q], fa[nq]=fa[q], fa[q]=fa[np]=nq;
while(ch[p][c]==q) ch[p][c]=nq, p=fa[p];
}
}
siz[np]=isl[np]=1;
}
inline void init(){
rep(i, 1, cnt) ++b[len[i]];
rep(i, 1, m) b[i]+=b[i-1];
rep(i, 1, cnt) q[b[len[i]]--]=i;
for(int i=cnt; i>1; --i) son[fa[q[i]]][str[t[q[i]]-len[fa[q[i]]]]]=q[i], siz[fa[q[i]]]+=siz[q[i]];
}
inline void trans(int &p, int c){ p=ch[p][c];}
void dfs5(int u, int fa, int p, int l){
if(!p) return;
if(l==len[p]) p=son[p][a[u]];
else if(str[t[p]-l]!=a[u]) p=0;
if(!p) return;
++lazy[p];
travel(i, u) if(e[i]!=fa && !vis[e[i]]) dfs5(e[i], u, p, l+1);
}
inline void work(int *tot){
rep(i, 2, cnt) lazy[q[i]]+=lazy[fa[q[i]]];
rep(i, 1, cnt) if(isl[i]) tot[len[i]]=lazy[i];
memset(lazy, 0, sizeof lazy);
}
}sam1, sam2;
ll ans;
int root, ctr, Siz, mn, top, lim, siz[N], stk[N];
void dfs1(int u, int fa=0){
siz[u]=1;
int mx=0;
travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs1(e[i], u), siz[u]+=siz[e[i]], mx=max(mx, siz[e[i]]);
mx=max(mx, Siz-siz[u]);
if(mx<mn) mn=mx, ctr=u;
}
inline int getctr(int u, int size){ return Siz=mn=size, dfs1(u), ctr;}
void dfs3(int u, int p, int W, int fa=0){
sam1.trans(p, a[u]);
if(p){
ans+=W*sam1.siz[p];
travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs3(e[i], p, W, u);
}
}
void dfs2(int u, int fa=0){
stk[++top]=a[u];

int p=1;
for(int i=top; i; --i) sam1.trans(p, stk[i]);
dfs3(root, p, -1);
travel(i, u) if(e[i]!=fa && !vis[e[i]]) dfs2(e[i], u);
--top;
}
void dfs4(int u, int fa=0){
dfs3(u, 1, 1);
travel(i, u) if(!vis[e[i]] && e[i]!=fa) dfs4(e[i], u);
}
void solve(int u, int fa=0, int v=0){
int size=Siz;
if(size<=lim){
if(fa){
stk[top=1]=a[fa];
root=v, dfs2(v);
}
dfs4(u);
}
else{
if(fa){
sam1.dfs5(v, fa, sam1.son[1][a[fa]], 1), sam2.dfs5(v, fa, sam2.son[1][a[fa]], 1);
sam1.work(tot1), sam2.work(tot2);
rep(i, 1, m) ans-=tot1[i]*tot2[m-i+1];
}
sam1.dfs5(u, 0, 1, 0), sam2.dfs5(u, 0, 1, 0);
sam1.work(tot1), sam2.work(tot2);
rep(i, 1, m) ans+=tot1[i]*tot2[m-i+1];
vis[u]=1;
travel(i, u) if(!vis[e[i]]) solve(getctr(e[i], siz[e[i]]<siz[u]?siz[e[i]]:size-siz[u]), u, e[i]);
}
}
int main() {
read(n), read(m);
lim=sqrt(m);
rep(i, 2, n){
static int x, y;
read(x), read(y);
add(x, y), add(y, x);
}
while(isspace(a[1]=read()));
rep(i, 2, n) a[i]=read();
rep(i, 1, n) a[i]-='a';
while(isspace(s[1]=read()));
rep(i, 2, m) s[i]=read();
rep(i, 1, m) sam1.ins(s[i]-='a');
for(int i=m; i; --i) sam2.ins(s[i]);

sam1.init(), sam2.init();
solve(getctr(1, n));
return printf("%lld", ans), 0;
}