「HDU 6057」Kanade's convolution

HDU 6057

题意

给你两个数组 \(A[0..2^m-1]\)\(B[0..2^m-1]\)

你需要计算 \[C[k]=\sum_{i\ and\ j=k}A[i\ xor\ j]*B[i\ or\ j]\]

输出 \(\sum_{i=0}^{2^m-1}C[i]*1526^i\ mod\ 998244353\)

\(m\le 19\)


做法

由于 \(i\ and\ j=k\),所以有 \(i\ or\ j=i\ xor\ j\ xor\ k\),(\(i\ xor\ j\)\(k\) 无交)

那么可以转化

\[ \begin{align} C[k]&=\sum_{x\ xor\ y=k}[x\ and\ y=x]*A[x]*B[y]*2^{cnt_x} \\ &=\sum_{x\ xor\ y=k}[cnt_y-cnt_x=cnt_k]*A[x]*B[y]*2^{cnt_x} \end{align} \]

其中 \(cnt_x\) 等于 \(x\) 的二进制表示中 \(1\) 的个数

定义\(a_{i,j}=[cnt_j=i]*A[j],b_{i,j}=[cnt_j=i]*B[j]\)

也就是增加一维集合大小,那么

\[c_{i,k}=\sum_{j=0}^i\sum_{x\ xor\ y=k}a_{j,x}*b_{i-j,y}*2^j\]

\(C[i]=c_{cnt_i,i}\),其他多余的位置是没有意义的

那么我们对每一个 \(a_i,b_i\)FWT,再暴力枚举第一维,加到对应的 \(c_i\)

时间复杂度 \(\mathcal O(2^mm^2)\)


代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#include<string.h>
#include<math.h>

using namespace std;
#define ll long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}

const int N = 20, M = 1<<19, P = 998244353;
int ans, m, cnt[M], a[N][M], b[N][M], c[N][M];
inline int Pow(ll x, int y=P-2){
int ans=1;
for(; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
return ans;
}
inline void FWT(int *f, int g){
for(int i=1; i<1<<m; i<<=1) for(int j=0; j<1<<m; j+=i<<1)
for(int k=j; k<j+i; ++k){
int x=f[k], y=f[k+i];
f[k]=(x+y)%P, f[k+i]=(x-y+P)%P;
}
if(g==-1) for(int i=0, I=Pow(1<<m); i<1<<m; ++i) f[i]=(ll)f[i]*I%P;
}
int main() {
read(m);
for(int i=1; i<1<<m; ++i) cnt[i]=cnt[i^(i&-i)]+1;
for(int i=0; i<1<<m; ++i) read(a[cnt[i]][i]);
for(int i=0; i<1<<m; ++i) read(b[cnt[i]][i]);
for(int i=0; i<=m; ++i) FWT(a[i], 1), FWT(b[i], 1);
for(int i=0, p=1; i<=m; ++i, p<<=1) for(int j=i; j<=m; ++j) for(int k=0; k<1<<m; ++k)
c[j-i][k]=(c[j-i][k]+(ll)a[i][k]*b[j][k]%P*p)%P;
for(int i=0; i<=m; ++i) FWT(c[i], -1);
for(int i=0, k=1; i<1<<m; ++i, k=k*1526ll%P) ans=(ans+(ll)k*c[cnt[i]][i])%P;
return printf("%d\n", ans), 0;
}