「LOJ 2983」「WC2019」数树

LOJ #2983. 「WC2019」数树

题意

本题包含三个问题:

  • 问题 0:已知两棵 \(n\) 个节点的树的形态。要给予每个节点一个 \([1, y]\) 中的整数,使得对于任意两个节点 \(p, q\),如果存在边 \((p, q)\) 同时属于这两棵树,则 \(p, q\) 必须被给予相同的数。求给予数的方案数。

  • 问题 1:已知第一棵树,对于第二棵树的所有 \(n^{n−2}\) 种选择方案,求问题 0 的答案之和。

  • 问题 2:对于第一棵树的所有 \(n^{n−2}\) 种选择方案,求问题 1 的答案之和。

\(998244353\) 取模

做法

问题 0

答案是 \(y^{n-\text{公共边数}}\)

问题 1

\(z=y^{-1}\)

在最后乘上 \(y^n\),一种方案对答案的贡献是 \(z^{\text{公共边数}}\)

根据

\[ z^k=(z-1+1)^k=\sum_{i=0}^k \binom{k}{i}(z-1)^k \]

问题可以转化为:对于每个第一棵树边集的子集 \(S\),和至少包含这个集合的第二棵树的每个方案,有 \((z-1)^{|S|}\) 的贡献

假设 \(S\) 形成的 \(m=n-|S|\) 个联通块大小分别为 \(a_1,a_2,\dotsc,a_m\)

覆盖 \(S\) 的树的数量即把 \(m\) 个联通块按照树的结构连接起来的方案数,等于

\[ \begin{align} &\sum_{\substack{(\sum_{i=1}^m d_i)=2(m-1) \\ d_i\ge 1}} \frac{(m-2)!}{\prod_{i=1}^m (d_i-1)!} \times \prod_{i=1}^m a_i^{d_i} \\ = &\left(\prod_{i=1}^m a_i \right)\sum_{\substack{(\sum_{i=1}^m d_i)=2(m-1) \\ d_i\ge 1}} \frac{(m-2)!}{\prod_{i=1}^m (d_i-1)!} \times \prod_{i=1}^m a_i^{d_i-1} \end{align} \]

其中 \(d_i\) 表示第 \(i\) 个联通块的度数

而第 \(i\) 个联通块在 Prufer 序列中出现了 \(d_i-1\) 次,一种连边方案唯一对应了一种数量分别为 \(d_1-1,d_2-1,\dotsc,d_m-1\)\(m\) 种元素的排列,于是这部分方案数为 \(\frac{(m-2)!}{\prod_{i=1}^m (d_i-1)!}\),每个端点可以在 \(a_i\) 个点中任意选择,于是有 \(\prod_{i=1}^m a_i^{d_i}\) 的贡献。

上式可以转化为枚举 Prufer 序列的第 \(i\) 位,假设是 \(p_i\),有 \(a_{p_i}\) 的贡献,每位的方案数就是 \((\sum_{i=1}^m a_i)=n\),位之间独立,于是 \(m-2\) 位共 \(n^{m-2}\)

上式等于

\[ n^{m-2} \prod_{i=1}^m a_i \]

其中 \(\prod_{i=1}^m a_i\) 可以理解为在每个联通块中选择一个点的方案数

答案即为

\[ \begin{align} &\sum_{S\subseteq E} (z-1)^{n-m} n^{m-2} \prod_{i=1}^m a_i \\ =&n^{-2} (z-1)^n \sum_{S\subseteq E} (z-1)^{-m} n^m \prod_{i=1}^m a_i \end{align} \]

其中 \(E\) 表示第一棵树的边集

考虑 DP,令 \(f_{i,1/0}\) 表示以 \(i\) 为根的子树,\(i\) 所在的联通块中是否已经选出一个点的方案数,枚举一条边选和不选分别有一个系数,最终答案为 \(f_{1,1}\)

复杂度 \(\mathcal O(n)\)

问题 2

同样枚举一个边集 \(S\) 需要被覆盖

根据上面,方案数是

\[ \left(n^{m-2} \prod_{i=1}^m a_i\right)^2 \]

所有联通块数量为 \(m\) 的总方案数为

