forked from TabbyML/tabby
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: prompt rewrite eval tool (TabbyML#406)
* feat: prompt rewrite eval tool * chore: readme * chore: language * chore: resolve comments * chore: prefix generation * chore: clean * chore: dependency
- Loading branch information
1 parent
c86c6d8
commit 96cea7d
Showing
5 changed files
with
263 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Prompt rewriting evaluation tool | ||
|
||
## Install dependencies | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run backend rewriting script | ||
1. tweak `eval.toml` | ||
- tabby binary path | ||
- index repo url (the repo you want tabby to index from) | ||
- sample repo url (the repo you want to generate completion requests from) | ||
- language | ||
- prompt count | ||
|
||
2. run `python evaluation.py` | ||
|
||
## Run dashboard to view prompts | ||
``` | ||
streamlit run dashboard.py | ||
``` | ||
- Tweak the slider bar to change how many recent prompts you want to review. | ||
- Change the language to filter only the specific language you are interested in. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import jsonlines | ||
import streamlit as st | ||
|
||
LANGUAGE_LIST = [ | ||
"python", | ||
"rust", | ||
"go", | ||
"java", | ||
"javascript_typescript", | ||
"lua", | ||
"php" | ||
] | ||
|
||
st.title(":wave: Prompt rewriting dashboard") | ||
|
||
st.divider() | ||
st.subheader("Select your options") | ||
|
||
entry_count = st.slider("How many entries to view", 0, 100, 10) | ||
language = st.radio("Select the language you are working on", LANGUAGE_LIST) | ||
|
||
events_path = os.path.expanduser("~/.tabby/events") | ||
log_file_name = sorted(os.listdir(events_path))[-1] | ||
log_file_path = os.path.join(events_path, log_file_name) | ||
|
||
prompts = [] | ||
with jsonlines.open(log_file_path) as log: | ||
for obj in log: | ||
if "completion" not in obj["event"]: | ||
continue | ||
if obj["event"]["completion"]["language"] != language: | ||
continue | ||
prompts.append(obj["event"]["completion"]["prompt"]) | ||
|
||
prompts = prompts[-entry_count:] | ||
code_language = language if language != "javascript_typescript" else "javascript" | ||
for i in range(len(prompts)): | ||
st.divider() | ||
prompt = prompts[i] | ||
st.write(f"**[prompt {i+1}]**") | ||
st.code(prompt, language=code_language) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
# Tabby config | ||
tabby_path = "/home/jiachen/workspace/tabbyml/tabby/target/debug/tabby" | ||
|
||
# Indexing / prompt rewriting config | ||
index_repo_url = "https://github.com/huggingface/transformers" | ||
sample_repo_url = "https://github.com/huggingface/transformers" | ||
language = "python" | ||
prompt_count = 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import jsonlines | ||
import logging | ||
import os | ||
import random | ||
import requests | ||
import subprocess | ||
import toml | ||
|
||
import time | ||
|
||
logging.getLogger().setLevel(logging.INFO) | ||
|
||
PORT = 8080 | ||
|
||
def wait_for_online(timeout): | ||
logging.info("Trying to connect to tabby") | ||
|
||
health_url = f"http://127.0.0.1:{PORT}/v1/health" | ||
|
||
is_online = False | ||
till = time.time() + timeout * 1000 | ||
|
||
while time.time() < till: | ||
try: | ||
r = requests.post(health_url) | ||
if r.status_code == 200: | ||
logging.info("Tabby is online now") | ||
is_online = True | ||
break | ||
except: | ||
logging.info("Retrying to connect") | ||
time.sleep(1) | ||
|
||
return is_online | ||
|
||
|
||
def index(args): | ||
binary = args["tabby_path"] | ||
index_repo_url = args["index_repo_url"] | ||
|
||
# Write to config.toml | ||
config_file_path = os.path.expanduser("~/.tabby/config.toml") | ||
config = { | ||
"repositories": [ | ||
{ | ||
"git_url": index_repo_url, | ||
} | ||
], | ||
"experimental": { | ||
"enable_prompt_rewrite": True, | ||
} | ||
} | ||
with open(config_file_path, "w+") as f: | ||
toml.dump(config, f) | ||
|
||
# Start indexing | ||
cmd = [binary, "scheduler", "--now"] | ||
subprocess.run(cmd) | ||
|
||
def generate_completion_segments(args): | ||
binary = args["tabby_path"] | ||
sample_repo_url = args["sample_repo_url"] | ||
language = args["language"] | ||
prompt_count = args["prompt_count"] | ||
|
||
segments = [] | ||
|
||
# Index the sample repo | ||
sample_path = os.path.expanduser("~/.tabby/eval_sample") | ||
sample_config_file_path = os.path.join(sample_path, "config.toml") | ||
config = { | ||
"repositories": [ | ||
{ | ||
"git_url": sample_repo_url, | ||
} | ||
] | ||
} | ||
|
||
if not os.path.exists(sample_path): | ||
os.mkdir(sample_path) | ||
|
||
with open(sample_config_file_path, "w+") as f: | ||
toml.dump(config, f) | ||
|
||
sample_index_command = [binary, "scheduler", "--now"] | ||
# subprocess.run(sample_index_command, env={"TABBY_ROOT": sample_path}) | ||
|
||
# Read in dataset.jsonl and build segments | ||
contents = [] | ||
dataset_path = os.path.join(sample_path, "dataset") | ||
# in dir dataset/, could have multiple jsonl files: | ||
# data.jsonl, data.jsonl.1, data.jsonl.2, etc | ||
files = os.listdir(dataset_path) | ||
for file_name in files: | ||
dataset_file_path = os.path.join(dataset_path, file_name) | ||
with jsonlines.open(dataset_file_path) as dataset: | ||
for obj in dataset: | ||
if obj["language"] != language: | ||
continue | ||
contents.append(obj["content"]) | ||
|
||
# Generate random segments | ||
for _ in range(prompt_count): | ||
# Randomly pick a file content | ||
content = "" | ||
|
||
# We are only interested in files that have content, | ||
# So we have this while loop to retry-and-fence | ||
while not content: | ||
file_content = random.randrange(len(contents)) | ||
content = contents[file_content] | ||
|
||
# Randomly pick a cursor | ||
cursor = random.randrange(len(content)) | ||
|
||
# Look backward to generate prefix | ||
lb = 0 | ||
pc = cursor | ||
while True: | ||
if pc < 0: | ||
break | ||
if content[pc] == "\n": | ||
lb += 1 | ||
if lb == 10: | ||
break | ||
pc -= 1 | ||
prefix = content[pc + 1: cursor + 1] | ||
|
||
# Look forward to generate suffix | ||
lb = 0 | ||
sc = cursor + 1 | ||
while True: | ||
if sc >= len(content): | ||
break | ||
if content[sc] == "\n": | ||
lb += 1 | ||
if lb == 10: | ||
break | ||
sc += 1 | ||
suffix = content[cursor + 1: sc] | ||
|
||
segments.append({ | ||
"prefix": prefix, | ||
"suffix": suffix | ||
}) | ||
|
||
# Generate query segment | ||
return segments | ||
|
||
def rewrite_prompt(args): | ||
binary = args["tabby_path"] | ||
language = args["language"] | ||
|
||
# Generate segments | ||
segments = generate_completion_segments(args) | ||
|
||
# Start tabby server | ||
serve_command = [binary, "serve", "--model", "TabbyML/T5P-220M"] | ||
process = subprocess.Popen(serve_command) | ||
|
||
try: | ||
# Wait for tabby server to be up online | ||
if not wait_for_online(5): | ||
logging.error("Tabby server is not online") | ||
return | ||
|
||
# Generate completion request messages | ||
completion_url = f"http://127.0.0.1:{PORT}/v1/completions" | ||
for s in segments: | ||
req = { | ||
"language": language, | ||
"segments": s, | ||
} | ||
|
||
r = requests.post(completion_url, json=req) | ||
logging.info(r.status_code) | ||
finally: | ||
process.terminate() | ||
|
||
def main(): | ||
args = toml.load("eval.toml") | ||
# index(args) | ||
rewrite_prompt(args) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
requests | ||
streamlit | ||
jsonlines |