forked from karpathy/minbpe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpe_gpt4.py
105 lines (89 loc) · 4.11 KB
/
bpe_gpt4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Implements the GPT-4 Tokenizer with a light wrapper around the RegexTokenizer.
"""
import tiktoken
from bpe_regex import RegexTokenizer
def bpe(mergeable_ranks, token, max_rank):
# helper function used in get_gpt4_merges() to reconstruct the merge forest
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts
def recover_merges(mergeable_ranks):
# the `merges` are already the byte sequences in their merged state.
# so we have to recover the original pairings. We can do this by doing
# a small BPE training run on all the tokens, in their order.
# also see https://github.com/openai/tiktoken/issues/60
merges = {}
for token, rank in mergeable_ranks.items():
if len(token) == 1:
continue # skip raw bytes
pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
assert len(pair) == 2
# recover the integer ranks of the pair
ix0 = mergeable_ranks[pair[0]]
ix1 = mergeable_ranks[pair[1]]
merges[(ix0, ix1)] = rank
return merges
class GPT4Tokenizer(RegexTokenizer):
"""Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""
def __init__(self):
super().__init__()
# get the official tokenizer and its merges
enc = tiktoken.get_encoding("cl100k_base")
mergeable_ranks = enc._mergeable_ranks
# the merges are those of gpt4, but we have to recover them
self.merges = recover_merges(mergeable_ranks)
# reconstruct the vocab from the merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
self.vocab = vocab
# now here is another tricky part.
# for some reason, the tokens corresponding to individual bytes
# are permuted in a different order. This is completely non-sensical
# and probably historical, but therefore we have to deal with it here.
self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}
def _encode_chunk(self, text_bytes):
# before we start processing bytes, we have to permute them
text_bytes = bytes(self.byte_shuffle[b] for b in text_bytes)
ids = super()._encode_chunk(text_bytes)
return ids
def decode(self, ids):
# we have to un-permute the bytes before we decode
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
if __name__ == "__main__":
# let's take it for a spin!
# tiktoken
enc = tiktoken.get_encoding("cl100k_base")
# vs.
tokenizer = GPT4Tokenizer()
# fight!
text = "hello world!!!? (안녕하세요!) lol123 😉"
print(text)
print(enc.encode(text)) # tiktoken
print(tokenizer.encode(text)) # ours
print(tokenizer.decode(tokenizer.encode(text))) # ours back to text
# two quick tests: equality (to tiktoken) and identity
print("OK" if enc.encode(text) == tokenizer.encode(text) else "FAIL")
print("OK" if text == tokenizer.decode(tokenizer.encode(text)) else "FAIL")
# let's also tokenize all of taylor swift, a bigger document just to make sure
text = open("taylorswift.txt", "r", encoding="utf-8").read()
t1 = enc.encode(text) # tiktoken
t2 = tokenizer.encode(text) # ours
print("OK" if t1 == t2 else "FAIL")
print("OK" if text == tokenizer.decode(tokenizer.encode(text)) else "FAIL")