Skip to content

Commit

Permalink
optimization: allow get_stats to update an existing counts dict. trai…
Browse files Browse the repository at this point in the history
…n.py runtime goes from 32s to 22s doing this. ty @gklab for original suggestion in a PR
  • Loading branch information
karpathy committed Feb 18, 2024
1 parent ff20c92 commit 7843c96
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
5 changes: 3 additions & 2 deletions minbpe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
# -----------------------------------------------------------------------------
# a few helper functions useful for both BasicTokenizer and RegexTokenizer

def get_stats(ids):
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {}
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
Expand Down
10 changes: 4 additions & 6 deletions minbpe/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ def train(self, text, vocab_size, verbose=False):
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count up the number of times every consecutive pair appears
chunk_stats = [get_stats(chunk_ids) for chunk_ids in ids]
# combine the pair counts from all chunks by summing them up
# count the number of times every consecutive pair appears
stats = {}
for chstat in chunk_stats:
for pair, count in chstat.items():
stats[pair] = stats.get(pair, 0) + count
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
Expand Down
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import os
import time
from minbpe import BasicTokenizer, RegexTokenizer

# open some text and train a vocab of 512 tokens
Expand All @@ -12,6 +13,7 @@
# create a directory for models, so we don't pollute the current directory
os.makedirs("models", exist_ok=True)

t0 = time.time()
for TokenizerClass, name in zip([BasicTokenizer, RegexTokenizer], ["basic", "regex"]):

# construct the Tokenizer object and kick off verbose training
Expand All @@ -20,3 +22,6 @@
# writes two files in the models directory: name.model, and name.vocab
prefix = os.path.join("models", name)
tokenizer.save(prefix)
t1 = time.time()

print(f"Training took {t1 - t0:.2f} seconds")

0 comments on commit 7843c96

Please sign in to comment.