Skip to content

Commit

Permalink
feat: prompt rewrite eval tool (TabbyML#406)
Browse files Browse the repository at this point in the history
* feat: prompt rewrite eval tool

* chore: readme

* chore: language

* chore: resolve comments

* chore: prefix generation

* chore: clean

* chore: dependency
  • Loading branch information
vodkaslime authored Sep 8, 2023
1 parent c86c6d8 commit 96cea7d
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 0 deletions.
23 changes: 23 additions & 0 deletions experimental/prompt-rewrite-eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Prompt rewriting evaluation tool

## Install dependencies
```
pip install -r requirements.txt
```

## Run backend rewriting script
1. tweak `eval.toml`
- tabby binary path
- index repo url (the repo you want tabby to index from)
- sample repo url (the repo you want to generate completion requests from)
- language
- prompt count

2. run `python evaluation.py`

## Run dashboard to view prompts
```
streamlit run dashboard.py
```
- Tweak the slider bar to change how many recent prompts you want to review.
- Change the language to filter only the specific language you are interested in.
42 changes: 42 additions & 0 deletions experimental/prompt-rewrite-eval/dashboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import jsonlines
import streamlit as st

LANGUAGE_LIST = [
"python",
"rust",
"go",
"java",
"javascript_typescript",
"lua",
"php"
]

st.title(":wave: Prompt rewriting dashboard")

st.divider()
st.subheader("Select your options")

entry_count = st.slider("How many entries to view", 0, 100, 10)
language = st.radio("Select the language you are working on", LANGUAGE_LIST)

events_path = os.path.expanduser("~/.tabby/events")
log_file_name = sorted(os.listdir(events_path))[-1]
log_file_path = os.path.join(events_path, log_file_name)

prompts = []
with jsonlines.open(log_file_path) as log:
for obj in log:
if "completion" not in obj["event"]:
continue
if obj["event"]["completion"]["language"] != language:
continue
prompts.append(obj["event"]["completion"]["prompt"])

prompts = prompts[-entry_count:]
code_language = language if language != "javascript_typescript" else "javascript"
for i in range(len(prompts)):
st.divider()
prompt = prompts[i]
st.write(f"**[prompt {i+1}]**")
st.code(prompt, language=code_language)
9 changes: 9 additions & 0 deletions experimental/prompt-rewrite-eval/eval.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# Tabby config
tabby_path = "/home/jiachen/workspace/tabbyml/tabby/target/debug/tabby"

# Indexing / prompt rewriting config
index_repo_url = "https://github.com/huggingface/transformers"
sample_repo_url = "https://github.com/huggingface/transformers"
language = "python"
prompt_count = 10
186 changes: 186 additions & 0 deletions experimental/prompt-rewrite-eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import jsonlines
import logging
import os
import random
import requests
import subprocess
import toml

import time

logging.getLogger().setLevel(logging.INFO)

PORT = 8080

def wait_for_online(timeout):
logging.info("Trying to connect to tabby")

health_url = f"http://127.0.0.1:{PORT}/v1/health"

is_online = False
till = time.time() + timeout * 1000

while time.time() < till:
try:
r = requests.post(health_url)
if r.status_code == 200:
logging.info("Tabby is online now")
is_online = True
break
except:
logging.info("Retrying to connect")
time.sleep(1)

return is_online


def index(args):
binary = args["tabby_path"]
index_repo_url = args["index_repo_url"]

# Write to config.toml
config_file_path = os.path.expanduser("~/.tabby/config.toml")
config = {
"repositories": [
{
"git_url": index_repo_url,
}
],
"experimental": {
"enable_prompt_rewrite": True,
}
}
with open(config_file_path, "w+") as f:
toml.dump(config, f)

# Start indexing
cmd = [binary, "scheduler", "--now"]
subprocess.run(cmd)

def generate_completion_segments(args):
binary = args["tabby_path"]
sample_repo_url = args["sample_repo_url"]
language = args["language"]
prompt_count = args["prompt_count"]

segments = []

# Index the sample repo
sample_path = os.path.expanduser("~/.tabby/eval_sample")
sample_config_file_path = os.path.join(sample_path, "config.toml")
config = {
"repositories": [
{
"git_url": sample_repo_url,
}
]
}

if not os.path.exists(sample_path):
os.mkdir(sample_path)

with open(sample_config_file_path, "w+") as f:
toml.dump(config, f)

sample_index_command = [binary, "scheduler", "--now"]
# subprocess.run(sample_index_command, env={"TABBY_ROOT": sample_path})

# Read in dataset.jsonl and build segments
contents = []
dataset_path = os.path.join(sample_path, "dataset")
# in dir dataset/, could have multiple jsonl files:
# data.jsonl, data.jsonl.1, data.jsonl.2, etc
files = os.listdir(dataset_path)
for file_name in files:
dataset_file_path = os.path.join(dataset_path, file_name)
with jsonlines.open(dataset_file_path) as dataset:
for obj in dataset:
if obj["language"] != language:
continue
contents.append(obj["content"])

# Generate random segments
for _ in range(prompt_count):
# Randomly pick a file content
content = ""

# We are only interested in files that have content,
# So we have this while loop to retry-and-fence
while not content:
file_content = random.randrange(len(contents))
content = contents[file_content]

# Randomly pick a cursor
cursor = random.randrange(len(content))

# Look backward to generate prefix
lb = 0
pc = cursor
while True:
if pc < 0:
break
if content[pc] == "\n":
lb += 1
if lb == 10:
break
pc -= 1
prefix = content[pc + 1: cursor + 1]

# Look forward to generate suffix
lb = 0
sc = cursor + 1
while True:
if sc >= len(content):
break
if content[sc] == "\n":
lb += 1
if lb == 10:
break
sc += 1
suffix = content[cursor + 1: sc]

segments.append({
"prefix": prefix,
"suffix": suffix
})

# Generate query segment
return segments

def rewrite_prompt(args):
binary = args["tabby_path"]
language = args["language"]

# Generate segments
segments = generate_completion_segments(args)

# Start tabby server
serve_command = [binary, "serve", "--model", "TabbyML/T5P-220M"]
process = subprocess.Popen(serve_command)

try:
# Wait for tabby server to be up online
if not wait_for_online(5):
logging.error("Tabby server is not online")
return

# Generate completion request messages
completion_url = f"http://127.0.0.1:{PORT}/v1/completions"
for s in segments:
req = {
"language": language,
"segments": s,
}

r = requests.post(completion_url, json=req)
logging.info(r.status_code)
finally:
process.terminate()

def main():
args = toml.load("eval.toml")
# index(args)
rewrite_prompt(args)

if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions experimental/prompt-rewrite-eval/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
requests
streamlit
jsonlines

0 comments on commit 96cea7d

Please sign in to comment.