题目描述
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
输入输出格式
输入格式:
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
输出格式:
输出一个整数表示答案
输入输出样例
输入样例#1: 复制
aabb bbaa
输出样例#1: 复制
10
把两个串串起来加”#“建立”广义“后缀自动机。对于一个节点记录两个信息。一个维护在前一个字符串中出现的次数,一个维护在后一个字符串中出现的次数。最后每一个节点 i 的贡献就是:(len[i]-len[fail[i]]) * num[0][i] * num[1][i]。
#include<cstdio> #include<cstring> using namespace std; const int maxn = 2e6 + 5; int trie[maxn][27], fail[maxn], len[maxn], last, cnt; int num[2][maxn], tot[maxn], ord[maxn], lcnt, none, lena, lenb; char s[maxn], a[maxn], sp[maxn]; long long ans; void add(int x, int c) { int p = last, u = ++cnt; ++num[c][u]; int id = s[x] - 'a'; last = u, len[u] = x + 1; while(p && !trie[p][id]) trie[p][id] = u, p = fail[p]; if(!p) fail[u] = 1; else { int q = trie[p][id]; if(len[p] + 1 == len[q]) fail[u] = q; else { int lca = ++cnt; fail[lca] = fail[q]; fail[q] = fail[u] = lca; len[lca] = len[p] + 1; memcpy(trie[lca], trie[q], sizeof(trie[q])); while(p && trie[p][id] == q) { trie[p][id] = lca; p = fail[p]; } } } } int main() { scanf("%s%s", s, a); sp[0] = 'z' + 1, last = cnt = 1; lena = strlen(s), lenb = strlen(a); strcat(s, sp), strcat(s, a); for(register int i = 0; i <= lena; ++i) add(i, 0); for(register int i = lena + 1; i < lena + 1 + lenb; ++i) add(i, 1); for(register int i = 1; i <= cnt; ++i) ++tot[len[i]]; for(register int i = 1; i <= cnt; ++i) tot[i] += tot[i - 1]; for(register int i = cnt; i >= 1; --i) ord[tot[len[i]]--] = i; for(register int i = cnt; i >= 1; --i) { num[0][fail[ord[i]]] += num[0][ord[i]]; num[1][fail[ord[i]]] += num[1][ord[i]]; } for(register int i = 1; i <= cnt; ++i) { ans += 1LL * (len[i] - len[fail[i]]) * num[0][i] * num[1][i]; } printf("%lld", ans); return 0; }
不用土方法建真广义自动机的版本:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn = 1e6 + 5; char s[maxn]; namespace SAM { int trie[maxn << 1][26], len[maxn << 1], fa[maxn << 1]; int cnt = 1, last = 1, sum[maxn << 1][2]; void add(int x, int t) { int id = s[x] - 'a'; int p = last, np, f = 0; if((!trie[p][id]) || !(len[trie[p][id]] == len[p] + 1)) np = ++cnt; else np = trie[p][id], f = 1; ++sum[np][t], last = np; if(f) return; len[np] = x + 1; while(p && !trie[p][id]) trie[p][id] = np, p = fa[p]; if(!p) fa[np] = 1; else { int q = trie[p][id]; if(len[q] == len[p] + 1) fa[np] = q; else { int lca = ++cnt; len[lca] = len[p] + 1; memcpy(trie[lca], trie[q], sizeof(trie[q])); fa[lca] = fa[q], fa[q] = fa[np] = lca; while(p && trie[p][id] == q) trie[p][id] = lca, p = fa[p]; } } } } int tot[maxn << 1], top[maxn << 1]; #define LL long long int l, la, lb; LL ans; char a[maxn], b[maxn]; int main() { //freopen("finde.in", "r", stdin); //freopen("finde.out", "w", stdout); using namespace SAM; scanf("%s", s), l = strlen(s); memcpy(a, s, sizeof(s)); for(register int i = 0; i < l; ++i) add(i, 0); last = 1; scanf("%s", s), l = strlen(s); for(register int i = 0; i < l; ++i) add(i, 1); for(register int i = 1; i <= cnt; ++i) ++tot[len[i]]; for(register int i = 1; i <= cnt; ++i) tot[i] += tot[i - 1]; for(register int i = 1; i <= cnt; ++i) top[tot[len[i]]--] = i; for(register int i = cnt; i >= 1; --i) { sum[fa[top[i]]][0] += sum[top[i]][0]; sum[fa[top[i]]][1] += sum[top[i]][1]; } for(register int i = cnt; i >= 2; --i) ans += 1ll * sum[i][0] * sum[i][1] * (len[i] - len[fa[i]]); printf("%lld", ans); return 0; }
没有帐号? 立即注册