题目描述
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
输入输出格式
输入格式:
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
输出格式:
输出一个整数表示答案
输入输出样例
输入样例#1: 复制
aabb bbaa
输出样例#1: 复制
10
把两个串串起来加”#“建立”广义“后缀自动机。对于一个节点记录两个信息。一个维护在前一个字符串中出现的次数,一个维护在后一个字符串中出现的次数。最后每一个节点 i 的贡献就是:(len[i]-len[fail[i]]) * num[0][i] * num[1][i]。
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn = 2e6 + 5;
int trie[maxn][27], fail[maxn], len[maxn], last, cnt;
int num[2][maxn], tot[maxn], ord[maxn], lcnt, none, lena, lenb;
char s[maxn], a[maxn], sp[maxn];
long long ans;
void add(int x, int c) {
int p = last, u = ++cnt;
++num[c][u];
int id = s[x] - 'a';
last = u, len[u] = x + 1;
while(p && !trie[p][id])
trie[p][id] = u, p = fail[p];
if(!p) fail[u] = 1;
else {
int q = trie[p][id];
if(len[p] + 1 == len[q]) fail[u] = q;
else {
int lca = ++cnt;
fail[lca] = fail[q];
fail[q] = fail[u] = lca;
len[lca] = len[p] + 1;
memcpy(trie[lca], trie[q], sizeof(trie[q]));
while(p && trie[p][id] == q) {
trie[p][id] = lca;
p = fail[p];
}
}
}
}
int main() {
scanf("%s%s", s, a);
sp[0] = 'z' + 1, last = cnt = 1;
lena = strlen(s), lenb = strlen(a);
strcat(s, sp), strcat(s, a);
for(register int i = 0; i <= lena; ++i) add(i, 0);
for(register int i = lena + 1; i < lena + 1 + lenb; ++i) add(i, 1);
for(register int i = 1; i <= cnt; ++i) ++tot[len[i]];
for(register int i = 1; i <= cnt; ++i) tot[i] += tot[i - 1];
for(register int i = cnt; i >= 1; --i) ord[tot[len[i]]--] = i;
for(register int i = cnt; i >= 1; --i) {
num[0][fail[ord[i]]] += num[0][ord[i]];
num[1][fail[ord[i]]] += num[1][ord[i]];
}
for(register int i = 1; i <= cnt; ++i) {
ans += 1LL * (len[i] - len[fail[i]]) * num[0][i] * num[1][i];
}
printf("%lld", ans);
return 0;
}
不用土方法建真广义自动机的版本:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 1e6 + 5;
char s[maxn];
namespace SAM {
int trie[maxn << 1][26], len[maxn << 1], fa[maxn << 1];
int cnt = 1, last = 1, sum[maxn << 1][2];
void add(int x, int t) {
int id = s[x] - 'a';
int p = last, np, f = 0;
if((!trie[p][id]) || !(len[trie[p][id]] == len[p] + 1))
np = ++cnt;
else np = trie[p][id], f = 1;
++sum[np][t], last = np;
if(f) return;
len[np] = x + 1;
while(p && !trie[p][id])
trie[p][id] = np, p = fa[p];
if(!p) fa[np] = 1;
else {
int q = trie[p][id];
if(len[q] == len[p] + 1) fa[np] = q;
else {
int lca = ++cnt;
len[lca] = len[p] + 1;
memcpy(trie[lca], trie[q], sizeof(trie[q]));
fa[lca] = fa[q], fa[q] = fa[np] = lca;
while(p && trie[p][id] == q)
trie[p][id] = lca, p = fa[p];
}
}
}
}
int tot[maxn << 1], top[maxn << 1];
#define LL long long
int l, la, lb; LL ans;
char a[maxn], b[maxn];
int main() {
//freopen("finde.in", "r", stdin);
//freopen("finde.out", "w", stdout);
using namespace SAM;
scanf("%s", s), l = strlen(s);
memcpy(a, s, sizeof(s));
for(register int i = 0; i < l; ++i) add(i, 0);
last = 1;
scanf("%s", s), l = strlen(s);
for(register int i = 0; i < l; ++i) add(i, 1);
for(register int i = 1; i <= cnt; ++i) ++tot[len[i]];
for(register int i = 1; i <= cnt; ++i) tot[i] += tot[i - 1];
for(register int i = 1; i <= cnt; ++i)
top[tot[len[i]]--] = i;
for(register int i = cnt; i >= 1; --i) {
sum[fa[top[i]]][0] += sum[top[i]][0];
sum[fa[top[i]]][1] += sum[top[i]][1];
}
for(register int i = cnt; i >= 2; --i)
ans += 1ll * sum[i][0] * sum[i][1] * (len[i] - len[fa[i]]);
printf("%lld", ans);
return 0;
}
rockdu
没有帐号? 立即注册