CSP-S 2023 复赛
考场上 T1, T3 做出来了
T2, T4 CF 上有原题
这里就记录一下 T2 正解
T2
做法1:dp
令 $f_i$ 从 $i$ 开头的可消去子串的数量
$nxt_i$ 是使得 $\left[i,j\right]$ 可消去的最小的 $j$
那么 $f_i=f_{nxt_i+1}+1$ 如果 $nxt_i$ 存在,否则是 $0$
$nxt_i$ 可以这么求($nxt_i$ 存在的话):
nxt[i] := i+1
while a[nxt[i]] != a[i]:
nxt[i] = nxt[nxt[i]] + 1
但是这么做是 $O(n^2)$ 的,不能拿满分
所以引入 $nxta_{i,x}$ 表示使得 $\left[i,j\right]$ 可消去且 $a_{j+1}=x$ 的最小的 $j$
那么就有这样的关系
若 $nxt_i$ 和 $nxta_{i,x}$ 都存在
当 $a_i==a_{i+1}$ 时 $nxt_i=i+1$ 否则 $nxt_i=nxta_{i+1,a_i}+1$
当 $x==a_{nxt_i+1}$ 时 $nxta_{i,x}=nxt_i$ 否则 $nxta_{i,x}=nxta_{nxt_i+1,x}$
原题的代码:(CCF版缩小了值域 $nxt$ 为 $1$,$nxta$ 为 $0$ 表示不存在
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
int nQ;
int n;
int a[N];
int dp[N];
map<int, int> nxta[N];
void input() {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
}
void solve() {
for (int i = n + 1; i >= 1; i--) {
nxta[i].clear();
dp[i] = 0;
}
for (int i = n - 1; i >= 1; i--) {
int nxt;
if (a[i] == a[i + 1]) {
nxt = i + 1;
nxta[i] = move(nxta[nxt + 1]);
if (nxt + 1 <= n) nxta[i][a[nxt + 1]] = nxt;
dp[i] = dp[nxt + 1] + 1;
} else {
nxt = nxta[i + 1][a[i]] + 1;
if (nxt != 1) {
nxta[i] = move(nxta[nxt + 1]);
if(nxt + 1 <= n) nxta[i][a[nxt + 1]] = nxt; //***
dp[i] = dp[nxt + 1] + 1;
}
}
}
long long ans = 0; //***
for (int i = 1; i < n; i++) {
ans += dp[i];
}
cout << ans << endl;
}
int main() {
scanf("%d", &nQ);
for (int iQ = 1; iQ <= nQ; iQ++) {
input();
solve();
}
return 0;
}
做法2:哈希
考虑一个栈 S,每次加入一个序列中的元素,如果该元素和栈顶的是相同的,那么就消掉
令 $S_i$ 表示到 $a_i$ 为止的栈, $f_i$ 表示 $a_i$ 结尾的可消去子序列的数量
如果有 $S_i=S_j$ 那么就说明 $i+1$ 到 $j$ 的区间是可以消去的(证明略,那么这时候 $f_j=f_i+1$
接下来状态压缩,令 $g_S$ 表示当前栈为 $S$ 时以当前位置结尾的可消去子序列数量,那么每次更新就是 $ans+=g_S,g_S++$
给 $S$ 做一个哈希,然后 $g$ 用 map 维护
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
string s;
string stk;
const ll b1=131,b2=151;
const ll mod=1e9+7;
ll p1[2000010];
ll p2[2000010];
ll h1,h2;
map<pair<ll,ll>,int> st;
ll ans;
int main(){
cin>>n>>s;
p1[0]=p2[0]=1;
for(int i=1;i<=n;i++){
p1[i]=p1[i-1]*b1%mod;
p2[i]=p2[i-1]*b2%mod;
}
stk='-';
h1=(int)stk[0]*p1[0]%mod;
h2=(int)stk[0]*p2[0]%mod;
st[make_pair(h1,h2)]++;
for(char c:s){
if(c==stk.back()){
h1-=(int)stk.back()*p1[stk.size()-1]%mod;
h1+=mod;
h1%=mod;
h2-=(int)stk.back()*p2[stk.size()-1]%mod;
h2+=mod;
h2%=mod;
stk.pop_back();
}else{
h1+=(int)c*p1[stk.size()]%mod;
h1%=mod;
h2+=(int)c*p2[stk.size()]%mod;
h2%=mod;
stk.push_back(c);
}
ans+=st[make_pair(h1,h2)];
st[make_pair(h1,h2)]++;
}
cout<<ans<<endl;
return 0;
}