CF1276F - Asterisk Substrings

CF1276F - Asterisk Substrings

题目大意

给定串 \(S,|S|=n\) ,设一个串的子串集合为 \(Sub(S)\)

\(|Sub(S) \cup Sub(*+S[2:n])\cup Sub(S[1:1]+*+S[3:n])\cup \cdots|\)

其中*表示特殊字符而不是通配符


分析

对于不包含*的串,显然就是 \(Sub(S)\) ,可以通过后缀数组,后缀自动机来计算

对于包含*的串,考虑分两部分计算


1.对于后面接的串 \(T\) 分类

对于后面接的串 \(T\)\(T\) 在原串 \(S\) 出现的位置对应后缀数组上一段 \(\text{rank}\) 区间 \([l,r]\)

考虑按照原串后缀数组的 \(\text{height}\) 建立笛卡尔树,此时容易发现,不同的 \([l,r]\) 就是

笛卡尔树上每一个节点对应的区间,而这个 \([l,r]\) 出现的个数就是 \(height_u-height_{fa_u}\)


2.对于每一个 \([l,r]\) 计算前面接的串 \(R\) 的种类

那么在前面接的串 \(R\) 就是从 \([l,r]\)\(sa[i]-2\) 对应的所有前缀中

选择某一条后缀得到

在笛卡尔树上计算时,我们需要从儿子中合并两段 \([l,r],[l',r']\) ,计算不同串个数

也就是说我们需要动态维护一个集合 \(Set\) 为反串后缀的子集,并且计算这些后缀能够构成的串种类

对于 \(Set\) 为全集的情况,我们知道答案就是 \(\sum |suf_i|-\sum height_i\)

这条式子的意义实际上是:

按照 \(\text{rank}\) 考虑每一个后缀,减去前面已经出现过的所有串,就是减去和前面串最大的 \(\text{LCP}\)

由于 \(\text{LCP}(i,j)\) 取决于中间 \(height\) 的最小值,按 \(\text{rank}\) 加入时 \(\text{LCP}\) 的最大值就是 \(height_{i-1}\)


那么这个计算思路对于 \(Set\) 中元素不连续的情况显然依然成立

只需要动态维护出现位置的 \(\text{rank}\) ,不断减去相邻两个位置 \(i,j\)\(\text{LCP}\) 即可

\(\text{std::set}\) +启发式合并即可 \(O(n\log ^2n)\) 维护, \(\text{LCP}\) 用后缀数组 \(\text{RMQ}\) 即可 \(O(1)\) 求(实际上带一个 \(\log\) 也不影响总复杂度)

或许用线段树合并可以做到 \(O(n\log n)\)

代码的话 \(\downarrow\) ,有轻度封装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)

char IO;
template <class T=int> T rd(){
T s=0; int f=0;
while(!isdigit(IO=getchar())) f|=IO=='-';
do s=(s<<1)+(s<<3)+(IO^'0');
while(isdigit(IO=getchar()));
return f?-s:s;
}

enum{N=100010};
int n,m;
char s[N];
struct Suffix_Array{
int rk[N<<1],tmp[N],cnt[N],sa[N],lcp[N];
void Build() {
rep(i,1,n) cnt[s[i]-'a']++;
rep(i,1,25) cnt[i]+=cnt[i-1];
rep(i,1,n) rk[i]=cnt[s[i]-'a'];
drep(i,n,1) sa[cnt[s[i]-'a']--]=i;
for(int m=n,k=1;;k<<=1) {
int h=0;
rep(i,n-k+1,n) tmp[++h]=i;
rep(i,1,n) if(sa[i]>k) tmp[++h]=sa[i]-k;

rep(i,1,n) cnt[rk[sa[i]]]=i;
drep(i,n,1) sa[cnt[rk[tmp[i]]]--]=tmp[i];

rep(i,1,n) tmp[sa[i]]=tmp[sa[i-1]]+(rk[sa[i]]!=rk[sa[i-1]]||rk[sa[i-1]+k]!=rk[sa[i]+k]);
rep(i,1,n) rk[i]=tmp[i];
if((m=rk[sa[n]])==n) break;
}
int h=0;
rep(i,1,n) {
int j=sa[rk[i]-1];
if(h) h--;
while(s[i+h]==s[j+h]) h++;
lcp[rk[i]-1]=h;
}
}
} ;

struct LCPer:Suffix_Array{
int st[20][N],Log[N];
void Init() {
rep(i,2,n) Log[i]=Log[i>>1]+1;
rep(i,1,n) st[0][i]=lcp[i];
rep(i,1,Log[n]) {
int len=1<<(i-1);
rep(j,1,n-len+1) st[i][j]=min(st[i-1][j],st[i-1][j+len]);
}
}
int LCP(int i,int j) {
if(i==j) return n-sa[i]+1;
if(i>j) swap(i,j);
j--;
int d=Log[j-i+1];
return min(st[d][i],st[d][j-(1<<d)+1]);
}
} S;

struct SA_Solver:Suffix_Array{
int stk[N],top,ls[N],rs[N],mk[N];
ll ans,F[N*2];
set <int> st[N*2];
void dfs(int &u,int l,int r,int lst){
if(l==r) {
u=++m;
int p=sa[l];
if(p>2) {
int q=n-(p-2)+1;
F[u]=n-q+1;
st[u].insert(S.rk[q]);
}
if(p>1) ans+=1ll*(n-p+1-lst)*(F[u]+1);
return;
}
dfs(ls[u],l,u,lcp[u]),dfs(rs[u],u+1,r,lcp[u]);
if(st[ls[u]].size()>st[rs[u]].size()) swap(ls[u],rs[u]);
swap(st[u],st[rs[u]]),F[u]=F[ls[u]]+F[rs[u]];

int t=-1;
for(int i:st[ls[u]]) {
if(~t) F[u]+=S.LCP(t,i);
t=i;
auto r=st[u].upper_bound(i);
if(r!=st[u].end()) F[u]-=S.LCP(i,*r);
if(r!=st[u].begin()) {
auto l=r; l--;
if(r!=st[u].end()) F[u]+=S.LCP(*l,*r);
F[u]-=S.LCP(*l,i);
}
st[u].insert(i);
}
ans+=1ll*(lcp[u]-lst)*(F[u]+1);
}
void Solve(){
rep(i,1,n-1) {
while(top && lcp[stk[top]]>lcp[i]) ls[i]=stk[top--];
if(top) rs[stk[top]]=i;
stk[++top]=i;
}
rep(i,1,n-1) mk[ls[i]]=mk[rs[i]]=1;

rep(i,1,n) ans+=n-i+1-lcp[i];
ans++;
int lst=-1;
rep(i,1,n) if(S.sa[i]>1) {
ans+=n-S.sa[i]+1;
if(~lst) ans-=min(S.LCP(i,lst),min(n-S.sa[i]+1,n-S.sa[lst]+1));
lst=i;
}
ans++;
rep(i,1,n-1) if(!mk[i]) dfs(i,1,n,0);
printf("%lld\n",ans);
}
} T;

int main(){
scanf("%s",s+1),n=m=strlen(s+1);
if(n==1) return puts("3"),0;
T.Build(),reverse(s+1,s+n+1),S.Build(),S.Init();
T.Solve();
}