0%

「LOJ 3059」「HNOI 2019」序列

LOJ #3059. 「HNOI 2019」序列

题意

给定一个长为 \(n\) 的序列 \(A_1,\dotsc,A_n\),求一个长为 \(n\) 的不下降序列 \(B_1,\dotsc,B_n\),使得 \(\sum_{i=1}^n (A_i-B_i)^2\) 最小,只需要输出最小值

以及 \(m\)互相独立的修改,每次会更改一个位置的值,要求输出修改后的答案

\(998244353\)

\(n,m\le 10^5\)

做法

考虑没有修改的情况

显然如果一个 \(A_i>A_{i+1}\),那么必然有 \(B_i=B_{i+1}\),于是我们可以把这两个位置缩起来

容易发现缩起来的块中的 \(B\) 全部取 \(A\) 的平均值最优

用单调栈维护缩起来的块,块的平均值保持不下降,每次新加入一个元素,不断弹出栈顶和新元素合并直到满足不下降的性质

时间复杂度 \(\mathcal O(n)\)

容易发现上面缩的过程不受顺序影响

把询问离线,对于一个位置 \(i\),维护出 \(1,\dotsc,i-1\)\(i+1,\dotsc,n\) 的单调栈(栈顶朝向位置 \(i\)

改变 \(i\) 处的值,压入左侧的栈中,可以用二分计算弹栈的次数,考虑如果弹出 \(k\) 个元素后满足了不下降,弹出 \(k+1\) 个也可以满足,因此可以二分

如果此时所有块满足不下降,那么已经得到了答案(左侧的栈压入后必然满足不下降,唯一可能出现问题的位置是两个栈顶之间)

否则,考虑不断把右侧的栈顶取出压入左侧,直到满足不下降,正确性显然

同理我们也可以用二分优化这个过程,考虑如果右侧取出前 \(k\) 个元素压入左侧后,满足了不下降,那么取出 \(k+1\) 个元素也可以满足,因为,在左侧不断弹栈的过程中间块的平均值只会不断增加,并且始终没有超过右侧栈中第 \(k+1\) 个元素的平均值,因此加入右侧栈中第 \(k+1\) 个元素后,弹栈过程也不会超过第 \(k+1\) 个的平均值,更不会超过第 \(k+2\) 个,因此可以二分

维护栈可以记忆一下操作然后撤销

时间复杂度 \(\mathcal O(n+m\log^2 n)\),空间复杂度 \(\mathcal O(n)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>
#include<vector>

using namespace std;
#define ll long long

inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 100005, P = 998244353;
int n, m, sum, top, top2, a[N], f[N], g[N], inv[N], ans[N], stk[N], stk2[N];
ll s[N];
vector<int> b[N];
vector<pair<int,int>> q[N];
inline bool cmp(int l1, int r1, int l2, int r2, int x=0, int y=0){
return (s[r1]-s[l1-1]+x)*(r2-l2+1)>(s[r2]-s[l2-1]+y)*(r1-l1+1);
}
inline int calc(int l, int r, int x=0){
return P-(s[r]-s[l-1]+x)%P*((s[r]-s[l-1]+x)%P)%P*inv[r-l+1]%P;
}
int main() {
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
read(n), read(m);
for(int i=1; i<=n; ++i) read(a[i]), s[i]=s[i-1]+a[i], sum=(sum+(ll)a[i]*a[i])%P;
q[1].push_back(make_pair(a[1], 0)), ans[0]=sum;
for(int i=1, x=0, y=0; i<=m; ++i)
read(x), read(y), q[x].push_back(make_pair(y, i)),
ans[i]=(sum+(ll)(P-a[x])*a[x]+(ll)y*y)%P;
inv[1]=1;
for(int i=2; i<=n; ++i) inv[i]=(ll)(P-P/i)*inv[P%i]%P;

for(int i=1; i<=n; ++i){
while(top && cmp(stk[top-1]+1, stk[top], stk[top]+1, i)) b[i].push_back(stk[top--]);
stk[++top]=i, f[top]=(f[top-1]+calc(stk[top-1]+1, i))%P;
}
stk2[0]=n+1;
for(int i=n; i; --i){
--top;
reverse(b[i].begin(), b[i].end());
for(int j:b[i]) stk[++top]=j, f[top]=(f[top-1]+calc(stk[top-1]+1, j))%P;
if(i<n){
while(top2 && cmp(i+1, stk2[top2]-1, stk2[top2], stk2[top2-1]-1)) --top2;
stk2[++top2]=i+1, g[top2]=(g[top2-1]+calc(i+1, stk2[top2-1]-1))%P;
}
for(auto j:q[i]){
int l=1, r=top, now=0, d=j.first-a[i];
while(l<=r){
int mid=(l+r)>>1;
if(cmp(stk[mid-1]+1, stk[mid], stk[mid]+1, i, 0, d)) r=mid-1;
else now=mid, l=mid+1;
}
if(!top2 || !cmp(stk[now]+1, i, i+1, stk2[top2-1]-1, d))
ans[j.second]=(ans[j.second]+(ll)calc(stk[now]+1, i, d)+f[now]+g[top2])%P;
else{
l=0, r=top2-1;
int res=0, lans=0;
while(l<=r){
int mid=(l+r)>>1, L=1, R=now, Ans=0;
while(L<=R){
int Mid=(L+R)>>1;
if(cmp(stk[Mid-1]+1, stk[Mid], stk[Mid]+1, stk2[mid]-1, 0, d))
R=Mid-1;
else L=Mid+1, Ans=Mid;
}
if(mid && cmp(stk[Ans]+1, stk2[mid]-1, stk2[mid], stk2[mid-1]-1, d))
r=mid-1;
else l=mid+1, res=mid, lans=Ans;
}
ans[j.second]=(ans[j.second]+(ll)calc(stk[lans]+1, stk2[res]-1, d)+f[lans]+g[res])%P;
}
}
}
for(int i=0; i<=m; ++i) print(ans[i]), print('\n');
return flush(), 0;
}