1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
|
#include <bits/stdc++.h>
#define ms(a, x) memset(a, x, sizeof(a))
typedef long long LL;
using namespace std;
const int N = 5e5 + 7;
struct SAM {
#define MAXALP 30
struct state {
int len, link, nxt[MAXALP];
int leftmost; //某个状态的right集合中r值最小的
int rightmost; //某个状态的right集合的r的最大值
int Right; //right集合大小
};
state st[N * 2];
char S[N];
int sz, last, rt;
char s[N];
int cnt[2 * N], rk[2 * N]; //for radix sort
int idx(char c){
if (c >= 'a' && c <= 'z')
return c - 'a';
return c - 'A' + 26;
}
void init(){
sz = 0;
ms(st, 0);
last = rt = ++sz;
st[1].len = 0;
st[1].link = -1;
st[1].rightmost = 0;
ms(st[1].nxt, -1);
}
void extend(int c, int head){
int cur = ++sz;
st[cur].len = st[last].len + 1;
st[cur].leftmost = st[cur].rightmost = head;
memset(st[cur].nxt, -1, sizeof(st[cur].nxt));
int p;
for (p = last; p != -1 && st[p].nxt[c] == -1; p = st[p].link)
st[p].nxt[c] = cur;
if (p == -1) {
st[cur].link = rt;
} else {
int q = st[p].nxt[c];
if (st[p].len + 1 == st[q].len) {
st[cur].link = q;
} else {
int clone = ++sz;
st[clone].len = st[p].len + 1;
st[clone].link = st[q].link;
memcpy(st[clone].nxt, st[q].nxt, sizeof(st[q].nxt));
st[clone].leftmost = st[q].leftmost;
st[clone].rightmost = st[q].rightmost;
for (; p != -1 && st[p].nxt[c] == q; p = st[p].link)
st[p].nxt[c] = clone;
st[q].link = st[cur].link = clone;
}
}
last = cur;
}
void build(){
init();
for (int i = 0, _len = strlen(S); i < _len; i++) {
st[sz + 1].Right = 1;
extend(idx(S[i]), i);
}
}
void topo(){
ms(cnt, 0);
for (int i = 1; i <= sz; i++) cnt[st[i].len]++;
for (int i = 1; i <= sz; i++) cnt[i] += cnt[i - 1];
//rk[1]是len最小的状态的标号
for (int i = 1; i <= sz; i++) rk[cnt[st[i].len]--] = i;
}
//跑拓扑序,预处理一些东西
void pre(){
for (int i = sz; i >= 2; i--) {
int v = rk[i];
int fa = st[v].link;
if (fa == -1) continue;
st[fa].rightmost = max(st[fa].rightmost, st[v].rightmost);
st[fa].Right += st[v].Right;
}
}
void solve(){
LL ans = 0;
for (int i = sz; i >= 2; i--) {
int v = rk[i];
if (st[v].link == -1) continue;
// 前面是串的组合
// 后面是 st[v].len - st[st[v].link].len是后缀的前缀,是本质不同的串的贡献
// 每个字母的贡献--->就是每个后缀节点t跳父亲节点fa跳掉的那部分t的前缀中的每一个字母开头的后缀都是和串t出现次数相同的!
ans = ans + 1LL * st[v].Right * (st[v].Right + 1) / 2 * (st[v].len - st[st[v].link].len);
// cout<<"TEST: "<<st[v].len - st[st[v].link].len<<endl;
}
printf("%lld\n", ans);
}
} A;
char B[N];
int main(){
scanf("%s", A.S);
A.build();
A.topo();
A.pre();
A.solve();
return 0;
}
|