KMP & exKMP
其实就考虑计算:后缀函数(\pi
) 和 Z函数
后缀函数的定义是一个串的最长公共前真后缀长度
Z函数的定义是一个串以某一位开始的后缀与原串的最长公共前缀长
KMP
O(n)
计算 后缀函数
O(n^2)
的算法很容易考虑,枚举每个可能的公共前缀开头向后匹配更新即可
考虑修改一下这一过程
假定已计算出了前 k
位的后缀函数,现考虑第 k + 1
位的
第 k + 1
位将在原有某个后缀的基础上增加一位,这一后缀必然是包含在第 k
位的最长公共前后缀内的
所以首先考虑在第 k
位的最长公共前后缀基础上添上一位,即尝试匹配 S_{k + 1}
与 S_{\pi(k) + 1}
当失配时即继续考虑更短的公共前后缀,那显然继续考虑的就是第 \pi(k)
位的最长公共前后缀基础上添加一位(其实也就是第 k
位稍短一些的最长公共前后缀)
依次这样迭代下去直到第一位
这个过程复杂度为 O(n)
不那么显然,需要分析一下
其实也就是看总的匹配需求次数,好吧其实挺显然的
因为每次成功的匹配至多使后续可能的失配次数+1,而出现失配后会使后续可能的失配次数变少
所以总的匹配次数至多为 2n
字符串匹配
后缀函数最常见的应用就是做字符串匹配了(还有各种简单字符串题)
考虑正在匹配模式串的第 j
位,原串的第 i
位
若失配,由于模式串的前 j - 1
位均匹配过了,则模式串的前 \pi(j - 1)
还是匹配的,将 j
跳转过去就行了
若匹配,则考虑下一位,i,j
均后移一位
特别的,有成功匹配处,则有 i
指向模式串的后一位,可作为答案拿出,并视作失配继续进行下去直到匹配串走完
找个板子题做做
洛谷 P3375 【模板】KMP字符串匹配
#include <cstring>
#include <iostream>
using namespace std;
const int MAX_N = 1E6 + 100;
int n, m;
char a[MAX_N], b[MAX_N];
int pi[MAX_N];
int main() {
ios::sync_with_stdio(false);
cin >> (a + 1);
cin >> (b + 1);
n = strlen(a + 1);
m = strlen(b + 1);
// init
pi[0] = -1;
for (int i = 1; i <= m; ++i) {
int j = pi[i - 1];
while (j >= 0 && b[j + 1] != b[i]) j = pi[j];
pi[i] = j + 1;
}
for (int i = 1, j = 1; i <= n; ++i, ++j) {
while (j && a[i] != b[j]) {
j = pi[j - 1] + 1;
}
if (j == m) {
cout << i - m + 1 << '\n';
}
}
for (int i = 1; i <= m; ++i) {
cout << pi[i] << ' ';
}
cout << endl;
return 0;
}
Z Algorithm
O(n)
计算 Z函数
O(n^2)
的算法很容易考虑,暴力向后直到失配即可
考虑能否模仿一下kmp的过程,利用原先匹配的结果
假定已经计算出了前 k
位的Z函数,考虑第 k + 1
位的Z函数
首先考虑最长的可能,去掉第 k
位的那个字符,等同于去掉第 1
位的那个字符
也就是 Z(k + 1)
至少有 min(Z(k) - 1, Z(1))
(下标从 0 开始记)
可以从 min(Z(k) - 1, Z(1)) + 1
开始暴力匹配
但是这一优化显然不够,还是有可能会多次匹配同一字符
容易想到可以把 1
扩展到任意的 x
,有
Z(k + x) \geq min\{Z(k) - x, Z(x)\}
且若 Z(x) < Z(k) - x
,有 S_{Z(x)} \neq S_{x + Z(x)} = S_{k + x + Z(x)}
,即有 Z(k + x) = Z(x)
故不妨每次从使 k + Z(k) - 1
最大的一个 k
处考虑
当 Z(x) < Z(k) - x
时可以 O(1)
计算
否则直接向后暴力
这样每个字符显然只被匹配一次,很显然是 O(n)
的
计算模式串与匹配串每个后缀的LCP
因为当求得的LCP为模式串长时即为找到了匹配,故字符串匹配为该问题的一个子集
所以该算法也称为扩展KMP
考虑匹配过程中对匹配串最远访问到了 k + lcp(k) - 1
处
即从 k
开始匹配了 lcp(k)
位
则有 lcp(k + x) \geq min\{lcp(k) - x, Z(x)\}
类似算Z数组的时候弄弄就好了
有两道题可以一做
P5410 【模板】扩展 KMP(Z 函数)
#include <cstring>
#include <iostream>
using namespace std;
using LL = long long;
const int MAX_N = 2E7 + 100;
int n, m;
int z[MAX_N], p[MAX_N];
char a[MAX_N], b[MAX_N];
int main() {
ios::sync_with_stdio(false);
cin >> a;
n = strlen(a);
cin >> b;
m = strlen(b);
// z[0] = 0;
for (int i = 1, k = 0; i < m; ++i) {
int x = i - k;
if (z[x] < z[k] - x) {
z[i] = z[x];
} else {
z[i] = max(0, z[k] - x);
while (b[z[i]] == b[i + z[i]]) z[i] += 1;
if (k + z[k] < i + z[i]) k = i;
}
}
LL ans = m + 1;
for (int i = 1; i < m; ++i) {
ans ^= (i + 1ll) * (z[i] + 1);
}
cout << ans << endl;
for (int i = 0, k = 0; i < n; ++i) {
int x = i - k;
if (z[x] < p[k] - x) {
p[i] = z[x];
} else {
p[i] = max(0, p[k] - x);
while (i + p[i] < n && b[p[i]] == a[i + p[i]]) p[i] += 1;
if (k + p[k] < i + p[i]) k = i;
}
}
ans = 0;
for (int i = 0; i < n; ++i) {
ans ^= (i + 1ll) * (p[i] + 1);
}
cout << ans << endl;
return 0;
}
Codeforces Round #741 (Div. 2) E. Rescue Niwen!
题意
求一个字符串的所有子串的“最长上升子序列”长度,其中子串先按起始下标顺序,再按终止下标倒序排序排列,子串大小即按字符串比较来排列。
如abac,生成的序列为 a ab aba abac b ba bac a ac c
数据范围
多组数据,卡log
思路
可以很简单的构造 O(n^3)
的dp
因为子串的排列方式,可以看出整个序列可以被划分为 n
段单调增的部分,显然,如果要取一段,那么必然从某一位开始取完剩下的这段
那么将状态设为 dp[i]
,取了第 i
组的某一个后缀后,最长上升子序列的长度,转移方程为
dp[j] = Max\{dp[i] + n - i + 1 - f(i, j)\}
其中 f(i, j)
为 s[i..n],s[j..n]
的最长公共前缀长度,且有 s[i + f(i, j)] > s[j + f(i, j)]
暴力求最长公共前缀,是 O(n^3)
的,也就是现有瓶颈
要考虑 O(n^2)
求出这个东西
显然,可以裸套Z Algorithm
但是其实这玩意非常愚蠢,转移如下
若 s[i] = s[j], f(i, j) = f(i + 1, j + 1) + 1
否则 f(i, j) = 0
AC代码(Z函数版本)
#include <cstring>
#include <iostream>
using namespace std;
const int MAX_N = 5000 + 7;
int T, n;
char s[MAX_N];
int z[MAX_N][MAX_N];
int dp[MAX_N];
void getZ(int n, char s[], int z[]) {
z[0] = 0;
for (int i = 1, k = 0; i < n; ++i) {
int x = i - k;
z[i] = 0;
if (z[x] < z[k] - x) {
z[i] = z[x];
} else {
z[i] = max(0, z[k] - x);
while (s[z[i]] == s[i + z[i]]) z[i] += 1;
if (k + z[k] < i + z[i]) k = i;
}
}
}
int main() {
ios::sync_with_stdio(false);
cin >> T;
while (T--) {
cin >> n;
cin >> (s + 1);
for (int i = 1; i <= n; ++i) {
dp[i] = 0;
getZ(n - i + 1, s + i, z[i] + i);
}
int ans = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 0; j < i; ++j) {
if (s[i + z[j][i]] > s[j + z[j][i]]) {
dp[i] = max(dp[i], dp[j] + n - i + 1 - z[j][i]);
}
}
ans = max(ans, dp[i]);
}
cout << ans << '\n';
}
return 0;
}