【题目描述】
给定正整数m以及n个01串s1~sn,你需要求出长度为2m的反对称的包含这n个01串作为子串的01串的个数。对998244353取模。
一个01串s是反对称的当且仅当它对于1<=i<=|s|都满足s[i]≠s[|s|-i+1]。
【输入数据】
第一行两个整数n,m。接下来n行每行一个字符串s1~sn。
【输出数据】
一行一个整数表示答案。
【样例输入】
2 3
011
001
【样例输出】
4
【数据范围】
对于10%的数据,m<=15。
对于40%的数据,n<=4,|si|<=20。
对于60%的数据,n<=6,|si|<=30,m<=100。
对于另外20%的数据,n=1。
对于100%的数据,n<=6,|si|<=100,m<=500。
直接求出反对称串很不好求,我们考虑求出反对称串的一半。感觉上可以DP,但是发现要匹配到每一个01串,不是很好记状态。想到如果建立一个AC自动机,把自动机上跳到的节点加入状态。用dp[i][j][mask]表示到第i个位置,自动机上节点为j,已经满足了mask二进制状态压缩的这些字符串为子串的方案数。这样我们只用把每一个01串正着加入AC自动机,反着加入AC自动机,然后DP一下就好了。注意反串也要加入mask的状态。
那么问题就在中间衔接部分了,如果衔接部分满足了一个01串的条件,这个答案也是合法的。但是可以看到,串长最大是100,于是我们1e6预处理出每一个串左边比右边长并且右边向左边翻折合法的情况,同样加入自动机。只是在中途dp时不计答案,用另一个数组存放信息。最后特殊统计这些节点的答案就行了。
#include<cstdio>
#include<algorithm>
#include<queue>
#include<cstring>
using namespace std;
const int maxn = 1e3 + 5, mod = 998244353;
int dp[maxn][maxn][65], n, m, nxt, ans, flag;
char s[6][2][maxn];
namespace AC {
int trie[maxn][2], fail[maxn], str[maxn], csc[maxn];
/*trie树、fail指针、二进制原串信息、横跨中间且左边长的串信息*/
int cnt;
void insert(char * s, int len, int d, int f) { /*f=1 普通串 f=0处理串*/
int now = 0;
for(register int i = 0; i < len; ++i) {
int id = s[i] - '0';
if(!trie[now][id]) now = trie[now][id] = ++cnt;
else now = trie[now][id];
}
if(f) str[now] |= 1 << d;
else csc[now] |= 1 << d;
}
void create() {
queue<int > q;
q.push(0), fail[0] = -1;
while(!q.empty()) {
int now = q.front();
q.pop();
for(register int i = 0; i < 2; ++i) {
int v = trie[now][i];
if(!v) continue;
int p = fail[now];
while(~p) {
if(trie[p][i]) {
fail[v] = trie[p][i];
break;
}
p = fail[p];
}
q.push(v);
}
}
}
int trans(int now, int id) {
while(~now) {
if(trie[now][id]) return trie[now][id];
now = fail[now];
}
return 0;
}
}
void rev(char s[2][maxn]) {
int len = strlen(s[0]) - 1;
for(register int i = 0; i <= len / 2; ++i)
s[1][len - i] = s[0][i], s[1][i] = s[0][len - i];
for(register int i = 0; i <= len; ++i)
s[1][i] = s[1][i] == '0' ? '1' : '0';
}
int main() {
freopen("string.in", "r", stdin);
freopen("string.out", "w", stdout);
using namespace AC;
scanf("%d%d", &n, &m);
for(register int i = 0; i < n; ++i) {
scanf("%s", s[i][0]), insert(s[i][0], strlen(s[i][0]), i, 1);
rev(s[i]), insert(s[i][1], strlen(s[i][0]), i, 1);
}
for(register int i = 0; i < n; ++i) {
int len = strlen(s[i][0]) - 1;
for(register int t = 0; t < 2; ++t) {
for(register int j = 0; j <= len; ++j) {
flag = 0;
for(register int k = j + 1; k <= len; ++k)
if(j - (k - j) + 1 < 0 || s[i][t][k] == s[i][t][j - (k - j) + 1])
{flag = 1; break;}
if(!flag)
insert(s[i][t], j + 1, i, 0);
}
}
}
create();
for(register int i = 1; i <= cnt; ++i) csc[i] |= csc[fail[i]], str[i] |= str[fail[i]];
dp[0][0][0] = 1;
int stmax = (1 << n) - 1;
for(register int i = 0; i < m; ++i)
for(register int j = 0; j <= cnt; ++j)
for(register int k = 0; k <= stmax; ++k) {
nxt = trans(j, 0);
(dp[i + 1][nxt][k | str[nxt]] += dp[i][j][k]) %= mod;
nxt = trans(j, 1);
(dp[i + 1][nxt][k | str[nxt]] += dp[i][j][k]) %= mod;
}
for(register int i = 0; i <= cnt; ++i)
for(register int j = 0; j <= stmax; ++j)
if((j | csc[i]) == stmax)
(ans += dp[m][i][j]) %= mod;
printf("%d\n", ans);
return 0;
}
rockdu
没有帐号? 立即注册