Skip to content

Commit

Permalink
Update to 3.2.3 (#77)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
edamamez and github-actions[bot] authored Dec 16, 2024
1 parent 0e649c6 commit f61a11f
Show file tree
Hide file tree
Showing 13 changed files with 388 additions and 38 deletions.
5 changes: 2 additions & 3 deletions lamini/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
os.environ.get("GATE_PIPELINE_BATCH_COMPLETIONS", False)
)

__version__ = "3.1.3"
__version__ = "3.2.3"

# isort: off

Expand All @@ -30,9 +30,8 @@
from lamini.api.model_downloader import ModelDownloader
from lamini.api.model_downloader import ModelType
from lamini.api.model_downloader import DownloadedModel
from lamini.classify.lamini_classifier import LaminiClassifier
from lamini.generation.generation_node import GenerationNode
from lamini.generation.generation_pipeline import GenerationPipeline
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.split_response_node import SplitResponseNode
from lamini.api.streaming_completion import StreamingCompletion
from lamini.api.streaming_completion import StreamingCompletion
70 changes: 70 additions & 0 deletions lamini/api/lamini.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import enum
import json
import logging
Expand All @@ -13,6 +14,8 @@
from lamini.api.rest_requests import get_version, make_web_request
from lamini.api.train import Train
from lamini.api.utils.completion import Completion
from lamini.api.utils.sql_completion import SQLCompletion
from lamini.api.utils.sql_token_cache import SQLTokenCache
from lamini.api.utils.upload_client import upload_to_blob
from lamini.error.error import DownloadingModelError

Expand Down Expand Up @@ -59,6 +62,8 @@ def __init__(
self.api_key = api_key
self.api_url = api_url
self.completion = Completion(api_key, api_url)
self.sql_completion = SQLCompletion(api_key, api_url)
self.sql_token_cache = SQLTokenCache(api_key, api_url)
self.trainer = Train(api_key, api_url)
self.upload_file_path = None
self.upload_base_path = None
Expand All @@ -77,6 +82,23 @@ def version(self) -> str:
"""
return get_version(self.api_key, self.api_url, self.config)

def generate_sql(
self,
prompt: Union[str, List[str]],
cache_id: str,
model_name: Optional[str] = None,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Union[str, Dict[str, Any]]:
result = self.sql_completion.generate(
prompt=prompt,
cache_id=cache_id,
model_name=model_name or self.model_name,
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
)
return result

def generate(
self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -181,6 +203,7 @@ async def async_generate(

req_data = self.completion.make_llm_req_map(
prompt=prompt,
cache_id=cache_id,
model_name=model_name or self.model_name,
output_type=output_type,
max_tokens=max_tokens,
Expand Down Expand Up @@ -424,6 +447,53 @@ def download_model(
INTERVAL_SECONDS = 1
time.sleep(INTERVAL_SECONDS)

def add_sql_token_cache(
self,
col_val_file: Optional[str] = None,
wait: bool = False,
wait_time_seconds: int = 600,
):
col_val_str = None

if col_val_file:
with open(col_val_file, 'r') as f:
col_vals = json.load(f)
# TODO: in another PR, limit size of col_vals dict
col_val_str = json.dumps(col_vals)

start_time = time.time()

while True:
res = self.sql_token_cache.add_token_cache(
base_model_name=self.model_name,
col_vals=col_val_str,
)

if not wait:
return res
if res["status"] == "done":
return res
elif res["status"] == "failed":
raise Exception("SQL token cache build failed")

elapsed_time = time.time() - start_time
if elapsed_time > wait_time_seconds:
return res
INTERVAL_SECONDS = 1
time.sleep(INTERVAL_SECONDS)

def delete_sql_token_cache(self, cache_id):
while True:
res = self.sql_token_cache.delete_token_cache(cache_id)

if res["status"] == "done":
return res
elif res["status"] == "failed":
raise Exception("SQL token cache deletion failed")

INTERVAL_SECONDS = 1
time.sleep(INTERVAL_SECONDS)

def list_models(self) -> List[DownloadedModel]:
return self.model_downloader.list()

Expand Down
11 changes: 6 additions & 5 deletions lamini/api/model_downloader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
A class to handle downloading models from the Lamini Platform.
"""

import enum
from typing import List, Union
from typing import List

import lamini
import numpy as np
from lamini.api.lamini_config import get_config, get_configured_key, get_configured_url
from lamini.api.rest_requests import make_web_request


Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
api_url: str,
):
self.api_key = api_key
self.api_endpoint = api_url + "/v1/downloaded_models/"
self.api_endpoint = api_url + "/v1alpha/downloaded_models/"

