0%

「Luogu P4705」玩游戏

Luogu P4705 玩游戏

题意

有长度为 $n$ 的序列 $a$ 和长度为 $m$ 的序列 $b$

对于 $k=1,2,\dotsc,t$ 求随机两个元素 $a_i$ 和 $b_j$,$(a_i+b_j)^k$ 的期望

模 $998244353$

$n,m,t\le 10^5$

做法

有点神

首先二项式展开

$$
\begin{align}
& \sum_{i=1}^n \sum_{j=1}^m (a_i+b_j)^k \
= & \sum_{i=1}^n \sum_{j=1}^m \sum_{x=0}^k a_i^xb_j^{k-x} \binom{k}{x} \
= & k!\sum_{i=1}^n \sum_{j=1}^m \sum_{x=0}^k \frac{a_i^x}{x!} \times \frac{b_j^{k-x}}{(k-x)!} \
= & k!\sum_{x=0}^k \left(\sum_{i=1}^n \frac{a_i^x}{x!} \right) \left(\sum_{i=1}^m \frac{b_i^{k-x}}{(k-x)!}\right)
\end{align}
$$

可以发现右边是卷积的形式,我们只要对于每一个 $x$ 求出 $\sum\limits_{i=1}^n a_i^x$ 和 $\sum\limits_{i=1}^m b_i^x$ 即可

构造多项式 $f(x)=\prod\limits_{i=1}^n (a_ix+1)$,这可以用分治 NTT 在 $\mathcal O(n\log^2n)$ 的复杂度内求出来

由于 $\ln (ax+1) = -\sum\limits_{i=1}^\infty \frac{(-1)^i}{i}a^ix^i$

证明如下

首先有

$$
\frac{1}{1-x}=\sum_{i=0}^\infty x^i
$$

所以

$$
\begin{align}
(\ln(1-x))’ &= -\frac{1}{1-x} = -\sum_{i=0}^\infty x^i \
\ln(1-x) &= -\sum_{i=1}^\infty \frac{x^i}{i}
\end{align}
$$

用 $-ax$ 代入即可

所以 $\ln(f(x)) = -\sum\limits_{j=1}^n \sum\limits_{i=1}^\infty \frac{(-1)^i}{i}a_j^ix^i$

多余的系数去掉可以得到我们需要的每个 $\sum\limits_{i=1}^n a_i^x$

求这个的另一种方法是从 牛顿恒等式 推,本质似乎是一样的

总复杂度 $\mathcal O(n\log^2n)$


代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>
#include<vector>

using namespace std;
#define ull unsigned long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }


const unsigned P = 998244353;
struct Z{
unsigned x;
Z(const unsigned _x=0):x(_x){}
inline Z operator +(const Z &rhs)const{ return x+rhs.x<P?x+rhs.x:x+rhs.x-P;}
inline Z operator -(const Z &rhs)const{ return x<rhs.x?x-rhs.x+P:x-rhs.x;}
inline Z operator -()const{ return x?P-x:0;}
inline Z operator *(const Z &rhs)const{ return static_cast<ull>(x)*rhs.x%P;}
inline Z operator +=(const Z &rhs){ return x=x+rhs.x<P?x+rhs.x:x+rhs.x-P, *this;}
inline Z operator -=(const Z &rhs){ return x=x<rhs.x?x-rhs.x+P:x-rhs.x, *this;}
inline Z operator *=(const Z &rhs){ return x=static_cast<ull>(x)*rhs.x%P, *this;}
};
int n, m, t;
vector<Z> a, b;

