考场上 T1, T3 做出来了

T2, T4 CF 上有原题

这里就记录一下 T2 正解

T2

原题

做法1:dp

令 $f_i$ 从 $i$ 开头的可消去子串的数量

$nxt_i$ 是使得 $\left[i,j\right]$ 可消去的最小的 $j$

那么 $f_i=f_{nxt_i+1}+1$ 如果 $nxt_i$ 存在,否则是 $0$

$nxt_i$ 可以这么求($nxt_i$ 存在的话):

nxt[i] := i+1
while a[nxt[i]] != a[i]:
    nxt[i] = nxt[nxt[i]] + 1

但是这么做是 $O(n^2)$ 的,不能拿满分

所以引入 $nxta_{i,x}$ 表示使得 $\left[i,j\right]$ 可消去且 $a_{j+1}=x$ 的最小的 $j$

那么就有这样的关系

若 $nxt_i$ 和 $nxta_{i,x}$ 都存在

当 $a_i==a_{i+1}$ 时 $nxt_i=i+1$ 否则 $nxt_i=nxta_{i+1,a_i}+1$

当 $x==a_{nxt_i+1}$ 时 $nxta_{i,x}=nxt_i$ 否则 $nxta_{i,x}=nxta_{nxt_i+1,x}$

原题的代码:(CCF版缩小了值域 $nxt$ 为 $1$,$nxta$ 为 $0$ 表示不存在

#include <bits/stdc++.h>
using namespace std;

const int N = 3e5 + 10;

int nQ;
int n;
int a[N];
int dp[N];
map<int, int> nxta[N];

void input() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
}

void solve() {
    for (int i = n + 1; i >= 1; i--) {
        nxta[i].clear();
        dp[i] = 0;
    }
    for (int i = n - 1; i >= 1; i--) {
        int nxt;
        if (a[i] == a[i + 1]) {
            nxt = i + 1;
            nxta[i] = move(nxta[nxt + 1]);
            if (nxt + 1 <= n) nxta[i][a[nxt + 1]] = nxt;
            dp[i] = dp[nxt + 1] + 1;
        } else {
            nxt = nxta[i + 1][a[i]] + 1;
            if (nxt != 1) {
                nxta[i] = move(nxta[nxt + 1]);
                if(nxt + 1 <= n) nxta[i][a[nxt + 1]] = nxt; //***
                dp[i] = dp[nxt + 1] + 1;
            }
        }
    }
    long long ans = 0; //***
    for (int i = 1; i < n; i++) {
        ans += dp[i];
    }
    cout << ans << endl;
}

int main() {
    scanf("%d", &nQ);
    for (int iQ = 1; iQ <= nQ; iQ++) {
        input();
        solve();
    }
    return 0;
}

做法2:哈希

考虑一个栈 S,每次加入一个序列中的元素,如果该元素和栈顶的是相同的,那么就消掉

令 $S_i$ 表示到 $a_i$ 为止的栈, $f_i$ 表示 $a_i$ 结尾的可消去子序列的数量

如果有 $S_i=S_j$ 那么就说明 $i+1$ 到 $j$ 的区间是可以消去的(证明略,那么这时候 $f_j=f_i+1$

接下来状态压缩,令 $g_S$ 表示当前栈为 $S$ 时以当前位置结尾的可消去子序列数量,那么每次更新就是 $ans+=g_S,g_S++$

给 $S$ 做一个哈希,然后 $g$ 用 map 维护

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

int n;
string s;
string stk;

const ll b1=131,b2=151;
const ll mod=1e9+7;
ll p1[2000010];
ll p2[2000010];
ll h1,h2;
map<pair<ll,ll>,int> st;

ll ans;

int main(){
    cin>>n>>s;
    p1[0]=p2[0]=1;
    for(int i=1;i<=n;i++){
        p1[i]=p1[i-1]*b1%mod;
        p2[i]=p2[i-1]*b2%mod;
    }
    stk='-';
    h1=(int)stk[0]*p1[0]%mod;
    h2=(int)stk[0]*p2[0]%mod;
    st[make_pair(h1,h2)]++;
    for(char c:s){
        if(c==stk.back()){
            h1-=(int)stk.back()*p1[stk.size()-1]%mod;
            h1+=mod;
            h1%=mod;
            h2-=(int)stk.back()*p2[stk.size()-1]%mod;
            h2+=mod;
            h2%=mod;
            stk.pop_back();
        }else{
            h1+=(int)c*p1[stk.size()]%mod;
            h1%=mod;
            h2+=(int)c*p2[stk.size()]%mod;
            h2%=mod;
            stk.push_back(c);
        }
        ans+=st[make_pair(h1,h2)];
        st[make_pair(h1,h2)]++;
    }
    cout<<ans<<endl;
    return 0;
}

Tags: