Skip to content

Commit

Permalink
Enhance API interactions and logging across multiple modules
Browse files Browse the repository at this point in the history
- Updated `APILogger` to support `GenerateContentResponse` from Google, improving compatibility with new response types.
- Cleaned up imports in `helper.py` by removing unused ones for better code clarity.
- Enhanced message content handling in `SupervisionContext` to accommodate different content types, including images.
- Introduced new functions in `utils.py` for converting state messages to Gemini API format and for generating responses compatible with Gemini, enhancing integration capabilities.
- Improved error handling in message conversions to ensure robust logging of issues.

These changes collectively enhance the modularity, readability, and functionality of the API interactions and supervision processes.
  • Loading branch information
mlcocdav committed Jan 22, 2025
1 parent 54fc6aa commit 48b44b5
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/asteroid_sdk/api/api_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openai.types.chat.chat_completion_message import (
ChatCompletionMessage,
)

from google.generativeai.types import GenerateContentResponse
from asteroid_sdk.api.generated.asteroid_api_client import Client
from asteroid_sdk.api.generated.asteroid_api_client.api.run.create_new_chat import (
sync_detailed as create_new_chat_sync_detailed,
Expand All @@ -35,7 +35,7 @@ def __init__(self, client: Client, model_provider_helper: ModelProviderHelper):

def log_llm_interaction(
self,
response: ChatCompletion | Message,
response: ChatCompletion | Message | GenerateContentResponse,
request_kwargs: Dict[str, Any],
run_id: UUID,
) -> ChatIds:
Expand Down Expand Up @@ -81,7 +81,7 @@ def _send_chats_to_asteroid_api(self, run_id: UUID, body: AsteroidChat) -> ChatI
raise

def _convert_to_json(
self, response: Any, request_kwargs: Any
self, response: ChatCompletion | Message | GenerateContentResponse, request_kwargs: Any
) -> tuple[str, str]:
"""
Convert the response and request data to JSON strings.
Expand Down
7 changes: 0 additions & 7 deletions src/asteroid_sdk/registration/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import copy
import logging
import json
import json

from asteroid_sdk.api.generated.asteroid_api_client.client import Client
from asteroid_sdk.api.generated.asteroid_api_client.models import CreateProjectBody, CreateTaskBody
Expand Down Expand Up @@ -51,7 +50,6 @@
from asteroid_sdk.utils.utils import get_function_code
from asteroid_sdk.settings import settings
from asteroid_sdk.supervision.config import SupervisionDecision, SupervisionDecisionType, ModifiedData
from asteroid_sdk.supervision.config import SupervisionDecision, SupervisionDecisionType, ModifiedData

class APIClientFactory:
"""Factory for creating API clients with proper authentication."""
Expand Down Expand Up @@ -616,20 +614,15 @@ def wait_for_human_decision(supervision_request_id: UUID, timeout: int = 86400)
if isinstance(status, Status) and status in [Status.FAILED, Status.COMPLETED, Status.TIMEOUT]:
# Map status to SupervisionDecision
logging.debug(f"Polling for human decision completed. Status: {status}")
logging.debug(f"Polling for human decision completed. Status: {status}")
return status
else:
logging.debug("Waiting for human supervisor decision...")
logging.debug("Waiting for human supervisor decision...")
else:
logging.warning(f"Unexpected response while polling for supervision status: {response}")
logging.warning(f"Unexpected response while polling for supervision status: {response}")
except Exception as e:
logging.error(f"Error while polling for supervision status: {e}")
logging.error(f"Error while polling for supervision status: {e}")

if time.time() - start_time > timeout:
logging.warning(f"Timed out waiting for human supervision decision. Timeout: {timeout} seconds")
logging.warning(f"Timed out waiting for human supervision decision. Timeout: {timeout} seconds")
return Status.TIMEOUT

Expand Down
13 changes: 12 additions & 1 deletion src/asteroid_sdk/supervision/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,18 @@ def _describe_openai_messages(self) -> str:
messages_text = []
for message in self.openai_messages:
role = message.get('role', 'Unknown').capitalize()
content = message.get('content', '').strip()
content = message.get('content', '')
if type(content) == str:
content = content.strip()
elif type(content) == list:
# TODO: Solve this - it happens when there is an image
content = ''
for m in content:
if m.type == 'text':
content += m.text
elif m.type == 'image_url':
content += f'Image' #{m.image_url}' #TODO: Add the image somehow

message_str = f"**{role}:**\n{content}" if content else f"**{role}:**"

# Handle tool calls if present
Expand Down
88 changes: 87 additions & 1 deletion src/asteroid_sdk/supervision/inspect_ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from anthropic.types.tool_use_block import ToolUseBlock
from anthropic.types.usage import Usage

from google.ai.generativelanguage_v1beta import FunctionCall, Content, Part, Candidate, GenerateContentResponse as BetaContent
from google.generativeai.types import GenerateContentResponse

from asteroid_sdk.supervision.config import SupervisionDecision

from asteroid_sdk.supervision.config import SupervisionDecision


Expand Down Expand Up @@ -67,6 +72,9 @@ def convert_state_messages_to_openai_messages(state_messages: List[ChatMessage])
role = msg.role # 'system', 'user', 'assistant', etc.
content = msg.text # Extract the text content from the message

if hasattr(msg, 'error') and msg.error is not None:
content = f"{content}\n\nError: {msg.error.message}"

openai_msg = {
"role": role,
"content": content,
Expand Down Expand Up @@ -170,6 +178,9 @@ def convert_state_messages_to_anthropic_messages(state_messages: List[ChatMessag
}
content_blocks.append(text_block)

if hasattr(msg, 'error') and msg.error is not None:
content_blocks.append({'type': 'text', 'text': f"Error: {msg.error.message}"})

# Include tool calls as ToolUseBlocks
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tool_call in msg.tool_calls:
Expand Down Expand Up @@ -275,4 +286,79 @@ def transform_asteroid_approval_to_inspect_ai_approval(approval_decision: Superv
decision=inspect_ai_decision,
modified=modified,
explanation=approval_decision.explanation
)
)

def convert_state_messages_to_gemini_messages(state_messages: List[ChatMessage]) -> List[Dict]:
"""
Convert Inspect AI state messages to a list of dictionaries compatible with Gemini's API.
Args:
state_messages (List[ChatMessage]): List of Inspect AI chat messages.
Returns:
List[Dict]: List of messages formatted for Gemini API.
"""
gemini_messages = []
for msg in state_messages:
content = {
'role': msg.role, # 'system', 'user', 'assistant'
'parts': []
}

parts = []
if msg.text:
parts.append({'text': msg.text})

# Include tool calls as appropriate
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tool_call in msg.tool_calls:
function_call = {
'function_call': {
'name': tool_call.function,
'args': tool_call.arguments
}
}
parts.append(function_call)

if hasattr(msg, 'error') and msg.error is not None:
parts.append({'text': f"Error: {msg.error.message}"})

content['parts'] = parts
gemini_messages.append(content)

return gemini_messages

def convert_state_output_to_gemini_response(state_output) -> GenerateContentResponse:
"""
Convert Inspect AI state output to a Gemini GenerateContentResponse instance.
Args:
state_output (ModelOutput): The output from Inspect AI model.
Returns:
GenerateContentResponse: An instance representing the response as per Gemini's API.
"""
# Assume state_output.choices is a list of choices; we'll use the first choice
choice = state_output.choices[0]
message = choice.message # Should be ChatMessageAssistant

parts = []
if message.text:
parts.append(Part(text=message.text))

# Include function calls as appropriate
if hasattr(message, 'tool_calls') and message.tool_calls:
for tool_call in message.tool_calls:
function_call = FunctionCall(
name=tool_call.function,
args=tool_call.arguments
)
parts.append(Part(function_call=function_call))

content = Content(parts=parts, role=message.role)
candidate = Candidate(content=content)

beta_response = BetaContent(candidates=[candidate])

response = GenerateContentResponse.from_response(beta_response)
return response

0 comments on commit 48b44b5

Please sign in to comment.