namespace Poly{
const int MAX_LEN = 1<<18;

Z w[MAX_LEN], Inv[MAX_LEN];// for DFT

inline Z Pow(Z x, int y=P-2){
Z ans=1;
for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x;
return ans;
}
inline void Init(){
for(int i=1; i<MAX_LEN; i<<=1){
w[i]=1;
Z t=Pow(3, (P-1)/i/2);
for(int j=1; j<i; ++j) w[i+j]=w[i+j-1]*t;
}
Inv[1]=1;
for(int i=2; i<MAX_LEN; ++i) Inv[i]=Inv[P%i]*(P-P/i);
}
inline int Get(int x){ int n=1; while(n<=x) n<<=1; return n;}
inline void DFT(vector<Z> &f, int n){
static ull F[MAX_LEN];
if((int)f.size()!=n) f.resize(n);
for(int i=0, j=0; i<n; ++i){
F[i]=f[j].x;
for(int k=n>>1; (j^=k)<k; k>>=1);
}
for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
Z *W=w+i;
ull *F0=F+j, *F1=F+j+i;
for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
ull t=(*F1)*(W->x)%P;
(*F1)=*F0+P-t, (*F0)+=t;
}
}
for(int i=0; i<n; ++i) f[i]=F[i]%P;
}
inline void IDFT(vector<Z> &f, int n){
f.resize(n), reverse(f.begin()+1, f.end());
DFT(f, n);
Z I=Pow(n);
for(int i=0; i<n; ++i) f[i]=f[i]*I;
}
inline vector<Z> Add(const vector<Z> &f, const vector<Z> &g){
vector<Z> ans=f;
for(unsigned i=0; i<f.size(); ++i) ans[i]+=g[i];
return ans;
}
inline vector<Z> Mul(const vector<Z> &f, const vector<Z> &g){
if(f.size()*g.size()<=1000){
vector<Z> ans;
ans.resize(f.size()+g.size()-1);
for(unsigned i=0; i<f.size(); ++i) for(unsigned j=0; j<g.size(); ++j)
ans[i+j]+=f[i]*g[j];
return ans;
}
static vector<Z> F, G;
F=f, G=g;
int p=Get(f.size()+g.size()-2);
DFT(F, p), DFT(G, p);
for(int i=0; i<p; ++i) F[i]*=G[i];
IDFT(F, p);
return F.resize(f.size()+g.size()-1), F;
}
vector<Z> &PolyInv(const vector<Z> &f, int n=-1){
if(n==-1) n=f.size();
if(n==1){
static vector<Z> ans;
return ans.clear(), ans.push_back(Pow(f[0])), ans;
}
vector<Z> &ans=PolyInv(f, (n+1)/2);
vector<Z> tmp(&f[0], &f[0]+n);
int p=Get(n*2-2);
DFT(tmp, p), DFT(ans, p);
for(int i=0; i<p; ++i) ans[i]=((Z)2-ans[i]*tmp[i])*ans[i];
IDFT(ans, p);
return ans.resize(n), ans;
}
inline vector<Z> Derivative(const vector<Z> &a){
vector<Z> ans(a.size()-1);
for(unsigned i=1; i<a.size(); ++i) ans[i-1]=a[i]*i;
return ans;
}
inline vector<Z> Integral(const vector<Z> &a){
vector<Z> ans(a.size()+1);
for(unsigned i=0; i<a.size(); ++i) ans[i+1]=a[i]*Inv[i+1];
return ans;
}
inline vector<Z> PolyLn(const vector<Z> &f){
vector<Z> ans=Mul(Derivative(f), PolyInv(f));
ans.resize(f.size()-1);
return Integral(ans);
}
vector<Z> divide(int l, int r, const vector<Z> &f){
if(l==r) return vector<Z>{1, f[l]};
int mid=(l+r)>>1;
return Mul(divide(l, mid, f), divide(mid+1, r, f));
}
inline vector<Z> solve(const vector<Z> &f, int t){
vector<Z> ans=divide(0, f.size()-1, f);
ans.resize(t+1), ans=PolyLn(ans);
for(int i=1; i<=t; ++i) ans[i]*=(i&1?i:P-i);
return ans[0]=f.size(), ans;
}
}
int main() {
Poly::Init();
read(n), read(m), a.resize(n), b.resize(m);
for(int i=0; i<n; ++i) read(a[i].x);
for(int i=0; i<m; ++i) read(b[i].x);
read(t);
a=Poly::solve(a, t);
b=Poly::solve(b, t);
for(int i=1, k=1; i<=t; k=(ull)k*Poly::Inv[++i].x%P) a[i]=a[i]*k;
for(int i=1, k=1; i<=t; k=(ull)k*Poly::Inv[++i].x%P) b[i]=b[i]*k;
a=Poly::Mul(a, b), a.resize(t+1);
for(int i=1, k=Poly::Pow((ull)n*m%P).x; i<=t; k=(ull)k*++i%P) print((a[i]*k).x), print('\n');
return flush(), 0;
}