KMP 算法介绍

简介

kmp算法是一种串的匹配算法。

kmp算法在蛮力算法的基础上,通过引入额外的next表,优化失配后的回退“长度”。将时间复杂度从O(m*n)减少为O(n)。

蛮力算法

对于串匹配,蛮力算法是一种很容易想到的算法,其正确性也显而易见。具体的,将主串与模式串一一比较,一旦某一字符不同(失配),则将模式串回退到与主串下一紧邻的字符,继续比较。

一种蛮力算法的实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
static int brute_force(String source, String target) {

final int n = source.length();
final int m = target.length();

if (m > n) return -1;

int i = 0;
int j = 0;

while (i < n && j < m) {
if (source.charAt(i) == target.charAt(j)) {
i++;
j++;
} else {
i -= j - 1;
j = 0;
}
}

if (j == m) return i-j;
else return -1;

}

KMP算法

不难看出,在蛮力算法中,一旦发生失配,主、模式串将(可能)会有大幅度的回退。而这些回退都是有必要的吗?

考虑发生大幅度回退的情况。具体的,在发生失配,即 source.charAt(i) != target.charAt(j)的时候。此时,target[0, j) = source[i-j, i),我们当然希望主串指针 i 能够不回退,而模式串指针 j 直接归为 0 与 i 对其。即让 i “从哪里跌倒,就从哪里爬起来”。但我们之所以不能简单地这样做,原因在于在已经匹配过的目标串的子串source[i-j, i)中,或许存在后缀source[x, i),其与模式串前i-x位相同,为了不遗漏这一种情况,我们应该让模式串指针 j 回退到第i-x+1。更进一步的,为了不遗漏任何情况,j的回退距离应该尽可能的小,即j的值尽可能的大,也即后缀source[x, i)尽可能的长。

而由于target[0, j) = source[i-j, i),所以j的值可以定义为串 target[0, j)既是前缀又是后缀的最长缀的长度。注意,j的最优值或者说下一个取值居然与目标串无关!因此,我们完全可以通过预处理,基于模式串target构造一个数组 next[j],表示每当模式串在j处发生失配后,j的下一个取值。

得到next数组后,KMP算法只需对蛮力算法进行少量改动,一种实现的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
static int kmp(String source, String target) {
final int n = source.length();
final int m = target.length();

if (m > n) return -1;

int[] next = buildNext(target);

int i = 0;
int j = 0;

while (i < n && j < m) {
// 让j=next[0]=-1 让 i,j 同时++,造成 j=0, i++
if (j == -1 || source.charAt(i) == target.charAt(j)) {
i++;
j++;
} else {
// i 不变
j = next[j];
}
}

if (j == m) return i-j;
else return -1;
}

注意这里定义 next[0]=-1:因为当主、模串第一位都不匹配时,应该让i++,j不变,但为了不引入难看的if判断,将j定义为-1,并修改循环内第一个if,使其等价为i++,j=0不变。

next表生成

next[j]=t,则target[0, j)自匹配的前、后缀的最大长度为t。

next[j+1] <= (t = next[j]) + 1,当且仅当target[j] == target[t]取等,

否则,next[j+1] <= (t = next[j-1]) + 1,当且仅当target[j] == target[t]取等,

由上述,next表可以递推的产生。一种实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
static private int[] buildNext(String target) {
final int m = target.length();
int[] next = new int[m];
int j = 0;
int t = next[0] = -1;
while (j < m - 1) {
if (t == -1 || target.charAt(j) == target.charAt(t)) {
j++; t++;
next[j] = t;
} else
t = next[t];
}

return next;

}