Skip to content

Commit

Permalink
Refactor to use new Converse API
Browse files Browse the repository at this point in the history
  • Loading branch information
daixba committed Jun 4, 2024
1 parent 86e3db7 commit 6960390
Show file tree
Hide file tree
Showing 9 changed files with 498 additions and 680 deletions.
7 changes: 0 additions & 7 deletions src/api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
from api.models.bedrock import (
ClaudeModel,
SUPPORTED_BEDROCK_MODELS,
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
get_model,
get_embeddings_model,
)
15 changes: 14 additions & 1 deletion src/api/models/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
Expand All @@ -19,6 +20,14 @@ class BaseChatModel(ABC):
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""

def list_models(self) -> list[str]:
"""Return a list of supported models"""
return []

def validate(self, chat_request: ChatRequest):
"""Validate chat completion requests."""
pass

@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
Expand All @@ -38,7 +47,11 @@ def stream_response_to_bytes(
response: ChatStreamResponse | None = None
) -> bytes:
if response:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
# to populate other fields when using exclude_unset=True
response.system_fingerprint = "fp"
response.object = "chat.completion.chunk"
response.created = int(time.time())
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8")


Expand Down
Loading

0 comments on commit 6960390

Please sign in to comment.