Skip to content

Commit

Permalink
initial work on server-side token counting
Browse files Browse the repository at this point in the history
  • Loading branch information
cogentapps committed Apr 29, 2023
1 parent 6b271a4 commit 36e434f
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 0 deletions.
4 changes: 4 additions & 0 deletions server/src/endpoints/service-proxies/openai/basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ export async function basicHandler(req: express.Request, res: express.Response)
})

res.json(response.data);

const promptTokens = response.data.usage.prompt_tokens as number;
const completionTokens = response.data.usage.completion_tokens as number;
// console.log(`prompt tokens: ${promptTokens}, completion tokens: ${completionTokens}, model: ${req.body.model}`);
}
4 changes: 4 additions & 0 deletions server/src/endpoints/service-proxies/openai/message.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export interface OpenAIMessage {
role: string;
content: string;
}
45 changes: 45 additions & 0 deletions server/src/endpoints/service-proxies/openai/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import { EventSource } from "launchdarkly-eventsource";
import express from 'express';
import { apiKey } from ".";
import { countTokensForMessages } from "./tokenizer";

export async function streamingHandler(req: express.Request, res: express.Response) {
res.set({
Expand All @@ -10,6 +11,11 @@ export async function streamingHandler(req: express.Request, res: express.Respon
Connection: 'keep-alive',
});

const messages = req.body.messages;
const promptTokens = countTokensForMessages(messages);

let completion = '';

const eventSource = new EventSource('https://api.openai.com/v1/chat/completions', {
method: "POST",
headers: {
Expand All @@ -30,6 +36,26 @@ export async function streamingHandler(req: express.Request, res: express.Respon
if (event.data === '[DONE]') {
res.end();
eventSource.close();

const totalTokens = countTokensForMessages([
...messages,
{
role: "assistant",
content: completion,
},
]);
const completionTokens = totalTokens - promptTokens;
// console.log(`prompt tokens: ${promptTokens}, completion tokens: ${completionTokens}, model: ${req.body.model}`);
return;
}

try {
const chunk = parseResponseChunk(event.data);
if (chunk.choices && chunk.choices.length > 0) {
completion += chunk.choices[0]?.delta?.content || '';
}
} catch (e) {
console.error(e);
}
});

Expand All @@ -48,4 +74,23 @@ export async function streamingHandler(req: express.Request, res: express.Respon
res.on('error', e => {
eventSource.close();
});
}

function parseResponseChunk(buffer: any) {
const chunk = buffer.toString().replace('data: ', '').trim();

if (chunk === '[DONE]') {
return {
done: true,
};
}

const parsed = JSON.parse(chunk);

return {
id: parsed.id,
done: false,
choices: parsed.choices,
model: parsed.model,
};
}
241 changes: 241 additions & 0 deletions server/src/endpoints/service-proxies/openai/tokenizer/bpe.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
const MAX_NUM_THREADS = 128;

type MergeRange = { start: number, end: number };

export class RankMap {
private values = new Map<string, number>();

public static from(texts: string[]) {
const map = new RankMap();
for (let i = 0; i < texts.length; i++) {
map.values.set(texts[i], i);
}
return map;
}

public set(bytes: Uint8Array, rank: number) {
const key = Buffer.from(bytes).toString();
this.values.set(key, rank);
}

public get(bytes: Uint8Array) {
const key = Buffer.from(bytes).toString();
return this.values.get(key);
}

public keys() {
return Array.from(this.values.keys()).map(k => Buffer.from(k));
}

public inverted() {
const inverted = new Map<number, Uint8Array>();
for (const [key, value] of Array.from(this.values.entries())) {
inverted.set(value, new Uint8Array(Buffer.from(key)));
}
return inverted;
}
}

function bytePairMerge(piece: Uint8Array, ranks: RankMap): MergeRange[] {
let parts: MergeRange[] = Array.from({ length: piece.length }, (_, i) => ({ start: i, end: i + 1 }));
while (true) {
if (parts.length === 1) {
break;
}
let minRank: [number, number] | null = null;
for (let i = 0; i < parts.length - 1; i++) {
const rank = ranks.get(piece.slice(parts[i].start, parts[i + 1].end));
if (rank === undefined) {
continue;
}
if (minRank === null || rank < minRank[0]) {
minRank = [rank, i];
}
}
if (minRank !== null) {
const [_, i] = minRank;
parts[i] = { start: parts[i].start, end: parts[i + 1].end };
parts.splice(i + 1, 1);
} else {
break;
}
}
return parts;
}

function bytePairEncode(piece: Uint8Array, ranks: RankMap): number[] {
if (piece.length === 1) {
return [ranks.get(piece)!];
}
return bytePairMerge(piece, ranks).map((p) => ranks.get(piece.slice(p.start, p.end))!);
}

function bytePairSplit(piece: Uint8Array, ranks: RankMap): Uint8Array[] {
if (piece.length === 1) {
return [piece];
}
return bytePairMerge(piece, ranks).map((p) => piece.slice(p.start, p.end));
}

export class CoreBPE {
encoder: RankMap;
specialTokensEncoder: Map<string, number>;
decoder: Map<number, Uint8Array>;
specialTokensDecoder: Map<number, Uint8Array>;
regexTls: RegExp[];
specialRegexTls: RegExp[];
sortedTokenBytes: Uint8Array[];

constructor(
encoder: RankMap,
specialTokensEncoder: Map<string, number>,
regex: RegExp
) {
const specialRegex = new RegExp(
Array.from(specialTokensEncoder.keys())
.map((s) => s.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"))
.join("|")
);

const decoder: Map<number, Uint8Array> = encoder.inverted();

const specialTokensDecoder: Map<number, Uint8Array> = new Map(
Array.from(specialTokensEncoder.entries()).map(([k, v]) => [v, new Uint8Array(Buffer.from(k))])
);
const sortedTokenBytes: Uint8Array[] = Array.from(encoder.keys());
sortedTokenBytes.sort((a, b) => Buffer.compare(a, b));

this.encoder = encoder;
this.specialTokensEncoder = specialTokensEncoder;
this.decoder = decoder;
this.specialTokensDecoder = specialTokensDecoder;
this.regexTls = Array(MAX_NUM_THREADS).fill(regex);
this.specialRegexTls = Array(MAX_NUM_THREADS).fill(specialRegex);
this.sortedTokenBytes = sortedTokenBytes;
}

private _getTlRegex(): RegExp {
return this.regexTls[Math.floor(Math.random() * MAX_NUM_THREADS)];
}

private _getTlSpecialRegex(): RegExp {
return this.specialRegexTls[Math.floor(Math.random() * MAX_NUM_THREADS)];
}

private _decodeNative(tokens: number[]): Uint8Array {
const ret: number[] = [];
for (const token of tokens) {
const tokenBytes = this.decoder.get(token) || this.specialTokensDecoder.get(token)!;
ret.push(...Array.from(tokenBytes));
}
return new Uint8Array(ret);
}

private _encodeOrdinaryNative(text: string): number[] {
const regex = this._getTlRegex();
const ret: number[] = [];
let match: RegExpExecArray | null;
while ((match = regex.exec(text)) !== null) {
const piece = new Uint8Array(Buffer.from(match[0]));
const token = this.encoder.get(piece);
if (token !== undefined) {
ret.push(token);
continue;
}
ret.push(...bytePairEncode(piece, this.encoder));
}
return ret;
}

private _encodeNative(text: string, allowedSpecial: Set<string>): [number[], number] {
const specialRegex = this._getTlSpecialRegex();
const regex = this._getTlRegex();
const ret: number[] = [];

let start = 0;
let lastPieceTokenLen = 0;
while (true) {
let nextSpecial: RegExpExecArray | null;
let startFind = start;
while (true) {
nextSpecial = specialRegex.exec(text.slice(startFind));
if (nextSpecial === null || allowedSpecial.has(nextSpecial[0])) {
break;
}
startFind = nextSpecial.index + 1;
}
const end = nextSpecial === null ? text.length : nextSpecial.index;
let match: RegExpExecArray | null;
while ((match = regex.exec(text.slice(start, end))) !== null) {
const piece = new Uint8Array(Buffer.from(match[0]));
const token = this.encoder.get(piece);
if (token !== undefined) {
lastPieceTokenLen = 1;
ret.push(token);
continue;
}
const tokens = bytePairEncode(piece, this.encoder);
lastPieceTokenLen = tokens.length;
ret.push(...tokens);
}

if (nextSpecial === null) {
break;
}
const piece = nextSpecial[0];
const token = this.specialTokensEncoder.get(piece)!;
ret.push(token);
start = nextSpecial.index + piece.length;
lastPieceTokenLen = 0;
}
return [ret, lastPieceTokenLen];
}

encodeOrdinary(text: string): number[] {
return this._encodeOrdinaryNative(text);
}

encode(text: string, allowedSpecial: Set<string>): number[] {
return this._encodeNative(text, allowedSpecial)[0];
}

encodeWithUnstable(text: string, allowedSpecial: Set<string>): [number[], Set<number[]>] {
throw new Error("Not implemented");
}

encodeSingleToken(piece: Uint8Array): number {
const token = this.encoder.get(piece);
if (token !== undefined) {
return token;
}
const pieceStr = Buffer.from(piece).toString("utf-8");
if (this.specialTokensEncoder.has(pieceStr)) {
return this.specialTokensEncoder.get(pieceStr)!;
}
throw new Error("Key not found");
}

encodeSinglePiece(piece: Uint8Array): number[] {
const token = this.encoder.get(piece);
if (token !== undefined) {
return [token];
}
return bytePairEncode(piece, this.encoder);
}

decodeBytes(tokens: number[]): Uint8Array {
return this._decodeNative(tokens);
}

decodeSingleTokenBytes(token: number): Uint8Array {
const bytes = this.decoder.get(token) || this.specialTokensDecoder.get(token);
if (bytes !== undefined) {
return bytes;
}
throw new Error("Key not found");
}

tokenByteValues(): Uint8Array[] {
return this.sortedTokenBytes;
}
}

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions server/src/endpoints/service-proxies/openai/tokenizer/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { CoreBPE, RankMap } from "./bpe";
import fs from 'fs';
import path from 'path';
import { OpenAIMessage } from "../message";

const ranks = JSON.parse(fs.readFileSync(path.join(__dirname, './cl100k_base.json'), 'utf8'));

const special_tokens: any = {
"<|endoftext|>": 100257,
"<|fim_prefix|>": 100258,
"<|fim_middle|>": 100259,
"<|fim_suffix|>": 100260,
"<|endofprompt|>": 100276,
};

const special_tokens_map = new Map<string, number>();
for (const text of Object.keys(special_tokens)) {
special_tokens_map.set(text, special_tokens[text]);
}

const pattern = /('s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/giu;

const tokenizer = new CoreBPE(RankMap.from(ranks), special_tokens_map, pattern);

const overheadTokens = {
perMessage: 5,
perPrompt: 2,
}

const tokenCache = new Map<string, number>();

export function countTokensForText(text: string) {
const cacheKey = text;
if (tokenCache.has(cacheKey)) {
return tokenCache.get(cacheKey)!;
}
let t1 = Date.now();
const tokens = tokenizer.encodeOrdinary(text).length;
tokenCache.set(cacheKey, tokens);
return tokens;
}

export function countTokensForMessage(message: OpenAIMessage) {
return countTokensForText(message.content) + overheadTokens.perMessage;
}

export function countTokensForMessages(messages: OpenAIMessage[]) {
let tokens = overheadTokens.perPrompt;
for (const m of messages) {
tokens += countTokensForMessage(m);
}
return tokens;
}

0 comments on commit 36e434f

Please sign in to comment.