forked from karpathy/minbpe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpe_basic.py
110 lines (87 loc) · 3.7 KB
/
bpe_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.
Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
But:
- Does not handle the regular expression splitting pattern.
- Does not handle any special tokens.
"""
from bpe_base import Tokenizer, get_stats, merge
class BasicTokenizer(Tokenizer):
def __init__(self):
super().__init__()
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# input text preprocessing
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurences of pair in ids with idx
ids = merge(ids, pair, idx)
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
if __name__ == "__main__":
"""
Quick unit test, following along the Wikipedia example:
https://en.wikipedia.org/wiki/Byte_pair_encoding
According to Wikipedia, running bpe on the the input string:
"aaabdaaabac"
for 3 merges will result in string:
"XdXac"
where:
X=ZY
Y=ab
Z=aa
Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values)
so Z will be 256, Y will be 257, X will be 258.
So we expect the output list of ids to be [258, 100, 258, 97, 99]
"""
text = "aaabdaaabac"
tokenizer = BasicTokenizer()
# we do 3 merges
tokenizer.train(text, 256 + 3)
# verify the correct expected result
ids = tokenizer.encode(text)
print("OK" if ids == [258, 100, 258, 97, 99] else "FAIL")
# verify that decode(encode(x)) == x
print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL")