Skip to content

Commit 1a05fcb

Browse files
committed
Add Rabin–Karp String Matching Algorithm (Fixes #13918)
1 parent 8934bab commit 1a05fcb

File tree

1 file changed

+97
-79
lines changed

1 file changed

+97
-79
lines changed

strings/rabin_karp.py

Lines changed: 97 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,109 @@
1-
# Numbers of alphabet which we call base
2-
alphabet_size = 256
3-
# Modulus to hash a string
4-
modulus = 1000003
1+
"""
2+
Rabin–Karp String Matching Algorithm
3+
https://en.wikipedia.org/wiki/Rabin%E2%80%93Karp_algorithm
4+
"""
55

6+
from typing import Dict, Iterable, List, Tuple
67

7-
def rabin_karp(pattern: str, text: str) -> bool:
8+
MOD: int = 1_000_000_007
9+
BASE: int = 257
10+
11+
12+
def rabin_karp(text: str, pattern: str) -> List[int]:
813
"""
9-
The Rabin-Karp Algorithm for finding a pattern within a piece of text
10-
with complexity O(nm), most efficient when it is used with multiple patterns
11-
as it is able to check if any of a set of patterns match a section of text in o(1)
12-
given the precomputed hashes.
14+
Return all starting indices where `pattern` appears in `text`.
15+
16+
>>> rabin_karp("abracadabra", "abra")
17+
[0, 7]
18+
>>> rabin_karp("aaaaa", "aa") # overlapping matches
19+
[0, 1, 2, 3]
20+
>>> rabin_karp("hello", "") # empty pattern matches everywhere
21+
[0, 1, 2, 3, 4, 5]
22+
>>> rabin_karp("", "abc")
23+
[]
24+
"""
25+
n, m = len(text), len(pattern)
26+
if m == 0:
27+
return list(range(n + 1))
28+
if n < m:
29+
return []
30+
31+
# Precompute BASE^(m-1) % MOD
32+
power = pow(BASE, m - 1, MOD)
33+
34+
# Hashes for pattern and first window of text
35+
hp = ht = 0
36+
for i in range(m):
37+
hp = (hp * BASE + ord(pattern[i])) % MOD
38+
ht = (ht * BASE + ord(text[i])) % MOD
39+
40+
results: List[int] = []
41+
42+
for i in range(n - m + 1):
43+
if hp == ht and text[i : i + m] == pattern:
44+
results.append(i)
45+
46+
if i < n - m:
47+
# sliding window: remove left char, add right char
48+
left = (ord(text[i]) * power) % MOD
49+
ht = (ht - left) % MOD
50+
ht = (ht * BASE + ord(text[i + m])) % MOD
1351

14-
This will be the simple version which only assumes one pattern is being searched
15-
for but it's not hard to modify
52+
return results
1653

17-
1) Calculate pattern hash
1854

19-
2) Step through the text one character at a time passing a window with the same
20-
length as the pattern
21-
calculating the hash of the text within the window compare it with the hash
22-
of the pattern. Only testing equality if the hashes match
55+
def rabin_karp_multi(text: str, patterns: Iterable[str]) -> Dict[str, List[int]]:
2356
"""
24-
p_len = len(pattern)
25-
t_len = len(text)
26-
if p_len > t_len:
27-
return False
28-
29-
p_hash = 0
30-
text_hash = 0
31-
modulus_power = 1
32-
33-
# Calculating the hash of pattern and substring of text
34-
for i in range(p_len):
35-
p_hash = (ord(pattern[i]) + p_hash * alphabet_size) % modulus
36-
text_hash = (ord(text[i]) + text_hash * alphabet_size) % modulus
37-
if i == p_len - 1:
57+
Multiple-pattern Rabin–Karp.
58+
Groups patterns by length and scans text once.
59+
60+
>>> rabin_karp_multi("abracadabra", ["abra", "bra", "cad"])
61+
{'abra': [0, 7], 'bra': [1, 8], 'cad': [4]}
62+
>>> rabin_karp_multi("aaaaa", ["aa", "aaa"])
63+
{'aa': [0, 1, 2, 3], 'aaa': [0, 1, 2]}
64+
"""
65+
patterns = list(patterns)
66+
result: Dict[str, List[int]] = {p: [] for p in patterns}
67+
68+
# Group patterns by length
69+
groups: Dict[int, List[str]] = {}
70+
for p in patterns:
71+
groups.setdefault(len(p), []).append(p)
72+
73+
for length, group in groups.items():
74+
if length == 0:
75+
for p in group:
76+
result[p] = list(range(len(text) + 1))
3877
continue
39-
modulus_power = (modulus_power * alphabet_size) % modulus
4078

41-
for i in range(t_len - p_len + 1):
42-
if text_hash == p_hash and text[i : i + p_len] == pattern:
43-
return True
44-
if i == t_len - p_len:
79+
# Precompute pattern hashes
80+
p_hash: Dict[int, List[str]] = {}
81+
for p in group:
82+
h = 0
83+
for c in p:
84+
h = (h * BASE + ord(c)) % MOD
85+
p_hash.setdefault(h, []).append(p)
86+
87+
# Scan text using sliding window hashing
88+
if len(text) < length:
4589
continue
46-
# Calculate the https://en.wikipedia.org/wiki/Rolling_hash
47-
text_hash = (
48-
(text_hash - ord(text[i]) * modulus_power) * alphabet_size
49-
+ ord(text[i + p_len])
50-
) % modulus
51-
return False
5290

91+
power = pow(BASE, length - 1, MOD)
92+
h = 0
93+
for i in range(length):
94+
h = (h * BASE + ord(text[i])) % MOD
95+
96+
for i in range(len(text) - length + 1):
97+
if h in p_hash:
98+
window = text[i : i + length]
99+
for p in p_hash[h]:
100+
if window == p:
101+
result[p].append(i)
102+
103+
if i < len(text) - length:
104+
left = (ord(text[i]) * power) % MOD
105+
h = (h - left) % MOD
106+
h = (h * BASE + ord(text[i + length])) % MOD
107+
108+
return result
53109

54-
def test_rabin_karp() -> None:
55-
"""
56-
>>> test_rabin_karp()
57-
Success.
58-
"""
59-
# Test 1)
60-
pattern = "abc1abc12"
61-
text1 = "alskfjaldsabc1abc1abc12k23adsfabcabc"
62-
text2 = "alskfjaldsk23adsfabcabc"
63-
assert rabin_karp(pattern, text1)
64-
assert not rabin_karp(pattern, text2)
65-
66-
# Test 2)
67-
pattern = "ABABX"
68-
text = "ABABZABABYABABX"
69-
assert rabin_karp(pattern, text)
70-
71-
# Test 3)
72-
pattern = "AAAB"
73-
text = "ABAAAAAB"
74-
assert rabin_karp(pattern, text)
75-
76-
# Test 4)
77-
pattern = "abcdabcy"
78-
text = "abcxabcdabxabcdabcdabcy"
79-
assert rabin_karp(pattern, text)
80-
81-
# Test 5)
82-
pattern = "Lü"
83-
text = "Lüsai"
84-
assert rabin_karp(pattern, text)
85-
pattern = "Lue"
86-
assert not rabin_karp(pattern, text)
87-
print("Success.")
88-
89-
90-
if __name__ == "__main__":
91-
test_rabin_karp()

0 commit comments

Comments
 (0)