\[ \sum_{\substack{(\sum_{i=1}^m a_i)=n \\ a_i\ge 1}} \frac{n!}{m!\prod_{i=1}^m a_i!} \times \prod_{i=1}^m a_i^{a_i-2} \times \left(n^{m-2} \prod_{i=1}^m a_i\right)^2 \]

乘上 \((z-1)^{n-m}\) 后总答案为

\[ \begin{align} &\sum_{m=1}^n (z-1)^{n-m} \sum_{\substack{(\sum_{i=1}^m a_i)=n \\ a_i\ge 1}} \frac{n!}{m!\prod_{i=1}^m a_i!} \times \prod_{i=1}^m a_i^{a_i-2} \times \left(n^{m-2} \prod_{i=1}^m a_i\right)^2 \\ =& (z-1)^n n^{-4} n! \sum_{m=1}^n \frac{(z-1)^{-m} n^{2m}}{m!} \sum_{\substack{(\sum_{i=1}^m a_i)=n \\ a_i\ge 1}} \prod_{i=1}^m \frac{a_i^{a_i}}{a_i!} \\ =& (z-1)^n n^{-4} n! [x^n] \exp\left(\frac{n^2}{z-1} \sum_{i=1}^\infty \frac{i^i}{i!} x^i\right) \end{align} \]

最后一行的 \(\exp\) 感受一下就好了 反推还是能看出来的

