From fd47ef67eb11fe76ae5350c7a9fb35f12c9cb96e Mon Sep 17 00:00:00 2001 From: Brian Bargh Date: Wed, 26 Jun 2024 02:39:52 +0000 Subject: [PATCH] Modify the bpe to work for strings with limited character sets and on lists of strings instead of individual strings. --- minbpe/basic.py | 92 ++++++++++++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 35 deletions(-) diff --git a/minbpe/basic.py b/minbpe/basic.py index 9bc5ab76..be12588e 100644 --- a/minbpe/basic.py +++ b/minbpe/basic.py @@ -13,28 +13,40 @@ class BasicTokenizer(Tokenizer): - def __init__(self): super().__init__() - def train(self, text, vocab_size, verbose=False): - assert vocab_size >= 256 - num_merges = vocab_size - 256 - + def train(self, texts, vocab_size, verbose=False): # input text preprocessing - text_bytes = text.encode("utf-8") # raw bytes - ids = list(text_bytes) # list of integers in range 0..255 + all_bytes = [] + for text in texts: + all_bytes.extend(text.encode("utf-8")) # raw bytes + unique_bytes = set(all_bytes) + + # Create initial vocabulary from unique bytes in the training data + self.byte_to_id = {byte: i for i, byte in enumerate(unique_bytes)} + self.id_to_byte = {i: byte for byte, i in self.byte_to_id.items()} + + ids = [self.byte_to_id[byte] for byte in all_bytes] + + initial_vocab_size = len(unique_bytes) + num_merges = vocab_size - initial_vocab_size + print(f"Training BPE with {initial_vocab_size} unique bytes and {num_merges} merges") # iteratively merge the most common pairs to create new tokens - merges = {} # (int, int) -> int - vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes + merges = {} # (int, int) -> int + vocab = {idx: bytes([self.id_to_byte[idx]]) for idx in range(initial_vocab_size)} + for i in range(num_merges): # count up the number of times every consecutive pair appears stats = get_stats(ids) + if not stats: + break # No more pairs to merge + # find the pair with the highest count pair = max(stats, key=stats.get) # mint a new token: assign it the next available id - idx = 256 + i + idx = initial_vocab_size + i # replace all occurrences of pair in ids with idx ids = merge(ids, pair, idx) # save the merge @@ -45,30 +57,40 @@ def train(self, text, vocab_size, verbose=False): print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") # save class variables - self.merges = merges # used in encode() - self.vocab = vocab # used in decode() + self.merges = merges # used in encode() + self.vocab = vocab # used in decode() - def decode(self, ids): - # given ids (list of integers), return Python string - text_bytes = b"".join(self.vocab[idx] for idx in ids) - text = text_bytes.decode("utf-8", errors="replace") - return text + def decode(self, ids_list): + # given a list of lists of integers, return a list of Python strings + texts = [] + for ids in ids_list: + text_bytes = b"".join(self.vocab[idx] for idx in ids) + text = text_bytes.decode("utf-8", errors="replace") + texts.append(text) + return texts - def encode(self, text): - # given a string text, return the token ids - text_bytes = text.encode("utf-8") # raw bytes - ids = list(text_bytes) # list of integers in range 0..255 - while len(ids) >= 2: - # find the pair with the lowest merge index - stats = get_stats(ids) - pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) - # subtle: if there are no more merges available, the key will - # result in an inf for every single pair, and the min will be - # just the first pair in the list, arbitrarily - # we can detect this terminating case by a membership check - if pair not in self.merges: - break # nothing else can be merged anymore - # otherwise let's merge the best pair (lowest merge index) - idx = self.merges[pair] - ids = merge(ids, pair, idx) - return ids + def encode(self, texts): + # given a list of strings, return a list of token ids for each string + all_ids = [] + for text in texts: + text_bytes = text.encode("utf-8") # raw bytes + ids = [self.byte_to_id.get(byte, len(self.byte_to_id)) for byte in text_bytes] # Use a default value for unknown bytes + + while len(ids) >= 2: + # find the pair with the lowest merge index + stats = get_stats(ids) + if not stats: + break # No more pairs to merge + + pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) + # subtle: if there are no more merges available, the key will + # result in an inf for every single pair, and the min will be + # just the first pair in the list, arbitrarily + # we can detect this terminating case by a membership check + if pair not in self.merges: + break # nothing else can be merged anymore + # otherwise let's merge the best pair (lowest merge index) + idx = self.merges[pair] + ids = merge(ids, pair, idx) + all_ids.append(ids) + return all_ids \ No newline at end of file