forked from karpathy/minbpe
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
1,354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# minbpe | ||
|
||
Minimal, clean, educational code for the (byte-level) Byte Pair Encoding (BPE) algorithm commonly used in LLM tokenization. The BPE algorithm is "byte-level" because it runs on UTF-8 encoded strings. | ||
|
||
This algorithm was popularized for LLMs by the [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) and the associated GPT-2 [code release](https://github.com/openai/gpt-2) from OpenAI. Today, all modern LLMs (e.g. GPT, Llama, Mistral) use this algorithm to train their tokenizers. | ||
|
||
There are two Tokenizers in this repository, both of which can perform the 3 primary functions of a Tokenizer: 1) train the tokenizer vocabulary and merges on a given text, 2) encode from text to tokens, 3) decode from tokens to text. The two tokenizers are: | ||
|
||
1. [bpe_basic.py](bpe_basic.py): The simplest implementation of the BPE algorithm that runs directly on text. | ||
2. [bpe_regex.py](bpe_regex.py): This implementation further splits the input text by a regex pattern, which is a preprocessing stage that splits up the input text by categories (think: letters, numbers, punctuation) before tokenization. This ensures that no merges will happen across category boundaries. This was introduced in the GPT-2 paper and continues to be in use as of GPT-4. | ||
|
||
Finally, the script [train.py](train.py) trains both of these tokenizers on the input text [taylorswift.txt](taylorswift.txt) (this is the Wikipedia entry for her kek) and saves the vocab to disk for visualization. This script runs in about 25 seconds on my (M1) MacBook. | ||
|
||
# License | ||
|
||
MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
""" | ||
Minimal (byte-level) Byte Pair Encoding tokenizer. | ||
Algorithmically follows along the GPT tokenizer: | ||
https://github.com/openai/gpt-2/blob/master/src/encoder.py | ||
But: | ||
- Does not handle the regular expression splitting pattern. | ||
- Does not handle any special tokens. | ||
""" | ||
|
||
def get_stats(ids): | ||
""" | ||
Given a list of integers, return a dictionary of counts of consecutive pairs | ||
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} | ||
""" | ||
counts = {} | ||
for pair in zip(ids, ids[1:]): # iterate consecutive elements | ||
counts[pair] = counts.get(pair, 0) + 1 | ||
return counts | ||
|
||
|
||
def merge(ids, pair, idx): | ||
""" | ||
In the list of integers (ids), replace all consecutive occurrences | ||
of pair with the new integer token idx | ||
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] | ||
""" | ||
newids = [] | ||
i = 0 | ||
while i < len(ids): | ||
# if not at the very last position AND the pair matches, replace it | ||
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: | ||
newids.append(idx) | ||
i += 2 | ||
else: | ||
newids.append(ids[i]) | ||
i += 1 | ||
return newids | ||
|
||
|
||
class Tokenizer: | ||
|
||
def __init__(self): | ||
# by default, we have a vocab size of 256 (all bytes) and no merges | ||
self.merges = {} | ||
self.vocab = {idx: bytes([idx]) for idx in range(256)} | ||
|
||
def train(self, text, vocab_size, verbose=False): | ||
assert vocab_size >= 256 | ||
num_merges = vocab_size - 256 | ||
|
||
# input text preprocessing | ||
text_bytes = text.encode("utf-8") # raw bytes | ||
ids = list(text_bytes) # list of integers in range 0..255 | ||
|
||
# iteratively merge the most common pairs to create new tokens | ||
merges = {} # (int, int) -> int | ||
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes | ||
for i in range(num_merges): | ||
# count up the number of times every consecutive pair appears | ||
stats = get_stats(ids) | ||
# find the pair with the highest count | ||
pair = max(stats, key=stats.get) | ||
# mint a new token: assign it the next available id | ||
idx = 256 + i | ||
# replace all occurences of pair in ids with idx | ||
ids = merge(ids, pair, idx) | ||
# save the merge | ||
merges[pair] = idx | ||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | ||
# prints | ||
if verbose: | ||
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | ||
|
||
# save class variables | ||
self.merges = merges # used in encode() | ||
self.vocab = vocab # used in decode() | ||
|
||
def decode(self, ids): | ||
# given ids (list of integers), return Python string | ||
text_bytes = b"".join(self.vocab[idx] for idx in ids) | ||
text = text_bytes.decode("utf-8", errors="replace") | ||
return text | ||
|
||
def encode(self, text): | ||
# given a string text, return the token ids | ||
text_bytes = text.encode("utf-8") # raw bytes | ||
ids = list(text_bytes) # list of integers in range 0..255 | ||
while len(ids) >= 2: | ||
# find the pair with the lowest merge index | ||
stats = get_stats(ids) | ||
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | ||
# subtle: if there are no more merges available, the key will | ||
# result in an inf for every single pair, and the min will be | ||
# just the first pair in the list, arbitrarily | ||
# we can detect this terminating case by a membership check | ||
if pair not in self.merges: | ||
break # nothing else can be merged anymore | ||
# otherwise let's merge the best pair (lowest merge index) | ||
idx = self.merges[pair] | ||
ids = merge(ids, pair, idx) | ||
return ids | ||
|
||
if __name__ == "__main__": | ||
|
||
""" | ||
Quick unit test, following along the Wikipedia example: | ||
https://en.wikipedia.org/wiki/Byte_pair_encoding | ||
According to Wikipedia, running bpe on the the input string: | ||
"aaabdaaabac" | ||
for 3 merges will result in string: | ||
"XdXac" | ||
where: | ||
X=ZY | ||
Y=ab | ||
Z=aa | ||
Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values) | ||
so Z will be 256, Y will be 257, X will be 258. | ||
So we expect the output list of ids to be [258, 100, 258, 97, 99] | ||
""" | ||
|
||
text = "aaabdaaabac" | ||
tokenizer = Tokenizer() | ||
|
||
# we do 3 merges | ||
tokenizer.train(text, 256 + 3) | ||
|
||
# verify the correct expected result | ||
ids = tokenizer.encode(text) | ||
print("OK" if ids == [258, 100, 258, 97, 99] else "FAIL") | ||
|
||
# verify that decode(encode(x)) == x | ||
print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
""" | ||
Minimal (byte-level) Byte Pair Encoding tokenizer. | ||
Algorithmically follows along the GPT tokenizer: | ||
https://github.com/openai/gpt-2/blob/master/src/encoder.py | ||
Unlike bpe_basic.py, this file also handles the regex splitting pattern. | ||
But: | ||
- Does not handle any special tokens. | ||
""" | ||
|
||
import regex as re | ||
|
||
# the GPT-4 text split pattern, see | ||
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py | ||
SPLIT_PATTERN = re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""") | ||
|
||
|
||
def get_stats(ids): | ||
""" | ||
Given a list of integers, return a dictionary of counts of consecutive pairs | ||
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} | ||
""" | ||
counts = {} | ||
for pair in zip(ids, ids[1:]): # iterate consecutive elements | ||
counts[pair] = counts.get(pair, 0) + 1 | ||
return counts | ||
|
||
|
||
def merge(ids, pair, idx): | ||
""" | ||
In the list of integers (ids), replace all consecutive occurrences | ||
of pair with the new integer token idx | ||
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] | ||
""" | ||
newids = [] | ||
i = 0 | ||
while i < len(ids): | ||
# if not at the very last position AND the pair matches, replace it | ||
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: | ||
newids.append(idx) | ||
i += 2 | ||
else: | ||
newids.append(ids[i]) | ||
i += 1 | ||
return newids | ||
|
||
|
||
class Tokenizer: | ||
|
||
def __init__(self): | ||
# by default, we have a vocab size of 256 (all bytes) and no merges | ||
self.merges = {} | ||
self.vocab = {idx: bytes([idx]) for idx in range(256)} | ||
|
||
def train(self, text, vocab_size, verbose=False): | ||
assert vocab_size >= 256 | ||
num_merges = vocab_size - 256 | ||
|
||
# split the text up into text chunks | ||
text_chunks = re.findall(SPLIT_PATTERN, text) | ||
|
||
# input text preprocessing | ||
ids = [list(ch.encode("utf-8")) for ch in text_chunks] | ||
|
||
# iteratively merge the most common pairs to create new tokens | ||
merges = {} # (int, int) -> int | ||
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes | ||
for i in range(num_merges): | ||
# count up the number of times every consecutive pair appears | ||
chunk_stats = [get_stats(chunk_ids) for chunk_ids in ids] | ||
# combine the pair counts from all chunks by summing them up | ||
stats = {} | ||
for chstat in chunk_stats: | ||
for pair, count in chstat.items(): | ||
stats[pair] = stats.get(pair, 0) + count | ||
# find the pair with the highest count | ||
pair = max(stats, key=stats.get) | ||
# mint a new token: assign it the next available id | ||
idx = 256 + i | ||
# replace all occurences of pair in ids with idx | ||
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] | ||
# save the merge | ||
merges[pair] = idx | ||
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | ||
# prints | ||
if verbose: | ||
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | ||
|
||
# save class variables | ||
self.merges = merges # used in encode() | ||
self.vocab = vocab # used in decode() | ||
|
||
def decode(self, ids): | ||
# given ids (list of integers), return Python string | ||
text_bytes = b"".join(self.vocab[idx] for idx in ids) | ||
text = text_bytes.decode("utf-8", errors="replace") | ||
return text | ||
|
||
def _encode_chunk(self, text): | ||
# given a string text, return the token ids | ||
text_bytes = text.encode("utf-8") # raw bytes | ||
ids = list(text_bytes) # list of integers in range 0..255 | ||
while len(ids) >= 2: | ||
# find the pair with the lowest merge index | ||
stats = get_stats(ids) | ||
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | ||
# subtle: if there are no more merges available, the key will | ||
# result in an inf for every single pair, and the min will be | ||
# just the first pair in the list, arbitrarily | ||
# we can detect this terminating case by a membership check | ||
if pair not in self.merges: | ||
break # nothing else can be merged anymore | ||
# otherwise let's merge the best pair (lowest merge index) | ||
idx = self.merges[pair] | ||
ids = merge(ids, pair, idx) | ||
return ids | ||
|
||
def encode(self, text): | ||
# split text into chunks of text by categories defined in regex pattern | ||
text_chunks = re.findall(SPLIT_PATTERN, text) | ||
# all chunks of text are encoded separately, then results are joined | ||
ids = [] | ||
for chunk in text_chunks: | ||
chunk_ids = self._encode_chunk(chunk) | ||
ids.extend(chunk_ids) | ||
return ids | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
""" | ||
Quick unit test, following along the Wikipedia example: | ||
https://en.wikipedia.org/wiki/Byte_pair_encoding | ||
According to Wikipedia, running bpe on the the input string: | ||
"aaabdaaabac" | ||
for 3 merges will result in string: | ||
"XdXac" | ||
where: | ||
X=ZY | ||
Y=ab | ||
Z=aa | ||
Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values) | ||
so Z will be 256, Y will be 257, X will be 258. | ||
So we expect the output list of ids to be [258, 100, 258, 97, 99] | ||
""" | ||
|
||
text = "aaabdaaabac" | ||
tokenizer = Tokenizer() | ||
|
||
# we do 3 merges | ||
tokenizer.train(text, 256 + 3) | ||
|
||
# verify the correct expected result | ||
ids = tokenizer.encode(text) | ||
print("OK" if ids == [258, 100, 258, 97, 99] else "FAIL") | ||
|
||
# verify that decode(encode(x)) == x | ||
print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL") | ||
|
||
# take a bit more complex piece of text and train the tokenizer, chosen at random | ||
text = """ | ||
The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era. | ||
Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4] | ||
The ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South America and over 158,000 llamas and 100,000 alpacas, descended from progenitors imported late in the 20th century, in the United States and Canada.[5] | ||
In Aymara mythology, llamas are important beings. The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to Aymara eschatology, llamas will return to the water springs and ponds where they come from at the end of time.[6] | ||
""".strip() | ||
|
||
# do 64 merges | ||
tokenizer.train(text, 256 + 64) | ||
|
||
# verify that decode(encode(x)) == x | ||
print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL") | ||
|
||
# for fun if you like | ||
# print(tokenizer.vocab) |
Oops, something went wrong.