forked from karpathy/minbpe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbpe_basic.py
139 lines (114 loc) · 4.63 KB
/
bpe_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.
Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
But:
- Does not handle the regular expression splitting pattern.
- Does not handle any special tokens.
"""
def get_stats(ids):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
"""
counts = {}
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
class Tokenizer:
def __init__(self):
# by default, we have a vocab size of 256 (all bytes) and no merges
self.merges = {}
self.vocab = {idx: bytes([idx]) for idx in range(256)}
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# input text preprocessing
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurences of pair in ids with idx
ids = merge(ids, pair, idx)
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
if __name__ == "__main__":
"""
Quick unit test, following along the Wikipedia example:
https://en.wikipedia.org/wiki/Byte_pair_encoding
According to Wikipedia, running bpe on the the input string:
"aaabdaaabac"
for 3 merges will result in string:
"XdXac"
where:
X=ZY
Y=ab
Z=aa
Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values)
so Z will be 256, Y will be 257, X will be 258.
So we expect the output list of ids to be [258, 100, 258, 97, 99]
"""
text = "aaabdaaabac"
tokenizer = Tokenizer()
# we do 3 merges
tokenizer.train(text, 256 + 3)
# verify the correct expected result
ids = tokenizer.encode(text)
print("OK" if ids == [258, 100, 258, 97, 99] else "FAIL")
# verify that decode(encode(x)) == x
print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL")