forked from ezelikman/parsel
-
Notifications
You must be signed in to change notification settings - Fork 1
/
codex.py
124 lines (119 loc) · 6.02 KB
/
codex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
'''
Codex was deprecated by openai, so change the model to gpt3.5.
'''
import json
import openai
import os
import time
from consts import CONSTS
import random
class CodeGen():
def __init__(self, cache="cache.json"):
self.cache_file = cache
self.exponential_backoff = 1
self.messages = []
# Load the cache JSON file, if cache file exists. Else, cache is {}
if os.path.exists(cache):
while os.path.exists(self.cache_file + ".tmp") or os.path.exists(self.cache_file + ".lock"):
time.sleep(0.1)
with open(cache, "r") as f:
self.cache = json.load(f)
else:
self.cache = {}
def generate(self,
codex_in, num_completions=8, max_tokens=500, temperature=0.5, presence_penalty=0.0,
stop=["\ndef"], indented=True, indented_after_first_line=False, require=None, cache_key=None,
rate_limit_tokens=4000, verbose=False, logit_bias={}, model_name=None, is_test=False
):
if model_name is None:
model_name = "gpt-3.5-turbo-0301"
if verbose:
print(codex_in)
print("-----")
assert isinstance(codex_in, str)
cache_key_base = codex_in if cache_key is None else cache_key
cache_key_list = (cache_key_base, max_tokens, temperature, stop, indented, indented_after_first_line, require)
if presence_penalty != 0.0:
cache_key_list = cache_key_list + (presence_penalty,)
if model_name != "code-davinci-002":
cache_key_list = cache_key_list + (model_name,)
cache_key = str(cache_key_list)
if cache_key in self.cache:
if len(self.cache[cache_key]) < num_completions:
num_completions -= len(self.cache[cache_key])
results = self.cache[cache_key]
else:
cur_implementations = self.cache[cache_key].copy()
if "shuffle_implementations" in CONSTS and CONSTS["shuffle_implementations"]:
random.shuffle(cur_implementations)
return cur_implementations[:num_completions]
else:
results = []
print(f"Using {model_name} model!")
# raise Exception("Codex is not available")
total_tokens = num_completions * max_tokens
completions_per_call = rate_limit_tokens // max_tokens
while total_tokens > 0:
#num_completions = min(total_tokens // max_tokens, completions_per_call)
print(f"Actually need: {num_completions} completions!")
while True:
try:
# time.sleep(8)
if not is_test:
print(codex_in)
messages = [{"role" : "system", "content" : "You are an expert of the Python programming language."}, {"role": "user", "content": "Please return a python function meets the following requirements. The function implementations should consist with the type innotations in function headers if exist. You should return return only the pure code. Omit explanations or any additional text. Ensure that your code can be directly compiled and run without errors.\n" + codex_in}]
completions = openai.ChatCompletion.create(
model=model_name,
messages=messages,
temperature=temperature,
presence_penalty=presence_penalty,
max_tokens=max_tokens,
n=num_completions,
logit_bias=logit_bias
)['choices']
else:
print(codex_in)
messages = [{"role" : "system", "content" : "You are an expert of the Python programming language."}, {"role": "user", "content": codex_in}]
completions = openai.ChatCompletion.create(
model=model_name,
messages=messages,
temperature=temperature,
presence_penalty=presence_penalty,
max_tokens=max_tokens,
n=num_completions,
logit_bias=logit_bias
)['choices']
self.exponential_backoff = 1
break
except openai.error.RateLimitError:
print("Rate limit reached. Waiting before retrying...")
time.sleep(16 * self.exponential_backoff)
self.exponential_backoff *= 2
for completion in completions:
result = []
response = completion["message"]["content"]
if '```' in response:
response = response.split('```')[1]
if response[:6] == "python":
response = response[6:]
for line_idx, line in enumerate(response.split("\n")):
result += [line]
results.append(result)
# Save updated cache - reopen in case multiple processes running
# Save to a temp file first, then rename
# Check if a temp file exists, and if so, wait for it to be deleted
while os.path.exists(self.cache_file + ".tmp") or os.path.exists(self.cache_file + ".lock"):
time.sleep(0.1)
# create an empty file to indicate that we are writing to the cache
with open(self.cache_file + ".lock", "w") as f:
pass
if os.path.exists(self.cache_file):
with open(self.cache_file, "r") as f:
self.cache = json.load(f)
self.cache[cache_key] = results
with open(self.cache_file + ".tmp", "w") as f:
json.dump(self.cache, f)
os.rename(self.cache_file + ".tmp", self.cache_file)
os.remove(self.cache_file + ".lock")
total_tokens -= num_completions * max_tokens
return results