def download(self, hf_model_name: str, model_type: ModelType) -> DownloadedModel:
"""Request to Lamini platform for an embedding encoding of the provided
Expand Down
78 changes: 63 additions & 15 deletions lamini/api/rest_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
APIUnprocessableContentError,
AuthenticationError,
DownloadingModelError,
ModelNotFound,
DuplicateResourceError,
ModelNotFoundError,
RateLimitError,
RequestTimeoutError,
UnavailableResourceError,
ProjectNotFoundError,
UserError,
)

Expand Down Expand Up @@ -174,24 +176,33 @@ async def handle_error(resp: aiohttp.ClientResponse) -> None:
Raises
------
ModelNotFound
Raises from 594
RateLimitError
Raises from 429
UserError
Raises from 400
AuthenticationError
Raises from 401
UserError
Raises from 400
APIUnprocessableContentError
Raises from 422
RateLimitError
Raises from 429
DuplicateResourceError
Raises from 497
JobNotFoundError
Raises from 498
ProjectNotFoundError
Raises from 499
UnavailableResourceError
Raises from 503
ModelNotFoundError
Raises from 594
APIError
Raises from 200
Expand All @@ -206,7 +217,21 @@ async def handle_error(resp: aiohttp.ClientResponse) -> None:
json_response = await resp.json()
except Exception:
json_response = {}
raise ModelNotFound(json_response.get("detail", "ModelNotFound"))
raise ModelNotFoundError(json_response.get("detail", "ModelNotFound"))
if resp.status == 499:
try:
json_response = await resp.json()
except Exception:
json_response = {}
raise ProjectNotFoundError(json_response.get("detail", "ProjectNotFoundError"))
if resp.status == 497:
try:
json_response = await resp.json()
except Exception:
json_response = {}
raise DuplicateResourceError(
json_response.get("detail", "DuplicateResourceError")
)
if resp.status == 429:
try:
json_response = await resp.json()
Expand Down Expand Up @@ -253,7 +278,7 @@ async def handle_error(resp: aiohttp.ClientResponse) -> None:


def make_web_request(
key: str, url: str, http_method: str, json: Optional[Dict[str, Any]] = None
key: str, url: str, http_method: str, json: Optional[Dict[str, Any]] = None, stream: bool = False
) -> Dict[str, Any]:
"""Execute a web request
Expand Down Expand Up @@ -288,7 +313,7 @@ def make_web_request(
HTTPError
Raised from many possible reasons:
if resp.status_code == 594:
ModelNotFound
ModelNotFoundError
if resp.status_code == 429:
RateLimitError
if resp.status_code == 401:
Expand Down Expand Up @@ -326,10 +351,14 @@ def make_web_request(
pass
if http_method == "post":
resp = requests.post(url=url, headers=headers, json=json)
elif http_method == "get" and stream:
resp = requests.get(url=url, headers=headers, stream=True)
elif http_method == "get":
resp = requests.get(url=url, headers=headers)
elif http_method == "delete":
resp = requests.delete(url=url, headers=headers)
else:
raise Exception("http_method must be 'post' or 'get'")
raise Exception("http_method must be 'post' or 'get' or 'delete'")
try:
check_version(resp)
resp.raise_for_status()
Expand All @@ -339,7 +368,23 @@ def make_web_request(
json_response = resp.json()
except Exception:
json_response = {}
raise ModelNotFound(json_response.get("detail", "ModelNameError"))
raise ModelNotFoundError(json_response.get("detail", "ModelNameError"))
if resp.status_code == 499:
try:
json_response = resp.json()
except Exception:
json_response = {}
raise ProjectNotFoundError(
json_response.get("detail", "ProjectNotFoundError")
)
if resp.status_code == 497:
try:
json_response = resp.json()
except Exception:
json_response = {}
raise DuplicateResourceError(
json_response.get("detail", "DuplicateResourceError")
)
if resp.status_code == 429:
try:
json_response = resp.json()
Expand Down Expand Up @@ -401,4 +446,7 @@ def make_web_request(
raise APIError("500 Internal Server Error")
raise APIError(f"API error {description}")

return resp.json()
if stream:
return resp
else:
return resp.json()
3 changes: 3 additions & 0 deletions lamini/api/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def make_llm_req_map(
self,
model_name: str,
prompt: Union[str, List[str]],
cache_id: Optional[str] = None,
output_type: Optional[dict] = None,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
Expand Down Expand Up @@ -186,4 +187,6 @@ def make_llm_req_map(
req_data["max_tokens"] = max_tokens
if max_new_tokens is not None:
req_data["max_new_tokens"] = max_new_tokens
if cache_id is not None:
req_data["cache_id"] = cache_id
return req_data
40 changes: 40 additions & 0 deletions lamini/api/utils/sql_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Dict, List, Optional, Union

import aiohttp
import lamini
from lamini.api.lamini_config import get_config, get_configured_key, get_configured_url
from lamini.api.rest_requests import make_async_web_request, make_web_request
from lamini.api.utils.completion import Completion

class SQLCompletion(Completion):
def __init__(self, api_key, api_url) -> None:
self.config = get_config()

self.api_key = api_key or lamini.api_key or get_configured_key(self.config)
self.api_url = api_url or lamini.api_url or get_configured_url(self.config)
self.api_prefix = self.api_url + "/v1alpha/"

def generate(
self,
prompt: Union[str, List[str]],
cache_id: str,
model_name: str,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Dict[str, Any]:
req_data = self.make_llm_req_map(
prompt=prompt,
cache_id=cache_id,
model_name=model_name,
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
)
resp = make_web_request(
self.api_key, self.api_prefix + "sql", "post", req_data
)
return resp

async def async_generate(
self, params: Dict[str, Any], client: aiohttp.ClientSession = None
) -> Dict[str, Any]:
raise Exception("SQL streaming not implemented")
Loading

0 comments on commit f61a11f

Please sign in to comment.