Skip to content

Commit

Permalink
meaty change: adding special tokens handling, so now we have full par…
Browse files Browse the repository at this point in the history
…ity with the GPT-4 Tokenizer
  • Loading branch information
karpathy committed Feb 19, 2024
1 parent e0ed1bc commit e82c123
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 33 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ There are two Tokenizers in this repository, both of which can perform the 3 pri

1. [minbpe/base.py](minbpe/base.py): Implements the `Tokenizer` class, which is the base class. It contains the `train`, `encode`, and `decode` stubs, save/load functionality, and there are also a few common utility functions. This class is not meant to be used directly, but rather to be inherited from.
2. [minbpe/basic.py](minbpe/basic.py): Implements the `BasicTokenizer`, the simplest implementation of the BPE algorithm that runs directly on text.
3. [minbpe/regex.py](minbpe/regex.py): Implements the `RegexTokenizer` that further splits the input text by a regex pattern, which is a preprocessing stage that splits up the input text by categories (think: letters, numbers, punctuation) before tokenization. This ensures that no merges will happen across category boundaries. This was introduced in the GPT-2 paper and continues to be in use as of GPT-4.
4. [minbpe/gpt4.py](minbpe/gpt4.py): Implements the `GPT4Tokenizer`. This class is a light wrapper around the `RegexTokenizer` (2, above) that exactly reproduces the tokenization of GPT-4 in the [tiktoken](https://github.com/openai/tiktoken) library. The wrapping handles some details around recovering the exact merges in the tokenizer, and the handling of some unfortunate (and likely historical?) 1-byte token permutations. Note that the parity is not fully complete yet because we do not handle special tokens.
3. [minbpe/regex.py](minbpe/regex.py): Implements the `RegexTokenizer` that further splits the input text by a regex pattern, which is a preprocessing stage that splits up the input text by categories (think: letters, numbers, punctuation) before tokenization. This ensures that no merges will happen across category boundaries. This was introduced in the GPT-2 paper and continues to be in use as of GPT-4. This class also handles special tokens, if any.
4. [minbpe/gpt4.py](minbpe/gpt4.py): Implements the `GPT4Tokenizer`. This class is a light wrapper around the `RegexTokenizer` (2, above) that exactly reproduces the tokenization of GPT-4 in the [tiktoken](https://github.com/openai/tiktoken) library. The wrapping handles some details around recovering the exact merges in the tokenizer, and the handling of some unfortunate (and likely historical?) 1-byte token permutations.

Finally, the script [train.py](train.py) trains the two major tokenizers on the input text [tests/taylorswift.txt](tests/taylorswift.txt) (this is the Wikipedia entry for her kek) and saves the vocab to disk for visualization. This script runs in about 25 seconds on my (M1) MacBook.

Expand Down Expand Up @@ -66,7 +66,6 @@ to run the tests.
- write an even more optimized C or Rust version (think through)
- rename GPT4Tokenizer to GPTTokenizer and support GPT-2 as well?
- write a LlamaTokenizer similar to GPT4Tokenizer (i.e. attempt sentencepiece equivalent)
- handle special tokens
- video coming soon ;)

## License
Expand Down
19 changes: 17 additions & 2 deletions minbpe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ class Tokenizer:
def __init__(self):
# default: vocab size of 256 (all bytes), no merges, no patterns
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
self.pattern = ""

def train(self, text, vocab_size, verbose=False):
# Tokenizer can train a vocabulary of size vocab_size from text
Expand All @@ -89,6 +90,8 @@ def _build_vocab(self):
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab

def save(self, file_prefix):
Expand All @@ -104,6 +107,11 @@ def save(self, file_prefix):
# write the version, pattern and merges, that's all that's needed
f.write(f"minbpe v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
# write the vocab: for the human to look at
Expand Down Expand Up @@ -134,17 +142,24 @@ def load(self, model_file):
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r') as f:
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
10 changes: 9 additions & 1 deletion minbpe/gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,20 @@ def recover_merges(mergeable_ranks):

return merges

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_SPECIAL_TOKENS = {
'<|endoftext|>': 100257,
'<|fim_prefix|>': 100258,
'<|fim_middle|>': 100259,
'<|fim_suffix|>': 100260,
'<|endofprompt|>': 100276
}

class GPT4Tokenizer(RegexTokenizer):
"""Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""

def __init__(self):
super().__init__()
super().__init__(pattern=GPT4_SPLIT_PATTERN, special_tokens=GPT4_SPECIAL_TOKENS)
# get the official tokenizer and its merges
enc = tiktoken.get_encoding("cl100k_base")
mergeable_ranks = enc._mergeable_ranks
Expand Down
68 changes: 60 additions & 8 deletions minbpe/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
Unlike bpe_basic.py, this file also handles the regex splitting pattern.
But:
- Does not handle any special tokens.
Unlike BasicTokenizer:
- RegexTokenizer handles an optional regex splitting pattern.
- RegexTokenizer handles optional special tokens.
"""

import regex as re
Expand All @@ -22,10 +21,17 @@

class RegexTokenizer(Tokenizer):

def __init__(self):
def __init__(self, pattern=None, special_tokens=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
super().__init__()
self.pattern = GPT4_SPLIT_PATTERN
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.special_tokens = {} if special_tokens is None else special_tokens
self.compiled_pattern = re.compile(self.pattern)
self.inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}

def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
Expand Down Expand Up @@ -65,7 +71,15 @@ def train(self, text, vocab_size, verbose=False):

def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text

Expand All @@ -88,7 +102,8 @@ def _encode_chunk(self, text_bytes):
ids = merge(ids, pair, idx)
return ids

def encode(self, text):
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
Expand All @@ -98,3 +113,40 @@ def encode(self, text):
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids

def encode(self, text, allowed_special="all"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special:
# shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# otherwise, we have to be careful with potential special tokens in text
# we handle special tokens by splitting the text
# based on the occurence of any exact match with any of the special tokens
# we can use re.split for this. note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
if part in special:
# this is a special token, encode it separately as a special case
ids.append(special[part])
else:
# this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
return ids
58 changes: 39 additions & 19 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from minbpe import BasicTokenizer, RegexTokenizer, GPT4Tokenizer

# -----------------------------------------------------------------------------
# common test data

# a few strings to test the tokenizers on
dirname = os.path.dirname(os.path.abspath(__file__))
taylorswift_file = os.path.join(dirname, "taylorswift.txt")
Expand All @@ -13,6 +16,28 @@
"hello world!!!? (안녕하세요!) lol123 😉", # fun small string
open(taylorswift_file, "r", encoding="utf-8").read(), # big string
]
specials_string = """
<|endoftext|>Hello world this is one document
<|endoftext|>And this is another document
<|endoftext|><|fim_prefix|>And this one has<|fim_suffix|> tokens.<|fim_middle|> FIM
<|endoftext|>Last document!!! 👋<|endofprompt|>
""".strip()
special_tokens = {
'<|endoftext|>': 100257,
'<|fim_prefix|>': 100258,
'<|fim_middle|>': 100259,
'<|fim_suffix|>': 100260,
'<|endofprompt|>': 100276
}
llama_text = """
<|endoftext|>The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.
Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4]
The ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South America and over 158,000 llamas and 100,000 alpacas, descended from progenitors imported late in the 20th century, in the United States and Canada.[5]
<|fim_prefix|>In Aymara mythology, llamas are important beings. The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to Aymara eschatology,<|fim_suffix|> where they come from at the end of time.[6]<|fim_middle|> llamas will return to the water springs and ponds<|endofprompt|>
""".strip()

# -----------------------------------------------------------------------------
# tests

# test encode/decode identity for a few different strings
@pytest.mark.parametrize("tokenizer_factory", [BasicTokenizer, RegexTokenizer, GPT4Tokenizer])
Expand All @@ -32,6 +57,14 @@ def test_gpt4_tiktoken_equality(text):
gpt4_tokenizer_ids = tokenizer.encode(text)
assert gpt4_tokenizer_ids == tiktoken_ids

# test the handling of special tokens
def test_gpt4_tiktoken_equality_special_tokens():
tokenizer = GPT4Tokenizer()
enc = tiktoken.get_encoding("cl100k_base")
tiktoken_ids = enc.encode(specials_string, allowed_special="all")
gpt4_tokenizer_ids = tokenizer.encode(specials_string, allowed_special="all")
assert gpt4_tokenizer_ids == tiktoken_ids

# reference test to add more tests in the future
@pytest.mark.parametrize("tokenizer_factory", [BasicTokenizer, RegexTokenizer])
def test_wikipedia_example(tokenizer_factory):
Expand Down Expand Up @@ -62,39 +95,26 @@ def test_wikipedia_example(tokenizer_factory):
assert ids == [258, 100, 258, 97, 99]
assert tokenizer.decode(tokenizer.encode(text)) == text

def test_save_load():
@pytest.mark.parametrize("special_tokens", [{}, special_tokens])
def test_save_load(special_tokens):
# take a bit more complex piece of text and train the tokenizer, chosen at random
text = """
The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.
Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4]
The ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South America and over 158,000 llamas and 100,000 alpacas, descended from progenitors imported late in the 20th century, in the United States and Canada.[5]
In Aymara mythology, llamas are important beings. The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to Aymara eschatology, llamas will return to the water springs and ponds where they come from at the end of time.[6]
""".strip()

tokenizer = RegexTokenizer()

# do 64 merges
text = llama_text
# create a Tokenizer and do 64 merges
tokenizer = RegexTokenizer(special_tokens=special_tokens)
tokenizer.train(text, 256 + 64)

# verify that decode(encode(x)) == x
assert tokenizer.decode(tokenizer.encode(text)) == text

# verify that save/load work as expected
ids = tokenizer.encode(text)

# TODO use a proper temporary directory for I/O things below
# save the tokenizer
# save the tokenizer (TODO use a proper temporary directory)
tokenizer.save("test_tokenizer_tmp")

# re-load the tokenizer
tokenizer = RegexTokenizer()
tokenizer.load("test_tokenizer_tmp.model")

# verify that decode(encode(x)) == x
assert tokenizer.decode(ids) == text
assert tokenizer.decode(tokenizer.encode(text)) == text
assert tokenizer.encode(text) == ids

# delete the temporary files
for file in ["test_tokenizer_tmp.model", "test_tokenizer_tmp.vocab"]:
os.remove(file)
Expand Down

0 comments on commit e82c123

Please sign in to comment.