【题目描述】
给定正整数m以及n个01串s1~sn,你需要求出长度为2m的反对称的包含这n个01串作为子串的01串的个数。对998244353取模。
一个01串s是反对称的当且仅当它对于1<=i<=|s|都满足s[i]≠s[|s|-i+1]。
【输入数据】
第一行两个整数n,m。接下来n行每行一个字符串s1~sn。
【输出数据】
一行一个整数表示答案。
【样例输入】
2 3
011
001
【样例输出】
4
【数据范围】
对于10%的数据,m<=15。
对于40%的数据,n<=4,|si|<=20。
对于60%的数据,n<=6,|si|<=30,m<=100。
对于另外20%的数据,n=1。
对于100%的数据,n<=6,|si|<=100,m<=500。
直接求出反对称串很不好求,我们考虑求出反对称串的一半。感觉上可以DP,但是发现要匹配到每一个01串,不是很好记状态。想到如果建立一个AC自动机,把自动机上跳到的节点加入状态。用dp[i][j][mask]表示到第i个位置,自动机上节点为j,已经满足了mask二进制状态压缩的这些字符串为子串的方案数。这样我们只用把每一个01串正着加入AC自动机,反着加入AC自动机,然后DP一下就好了。注意反串也要加入mask的状态。
那么问题就在中间衔接部分了,如果衔接部分满足了一个01串的条件,这个答案也是合法的。但是可以看到,串长最大是100,于是我们1e6预处理出每一个串左边比右边长并且右边向左边翻折合法的情况,同样加入自动机。只是在中途dp时不计答案,用另一个数组存放信息。最后特殊统计这些节点的答案就行了。
#include<cstdio> #include<algorithm> #include<queue> #include<cstring> using namespace std; const int maxn = 1e3 + 5, mod = 998244353; int dp[maxn][maxn][65], n, m, nxt, ans, flag; char s[6][2][maxn]; namespace AC { int trie[maxn][2], fail[maxn], str[maxn], csc[maxn]; /*trie树、fail指针、二进制原串信息、横跨中间且左边长的串信息*/ int cnt; void insert(char * s, int len, int d, int f) { /*f=1 普通串 f=0处理串*/ int now = 0; for(register int i = 0; i < len; ++i) { int id = s[i] - '0'; if(!trie[now][id]) now = trie[now][id] = ++cnt; else now = trie[now][id]; } if(f) str[now] |= 1 << d; else csc[now] |= 1 << d; } void create() { queue<int > q; q.push(0), fail[0] = -1; while(!q.empty()) { int now = q.front(); q.pop(); for(register int i = 0; i < 2; ++i) { int v = trie[now][i]; if(!v) continue; int p = fail[now]; while(~p) { if(trie[p][i]) { fail[v] = trie[p][i]; break; } p = fail[p]; } q.push(v); } } } int trans(int now, int id) { while(~now) { if(trie[now][id]) return trie[now][id]; now = fail[now]; } return 0; } } void rev(char s[2][maxn]) { int len = strlen(s[0]) - 1; for(register int i = 0; i <= len / 2; ++i) s[1][len - i] = s[0][i], s[1][i] = s[0][len - i]; for(register int i = 0; i <= len; ++i) s[1][i] = s[1][i] == '0' ? '1' : '0'; } int main() { freopen("string.in", "r", stdin); freopen("string.out", "w", stdout); using namespace AC; scanf("%d%d", &n, &m); for(register int i = 0; i < n; ++i) { scanf("%s", s[i][0]), insert(s[i][0], strlen(s[i][0]), i, 1); rev(s[i]), insert(s[i][1], strlen(s[i][0]), i, 1); } for(register int i = 0; i < n; ++i) { int len = strlen(s[i][0]) - 1; for(register int t = 0; t < 2; ++t) { for(register int j = 0; j <= len; ++j) { flag = 0; for(register int k = j + 1; k <= len; ++k) if(j - (k - j) + 1 < 0 || s[i][t][k] == s[i][t][j - (k - j) + 1]) {flag = 1; break;} if(!flag) insert(s[i][t], j + 1, i, 0); } } } create(); for(register int i = 1; i <= cnt; ++i) csc[i] |= csc[fail[i]], str[i] |= str[fail[i]]; dp[0][0][0] = 1; int stmax = (1 << n) - 1; for(register int i = 0; i < m; ++i) for(register int j = 0; j <= cnt; ++j) for(register int k = 0; k <= stmax; ++k) { nxt = trans(j, 0); (dp[i + 1][nxt][k | str[nxt]] += dp[i][j][k]) %= mod; nxt = trans(j, 1); (dp[i + 1][nxt][k | str[nxt]] += dp[i][j][k]) %= mod; } for(register int i = 0; i <= cnt; ++i) for(register int j = 0; j <= stmax; ++j) if((j | csc[i]) == stmax) (ans += dp[m][i][j]) %= mod; printf("%d\n", ans); return 0; }
没有帐号? 立即注册