「LOJ 6391」「THUPC2018」淘米神的树 / Tommy

LOJ #6391. 「THUPC2018」淘米神的树 / Tommy

题意

有一棵 \(n\) 个点的树,你要以一个顺序选择每个点恰好一次。

初始只有两个钦定点 \(a,b\) 可以被选,一个点被选后所有的相邻点就可以被选择。

求方案数,模 \(998244353\)

做法

考虑只有一个初始点怎么做,以这个点为根,计算每个子树的方案数,合并的时候就考虑不改变顺序的方式。

\(size_u\) 表示以 \(u\) 为根的子树的大小,\(son_u\) 表示 \(u\) 的儿子集合。

可以发现一个点 \(u\) 对答案的贡献是 \(\frac{(size_u-1)!}{\prod\limits_{v\in son_u} size_v!}\)

因此答案就是

\[ \prod_{u=1}^n \frac{(size_u-1)!}{\prod\limits_{v\in son_u} size_v!} = \frac{n!}{\prod\limits_{u=1}^n size_u} \]

然后考虑两个点的情况

新建一个点 \(s\),与 \(a,b\) 分别连边,因此问题可以等价转化到环套树上只有一个初始点 \(s\) 的情况

考虑枚举环上一个最后被选到的点 \(u\),我们可以分别断开与 \(u\) 相连的两条环边并按照树的情况计算答案,这两种各会恰好计算 \(u\) 是环上最后被选到的点的方案一次(包括 \(s=u\) 的情况方案数是 \(0\))。

因此我们只需要枚举每条环边断开后计算贡献,最后乘 \(\frac{1}{2}\) 即可。

首先断开环边时不会影响不在环上点对答案的贡献,可以预处理。

记断开所有环边后一个环点 \(u\) 所在树的点数为 \(f_u\),环依次是 \(a_0,a_1,\dotsc,a_m\),其中 \(a_0=s\)

如果断开了 \(a_i\)\(a_{i+1}\) 之间的边

  • 对于 \(j\le i\),贡献是 \(\sum_{x=j}^i f_{a_x}\)
  • 对于 \(j>i\) 贡献是 \(\sum_{x=i+1}^j f_{a_x}\)

\(f\) 的前缀和为 \(s\)

可以发现 \(a_j\) 的贡献即 \(|s_i-s_j|\),并且这对于 \(i=m\) (断开 \(a_m\)\(s\) 之间的边)仍然成立

于是我们需要对于每个 \(i\) 求出 \(\prod_{j\ne i} |s_i-s_j|\)

可以讨论负号会出现多少次去掉绝对值,只需要求每个 \(\prod_{j\ne i} (s_i-s_j)\)

\(g(x)=\prod_{j=0}^m (x-s_j)\),也就是求 \(x\to s_i\)\(f_i(x)=\frac{g(x)}{x-s_i}\) 的值