复杂度 \(\mathcal O(n\log n)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>
#include<set>
#include<vector>

using namespace std;
#define ll long long
#define ull unsigned long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}

const int N = 100005, P = 998244353;
int n, y, op;
namespace subtask1{
int num, h[N], e[N<<1], pre[N<<1], f[N][2];
inline int Pow(ll x, int y=P-2){
int ans=1;
for(; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
return ans;
}
inline void add(int x, int y){ e[++num]=y, pre[num]=h[x], h[x]=num;}
void dfs(int u, int fa=0){
f[u][0]=1, f[u][1]=n;
for(int i=h[u]; i; i=pre[i]) if(e[i]!=fa){
dfs(e[i], u);
f[u][1]=((ll)f[u][1]*f[e[i]][1]+((ll)f[u][1]*f[e[i]][0]+(ll)f[u][0]*f[e[i]][1])%P*(y-1))%P;
f[u][0]=((ll)f[u][0]*f[e[i]][1]+(ll)f[u][0]*f[e[i]][0]%P*(y-1))%P;
}
}
void main(){
y=Pow(y);
for(int i=1, x, y; i<n; ++i) read(x), read(y), add(x, y), add(y, x);
dfs(1);
printf("%lld", (ll)f[1][1]*Pow(n, P-3)%P*Pow(y, P-n-1)%P);
}
}
namespace subtask2{
const int M = 1<<18;
struct Z{
unsigned x;
Z(const unsigned _x=0):x(_x){}
inline Z operator +(const Z &rhs)const{ return x+rhs.x<P?x+rhs.x:x+rhs.x-P;}
inline Z operator -(const Z &rhs)const{ return x<rhs.x?x-rhs.x+P:x-rhs.x;}
inline Z operator -()const{ return x?P-x:0;}
inline Z operator *(const Z &rhs)const{ return static_cast<ull>(x)*rhs.x%P;}
inline Z operator +=(const Z &rhs){ return x=x+rhs.x<P?x+rhs.x:x+rhs.x-P, *this;}
inline Z operator -=(const Z &rhs){ return x=x<rhs.x?x-rhs.x+P:x-rhs.x, *this;}
inline Z operator *=(const Z &rhs){ return x=static_cast<ull>(x)*rhs.x%P, *this;}
} w[M], Inv[M], fac[M];
vector<Z> f;

inline Z Pow(Z x, int y=P-2){
Z ans=1;
for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x;
return ans;
}
inline void Init(){
for(unsigned i=1; i<M; i<<=1){
w[i]=1;
Z t=Pow(3, (P-1)/i/2);
for(unsigned j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;
}
Inv[1]=1;
for(unsigned i=2; i<M; ++i) Inv[i]=Inv[P%i]*(P-P/i);
}
inline int Get(int x){ int n=1; while(n<=x) n<<=1; return n;}
inline void DFT(vector<Z> &f, int n){
static ull F[M];
if((int)f.size()!=n) f.resize(n);
for(int i=0, j=0; i<n; ++i){
F[i]=f[j].x;
for(int k=n>>1; (j^=k)<k; k>>=1);
}
for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
Z *W=w+i;
ull *F0=F+j, *F1=F+j+i;
for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
ull t=(*F1)*(W->x)%P;
(*F1)=*F0+P-t, (*F0)+=t;
}
}
for(int i=0; i<n; ++i) f[i]=F[i]%P;
}
inline void IDFT(vector<Z> &f, int n){
f.resize(n), reverse(f.begin()+1, f.end());
DFT(f, n);
Z I=Pow(n);
for(int i=0; i<n; ++i) f[i]=f[i]*I;
}
inline vector<Z> operator +(const vector<Z> &f, const vector<Z> &g){
vector<Z> ans=f;
for(unsigned i=0; i<f.size(); ++i) ans[i]+=g[i];
return ans;
}
inline vector<Z> operator *(const vector<Z> &f, const vector<Z> &g){
if((ull)f.size()*g.size()<=1000){
vector<Z> ans;
ans.resize(f.size()+g.size()-1);
for(unsigned i=0; i<f.size(); ++i) for(unsigned j=0; j<g.size(); ++j)
ans[i+j]+=f[i]*g[j];
return ans;
}
static vector<Z> F, G;
F=f, G=g;
int p=Get(f.size()+g.size()-2);
DFT(F, p), DFT(G, p);
for(int i=0; i<p; ++i) F[i]*=G[i];
IDFT(F, p);
return F.resize(f.size()+g.size()-1), F;
}
vector<Z> &PolyInv(const vector<Z> &f, int n=-1){
if(n==-1) n=f.size();
if(n==1){
static vector<Z> ans;
return ans.clear(), ans.push_back(Pow(f[0])), ans;
}
vector<Z> &ans=PolyInv(f, (n+1)/2);
vector<Z> tmp(&f[0], &f[0]+n);
int p=Get(n*2-2);
DFT(tmp, p), DFT(ans, p);
for(int i=0; i<p; ++i) ans[i]=((Z)2-ans[i]*tmp[i])*ans[i];
IDFT(ans, p);
return ans.resize(n), ans;
}
inline vector<Z> Derivative(const vector<Z> &a){
vector<Z> ans(a.size()-1);
for(unsigned i=1; i<a.size(); ++i) ans[i-1]=a[i]*i;
return ans;
}
inline vector<Z> Integral(const vector<Z> &a){
vector<Z> ans(a.size()+1);
for(unsigned i=0; i<a.size(); ++i) ans[i+1]=a[i]*Inv[i+1];
return ans;
}
inline vector<Z> PolyLn(const vector<Z> &f){
vector<Z> ans=Derivative(f)*PolyInv(f);
ans.resize(f.size()-1);
return Integral(ans);
}
vector<Z> PolyExp(const vector<Z> &f, int n=-1){
if(n==-1) n=f.size();
if(n==1) return {1};
vector<Z> ans=PolyExp(f, (n+1)/2), tmp;
ans.resize(n), tmp=PolyLn(ans);
for(Z &i:tmp) i=-i;
++tmp[0].x;
ans=ans*(tmp+f);
return ans.resize(n), ans;
}
void main(){
if(y==1) { printf("%u", Pow(n, (n-2)*2).x); return;}
Init();
y=Pow(y).x, fac[0]=1, f.resize(n+1);
Z k=Pow(y-1)*n*n;
for(int i=1; i<=n; ++i) f[i]=k*Pow(i, i)*Pow(fac[i]=fac[i-1]*i);
f=PolyExp(f);
printf("%u", (f[n]*Pow(y-1, n)*Pow(n, P-5)*fac[n]*Pow(y, P-n-1)).x);
}
}
int main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
read(n), read(y), read(op);
if(op==0){
static set<pair<int,int>> _;
static pair<int,int> __;
for(int i=1; i<n; ++i) read(__.first), read(__.second), _.insert(__);
int cnt=0;
for(int i=1; i<n; ++i) read(__.first), read(__.second), cnt+=_.count(__);
printf("%d", subtask1::Pow(y, n-cnt));
}
else if(op==1) subtask1::main();
else subtask2::main();
return 0;
}
0%