「LOJ 6509」「雅礼集训 2018 Day7」C

LOJ #6509. 「雅礼集训 2018 Day7」C

题意

给定一棵 \(n\) 个点的树,树上每个点初始有一个 \(0\)\(1\) 的数字。

考虑这样一个过程:

  1. 等概率随机选择一个点作为起点
  2. 等概率随机选择一个新点并沿着树上的路径移动过去,最后反转这个新点上的数字(注意只反转这个新点上的数字而非经过的所有点的数字)
  3. 如果此时整棵树上的所有数字相同,则过程结束;否则回到步骤 \(2\)

求出期望的移动距离,对 \(10^9 + 7\) 取模。

\(n\le 10^5\)

做法

考虑每次选出的点构成的序列

一个点只要不是在最后出现,由于下一步是等概率随机的,所以产生的贡献是固定的,是该点与树上所有点的平均距离

那么我们只要计算每个数不在最后一个位置出现的期望次数就可以计算答案

可以发现期望出现次数只与 \(0\)\(1\) 的数量有关

\(f_{i,0/1}\) 表示有 \(i\)\(1\) 时,一个位置\(0/1\) 期望的出现次数

显然有 \(f_{0, * } = 0\)\(f_{n, * } = 0\)

考虑方程

\[ \begin{align} f_{i,0} &= \frac{i}{n} f_{i-1, 0} + \frac{n-i-1}{n} f_{i+1, 0} + \frac{1}{n} f_{i+1,1} + \frac{1}{n} \\ f_{i,1} &= \frac{n-i}{n} f_{i+1, 1} + \frac{i-1}{n} f_{i-1, 1} + \frac{1}{n} f_{i-1,0} + \frac{1}{n} \end{align} \]

即在这一次选择了

  • 与钦定位置不同色
  • 钦定位置
  • 同色的非钦定位置

以及每次在 \(n\) 个点中选到的概率 \(\frac{1}{n}\)

注意这里的 \(\frac{1}{n}\) 要保证这不是最后一个位置才能存在,于是在 \(f_{1,1}\)\(f_{n-1,0}\) 处的式子不能加

直接高斯消元是 \(\mathcal O(n^3)\)

考虑到上面的式子很优美

保留 \(f_{1,0},f_{1,1}\) 两个未知数计算,从小到大推

假设已经知道了对于 \(0\le i\le k\)\(f_{i, * }\),我们可以用 \(f_{k,1}\) 的式子直接计算 \(f_{k+1,1}\),再用 \(f_{k,0}\) 的式子和 \(f_{k+1,1}\) 算出 \(f_{k+1,0}\)

这样可以推出所有 \(i<n\)\(f_{i, * }\)

利用 \(f_{n-1,* }\) 的两个式子解出两个未知数,代入可以求得我们需要的 \(f_{s, * }\) 其中 \(s\) 是初始 \(1\) 的个数

注意这里的 \(f\) 没有计算到随机选择的起点,每个点的期望次数要加另外的 \(\frac{1}{n}\)

复杂度 \(\mathcal O(n)\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>

using namespace std;
#define ll long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 100005, P = 1000000007;

int n, num, ans, b, w, tot, x, y, inv[N], siz[N], h[N], e[N], pre[N];
ll f[N];
char s[N];
struct p{
int a, b, c;
inline p(){ a=b=c=0;}
inline p(int x){ a=b=0, c=x;}
template<class T> inline p(T x, T y, T z){ a=x, b=y, c=z;}
inline p operator +(const p &rhs)const{ return {(a+rhs.a)%P, (b+rhs.b)%P, (c+rhs.c)%P};}
inline p operator -(const p &rhs)const{ return {(a-rhs.a+P)%P, (b-rhs.b+P)%P, (c-rhs.c+P)%P};}
inline p operator *(int rhs)const{ return {(ll)a*rhs%P, (ll)b*rhs%P, (ll)c*rhs%P};}
} A, B, a[N][2];
inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
void dfs1(int u){
siz[u]=1;
for(int i=h[u]; i; i=pre[i])
dfs1(e[i]), siz[u]+=siz[e[i]], f[u]+=f[e[i]]+siz[e[i]];
}
void dfs2(int u){
for(int i=h[u]; i; i=pre[i]) f[e[i]]=f[u]+n-siz[e[i]]*2, dfs2(e[i]);
}
inline int Pow(ll x, int y=P-2){
int ans=1;
for(x%=P; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
return ans;
}
int main() {
read(n);
inv[0]=inv[1]=1;
for(int i=2; i<=n; ++i) inv[i]=(ll)(P-P/i)*inv[P%i]%P;
while(isspace(s[1]=read()));
for(int i=2; i<=n; ++i) s[i]=read();
for(int i=1; i<=n; ++i) tot+=(s[i]=='1');
for(int i=2, x; i<=n; ++i) read(x), add(x, i);
dfs1(1), dfs2(1);
a[1][0].a=a[1][1].b=1;
for(int i=1; i<n-1; ++i){
a[i+1][1]=(a[i][1]*n-a[i-1][1]*(i-1)-a[i-1][0]-(i!=1))*inv[n-i];
a[i+1][0]=(a[i][0]*n-a[i-1][0]*i-a[i+1][1]-1)*inv[n-i-1];
}
A=a[n-2][0]*(n-1)-a[n-1][0]*n;
B=a[n-2][1]*(n-2)-a[n-1][1]*n+a[n-2][0]+1;
x=((ll)A.c*B.b-(ll)A.b*B.c)%P*Pow((ll)A.b*B.a-(ll)A.a*B.b)%P;
swap(A.a, A.b), swap(B.a, B.b);
y=((ll)A.c*B.b-(ll)A.b*B.c)%P*Pow((ll)A.b*B.a-(ll)A.a*B.b)%P;
b=((ll)a[tot][1].a*x+(ll)a[tot][1].b*y+a[tot][1].c)%P+inv[n];
w=((ll)a[tot][0].a*x+(ll)a[tot][0].b*y+a[tot][0].c)%P+inv[n];
for(int i=1; i<=n; ++i) ans=(ans+f[i]%P*(s[i]=='1'?b:w))%P;
return printf("%lld", (ll)(ans+P)*inv[n]%P), 0;
}
0%