根据 洛必达法则 我们有 \(f_i(s_i)=g'(s_i)\)

于是分治求出 \(g(x)\) 后多点求值即可

复杂度 \(\mathcal O(n \log^2 n)\)


代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>
#include<vector>

using namespace std;
#define ull unsigned long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 234570, P = 998244353, inv2 = (P+1)/2, M = 1<<19;
int n, a, b, cnt, q[N], siz[N];
vector<int> e[N];
struct Z{
unsigned x;
Z(const unsigned _x=0):x(_x){}
inline Z operator +(const Z &rhs)const{ return x+rhs.x<P?x+rhs.x:x+rhs.x-P;}
inline Z operator -(const Z &rhs)const{ return x<rhs.x?x-rhs.x+P:x-rhs.x;}
inline Z operator -()const{ return x?P-x:0;}
inline Z operator *(const Z &rhs)const{ return static_cast<ull>(x)*rhs.x%P;}
inline Z operator +=(const Z &rhs){ return x=x+rhs.x<P?x+rhs.x:x+rhs.x-P, *this;}
inline Z operator -=(const Z &rhs){ return x=x<rhs.x?x-rhs.x+P:x-rhs.x, *this;}
inline Z operator *=(const Z &rhs){ return x=static_cast<ull>(x)*rhs.x%P, *this;}
} ans, ans0;
vector<Z> f, g;
namespace Poly{
Z w[M];// for DFT
vector<Z> ans;// for Evaluate()
vector<vector<Z>> p;// for Evaluate() & Interpolate()

inline Z Pow(Z x, int y=P-2){
Z ans=1;
for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x;
return ans;
}
inline void Init(){
for(unsigned i=1; i<M; i<<=1){
w[i]=1;
Z t=Pow(3, (P-1)/i/2);
for(unsigned j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;
}
}
inline int Get(int x){ int n=1; while(n<=x) n<<=1; return n;}
inline void DFT(vector<Z> &f, int n){
static ull F[M];
if((int)f.size()!=n) f.resize(n);
for(int i=0, j=0; i<n; ++i){
F[i]=f[j].x;
for(int k=n>>1; (j^=k)<k; k>>=1);
}
for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
Z *W=w+i;
ull *F0=F+j, *F1=F+j+i;
for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
ull t=(*F1)*(W->x)%P;
(*F1)=*F0+P-t, (*F0)+=t;
}
}
for(int i=0; i<n; ++i) f[i]=F[i]%P;
}
inline void IDFT(vector<Z> &f, int n){
f.resize(n), reverse(f.begin()+1, f.end());
DFT(f, n);
Z I=Pow(n);
for(int i=0; i<n; ++i) f[i]=f[i]*I;
}
inline vector<Z> operator +(const vector<Z> &f, const vector<Z> &g){
vector<Z> ans=f;
for(unsigned i=0; i<f.size(); ++i) ans[i]+=g[i];
return ans;
}
inline vector<Z> operator *(const vector<Z> &f, const vector<Z> &g){
if((ull)f.size()*g.size()<=1000){
vector<Z> ans;
ans.resize(f.size()+g.size()-1);
for(unsigned i=0; i<f.size(); ++i) for(unsigned j=0; j<g.size(); ++j)
ans[i+j]+=f[i]*g[j];
return ans;
}
static vector<Z> F, G;
F=f, G=g;
int p=Get(f.size()+g.size()-2);
DFT(F, p), DFT(G, p);
for(int i=0; i<p; ++i) F[i]*=G[i];
IDFT(F, p);
return F.resize(f.size()+g.size()-1), F;
}
vector<Z> &PolyInv(const vector<Z> &f, int n=-1){
if(n==-1) n=f.size();
if(n==1){
static vector<Z> ans;
return ans.clear(), ans.push_back(Pow(f[0])), ans;
}
vector<Z> &ans=PolyInv(f, (n+1)/2);
vector<Z> tmp(&f[0], &f[0]+n);
int p=Get(n*2-2);
DFT(tmp, p), DFT(ans, p);
for(int i=0; i<p; ++i) ans[i]=((Z)2-ans[i]*tmp[i])*ans[i];
IDFT(ans, p);
return ans.resize(n), ans;
}
// a=d*b+r
inline void PolyDiv(const vector<Z> &a, const vector<Z> &b, vector<Z> &d, vector<Z> &r){
if(b.size()>a.size()) return d.clear(), (void)(r=a);

vector<Z> A=a, B=b, iB;
int n=a.size(), m=b.size();
reverse(A.begin(), A.end()), reverse(B.begin(), B.end());
B.resize(n-m+1), iB=PolyInv(B, n-m+1);
d=A*iB;
d.resize(n-m+1), reverse(d.begin(), d.end());

r=b*d, r.resize(m-1);
for(int i=0; i<m-1; ++i) r[i]=a[i]-r[i];
}
inline vector<Z> Derivative(const vector<Z> &a){
vector<Z> ans(a.size()-1);
for(unsigned i=1; i<a.size(); ++i) ans[i-1]=a[i]*i;
return ans;
}
void Evaluate_Interpolate_Init(int l, int r, int t, const vector<Z> &a){
if(l==r) return p[t].clear(), p[t].push_back(-a[l]), p[t].push_back(1);
int mid=(l+r)/2, k=t<<1;
Evaluate_Interpolate_Init(l, mid, k, a), Evaluate_Interpolate_Init(mid+1, r, k|1, a);
p[t]=p[k]*p[k|1];
}
void Evaluate(int l, int r, int t, const vector<Z> &f, const vector<Z> &a){
if(r-l+1<=512){
for(int i=l; i<=r; ++i){
Z x=0, a1=a[i], a2=a[i]*a[i], a3=a[i]*a2, a4=a[i]*a3, a5=a[i]*a4, a6=a[i]*a5, a7=a[i]*a6, a8=a[i]*a7;
int j=f.size();
while(j>=8) x=x*a8+f[j-1]*a7+f[j-2]*a6+f[j-3]*a5+f[j-4]*a4+f[j-5]*a3+f[j-6]*a2+f[j-7]*a1+f[j-8], j-=8;
while(j--) x=x*a[i]+f[j];
ans.push_back(x);
}
return;
}
vector<Z> tmp;
PolyDiv(f, p[t], tmp, tmp);
Evaluate(l, (l+r)/2, t<<1, tmp, a), Evaluate((l+r)/2+1, r, t<<1|1, tmp, a);
}
inline vector<Z> Evaluate(const vector<Z> &f, const vector<Z> &a, int flag=-1){
if(flag==-1) p.resize(a.size()<<2), Evaluate_Interpolate_Init(0, a.size()-1, 1, a);
return ans.clear(), Evaluate(0, a.size()-1, 1, f, a), ans;
}
inline vector<Z> Divide(int l, int r){
if(l==r) return {P-g[l].x, 1};
int mid=(l+r)>>1;
return Divide(l, mid)*Divide(mid+1, r);
}
}
inline bool find(int u, int fa=0){
siz[u]=1;
bool t=(u==b);
for(int i:e[u]) if(i!=fa) t|=find(i, u), siz[u]+=siz[i];
if(t) q[cnt++]=u; else ans0*=siz[u];
return t;
}
int main() {
Poly::Init();
read(n), read(a), read(b);
for(int i=1, x, y; i<n; ++i) read(x), read(y), e[x].push_back(y), e[y].push_back(x);
ans0=1, find(a);
g.push_back(1);
for(int i=0; i<cnt; ++i) g.push_back(siz[q[i]]+1);
f=Poly::Evaluate(Poly::Derivative(Poly::Divide(0, cnt)), g);
for(int i=0; i<=cnt; ++i) ans+=Poly::Pow(f[i]*((cnt^i)&1?P-1:1));
for(int i=1; i<=n; ++i) ans*=i;
return printf("%d", (ans*Poly::Pow(ans0)*inv2).x), 0;
}