Skip to content

Commit

Permalink
Refactor supervision handling and enhance provider support
Browse files Browse the repository at this point in the history
- Updated `AsteroidChatSupervisionManager` to streamline message updates with provider-specific logic.
- Improved `SupervisionRunner` to clarify tool call handling and ensure compatibility with single tool calls.
- Enhanced `supervisors.py` to integrate Gemini support, including new conversion functions for state messages and responses.
- Refactored `match_tool_call_ids` to accommodate different provider matching criteria, improving accuracy in tool interactions.
- Cleaned up comments and improved code clarity across multiple files for better maintainability.

These changes collectively enhance the modularity and functionality of the supervision process, ensuring better integration with various API providers.
  • Loading branch information
mlcocdav committed Jan 22, 2025
1 parent 48b44b5 commit 0e2548a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/asteroid_sdk/api/asteroid_chat_supervision_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def handle_language_model_interaction(
supervision_context = run.supervision_context
# Update messages on the supervision context
supervision_context.update_messages(request_kwargs,
provider=self.model_provider_helper.get_provider(), # TODO Change this to get it from the provider helper
provider=self.model_provider_helper.get_provider(),
system_message=request_kwargs.get('system', None))

response, response_data_tool_calls = self.get_tool_calls_and_modify_response_if_necessary(
Expand Down
3 changes: 1 addition & 2 deletions src/asteroid_sdk/api/supervision_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ async def handle_tool_calls_from_llm_response(

new_response = copy.deepcopy(response)
# TODO - Check if this is still relevant
# !IMPORTANT! - We're only accepting 1 tool_call ATM. There's code that is called within this
# this loop that assumes this.
# We do not allow multiple tool calls with resampling, it should work without. Tested with Gemini and nothing else
decisions: List[Dict] = []
for idx, tool_call in enumerate(response_data_tool_calls):
tool_id = UUID(choice_ids[0].tool_call_ids[idx].tool_id)
Expand Down
110 changes: 65 additions & 45 deletions src/asteroid_sdk/supervision/inspect_ai/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""

from functools import wraps
from typing import List, Dict, Optional, Callable, Union
from typing import List, Optional, Callable, Union
from uuid import UUID
import json
from anthropic.types.message import Message as AnthropicMessage

from asteroid_sdk.api.api_logger import APILogger
Expand All @@ -30,54 +31,54 @@
from anthropic.types.tool_use_block import ToolUseBlock
from asteroid_sdk.supervision.helpers.anthropic_helper import AnthropicSupervisionHelper
from asteroid_sdk.supervision.helpers.openai_helper import OpenAiSupervisionHelper
from asteroid_sdk.supervision.helpers.gemini_helper import GeminiHelper
from .utils import (
transform_asteroid_approval_to_inspect_ai_approval,
convert_state_messages_to_openai_messages,
convert_state_output_to_openai_response,
convert_state_messages_to_anthropic_messages,
convert_state_output_to_anthropic_response,
convert_state_messages_to_gemini_messages,
convert_state_output_to_gemini_response,
)
import logging
from asteroid_sdk.supervision.helpers.model_provider_helper import Provider

# Mappings for model provider helpers and conversion functions
MODEL_PROVIDER_HELPERS = {
"openai": OpenAiSupervisionHelper,
"anthropic": AnthropicSupervisionHelper,
# Add "google": GoogleSupervisionHelper when implemented
"google": GeminiHelper,
}

CONVERT_STATE_MESSAGES_TO_MESSAGES = {
"openai": convert_state_messages_to_openai_messages,
"anthropic": convert_state_messages_to_anthropic_messages,
# "google": convert_state_messages_to_google_messages when available
"google": convert_state_messages_to_gemini_messages,
}

CONVERT_STATE_OUTPUT_TO_RESPONSE = {
"openai": convert_state_output_to_openai_response,
"anthropic": convert_state_output_to_anthropic_response,
# "google": convert_state_output_to_google_response when available
"google": convert_state_output_to_gemini_response,
}

EXTRACT_TOOL_CALLS_FROM_RESPONSE = {
"openai": lambda response: response.choices[0].message.tool_calls,
"anthropic": lambda response: [
content_block for content_block in response.content if isinstance(content_block, ToolUseBlock)
],
# "google": lambda response: ... when available
"google": lambda response: GeminiHelper().get_tool_call_from_response(response),
}

def with_asteroid_supervision(
supervisor_name_param: Optional[str] = None,
n: Optional[int] = None
supervisor_name_param: Optional[str] = None,
n: Optional[int] = None
):
"""
Decorator for common Asteroid API interactions during supervision.
Args:
supervisor_name_param (Optional[str]): Name of the supervisor to use.
If not provided, the function's name will be used.
supervisor_name_param (Optional[str]): Name of the supervisor to use.
If not provided, the function's name will be used.
n (Optional[int]): Number of tool call suggestions to generate for human approval.
Expand Down Expand Up @@ -120,8 +121,7 @@ async def wrapper(
elif state.model.api == "anthropic":
model_provider_helper = AnthropicSupervisionHelper()
elif state.model.api == "google":
# model_provider_helper = GoogleSupervisionHelper()
raise Exception(f"Model API {state.model.api} not supported")
model_provider_helper = GeminiHelper()
else:
raise Exception(f"Model API {state.model.api} not supported")

Expand All @@ -146,14 +146,23 @@ async def wrapper(
if last_message.tool_calls:
# Match Asteroid's tool call ID to the Inspect AI tool call ID
for idx, _tool_call in enumerate(last_message.tool_calls):
if _tool_call.call_id == call.id:
tool_call_idx = idx
tool_id = _tool_call.tool_id
tool_call_id = _tool_call.id
tool_call_data = _tool_call
tool = supervision_runner.get_tool(tool_id)
break

if state.model.api == "google":
if _tool_call.name == call.function and json.loads(_tool_call.arguments) == call.arguments:
tool_call_idx = idx
tool_id = _tool_call.tool_id
tool_call_id = _tool_call.id
tool_call_data = _tool_call # TODO: This might need fixing we might have to instantiate new ToolCall, not pass AsteroidToolCall
tool = supervision_runner.get_tool(tool_id)
break
else:
if _tool_call.id == call.id:
tool_call_idx = idx
tool_id = _tool_call.tool_id
tool_call_id = _tool_call.id
tool_call_data = _tool_call
tool = supervision_runner.get_tool(tool_id)
break

# If no existing messages or tool call not found, log the first message
if len(asteroid_messages) == 0 or tool_call_idx is None:
# Handle provider-specific logic
Expand All @@ -179,7 +188,7 @@ async def wrapper(

# Match Asteroid's tool call ID to the Inspect AI tool call ID
tool_call_idx, tool_id, tool_call_id = match_tool_call_ids(
response_tool_calls, call, choice_ids
response_tool_calls, call, choice_ids, state.model.api
)
tool = supervision_runner.get_tool(tool_id)
tool_call_data = model_provider_helper.get_tool_call_from_response(response)[tool_call_idx]
Expand All @@ -195,25 +204,19 @@ async def wrapper(
)
)

# Find the supervisor by name within the supervisor chains
# Find the supervisor by name within the supervisor chains
supervisor_name = supervisor_name_param or approve_func.__name__
supervisor, supervisor_chain_id, position_in_chain = find_supervisor_in_chains(
supervisor_chains, supervisor_name
)
supervisor, supervisor_chain_id, position_in_chain = find_supervisor_in_chains(
supervisor_chains, supervisor_name
)
if supervisor is None:
raise Exception(f"Supervisor {supervisor_name} not found in any chain")

# Execute the supervisor and get the decision
# Execute the supervisor and get the decision
decision = await supervision_runner.execute_supervisor(
supervisor=supervisor,
tool=tool,
tool_call=tool_call_data,
tool_call=tool_call_data,
tool_call_id=tool_call_id,
position_in_chain=position_in_chain,
supervision_context=supervision_context,
Expand All @@ -227,15 +230,12 @@ async def wrapper(
f"No decision made for supervisor {supervisor_name} in chain {supervisor_chain_id}"
)

# Handle modify decision and attach original call if needed
# Handle modify decision and attach original call if needed
if decision.decision == SupervisionDecisionType.MODIFY and decision.modified is not None:
decision.modified.original_inspect_ai_call = call
logging.info(f"Returning approval: {decision.decision}")
logging.info(f"Returning approval: {decision.decision}")
return transform_asteroid_approval_to_inspect_ai_approval(decision)

# Set the wrapper function name
# Set the wrapper function name
wrapper.__name__ = supervisor_name_param or approve_func.__name__
return wrapper
Expand Down Expand Up @@ -266,7 +266,7 @@ def handle_openai_provider(state: TaskState, supervision_context: SupervisionCon
response = convert_state_output_to_openai_response(state.output)

# Update the supervision context with messages
supervision_context.update_messages(request_kwargs["messages"])
supervision_context.update_messages(request_kwargs, provider=Provider.OPENAI)

# Extract tool calls from the response
response_tool_calls = response.choices[0].message.tool_calls
Expand Down Expand Up @@ -298,8 +298,8 @@ def handle_anthropic_provider(state: TaskState, supervision_context: Supervision

# Update the supervision context with messages
supervision_context.update_messages(
request_kwargs["messages"],
anthropic=True,
request_kwargs,
provider=Provider.ANTHROPIC,
system_message=request_kwargs.get("system", None),
)

Expand All @@ -321,28 +321,50 @@ def handle_google_provider(state: TaskState, supervision_context: SupervisionCon
Returns:
Tuple of (request_kwargs, response, response_tool_calls)
"""
# Convert state messages to Google format
#TODO: Implement
pass
# Convert state messages to Gemini format
gemini_messages = convert_state_messages_to_gemini_messages(state.messages[:-1])

def match_tool_call_ids(response_tool_calls: List, call: ToolCall, choice_ids: List):
# Prepare request kwargs
request_kwargs = {
"contents": gemini_messages,
"model": state.model.name,
}

# Convert state output to Gemini response
response = convert_state_output_to_gemini_response(state.output)

# Update the supervision context with messages
supervision_context.update_messages(request_kwargs, provider=Provider.GEMINI)

# Extract tool calls from the response
response_tool_calls = GeminiHelper().get_tool_call_from_response(response)

return request_kwargs, response, response_tool_calls

def match_tool_call_ids(response_tool_calls: List, call: InspectAIToolCall, choice_ids: List,
provider: str):
"""
Match Asteroid's tool call ID to the Inspect AI tool call ID.
Args:
response_tool_calls (List): List of tool calls from the response.
call (ToolCall): The original ToolCall object from Inspect AI.
choice_ids (List): List of choice IDs from the API response.
provider (str): The model provider.
Returns:
Tuple of (tool_call_idx, tool_id, tool_call_id)
"""
for idx, _tool_call in enumerate(response_tool_calls):
if _tool_call.id == call.id:
tool_id = choice_ids[0].tool_call_ids[idx].tool_id
tool_call_id = choice_ids[0].tool_call_ids[idx].tool_call_id
return idx, tool_id, tool_call_id
raise Exception("Tool call ID not found in response tool calls")
if provider == "google":
# There are no ids in Gemini we need to match on the function name and arguments
if _tool_call.tool_name == call.function and _tool_call.tool_params == call.arguments:
break
else:
if _tool_call.id == call.id:
break
tool_id = choice_ids[0].tool_call_ids[idx].tool_id
tool_call_id = choice_ids[0].tool_call_ids[idx].tool_call_id
return idx, tool_id, tool_call_id

def find_supervisor_in_chains(supervisor_chains: List, supervisor_name: str):
"""
Expand All @@ -364,10 +386,8 @@ def find_supervisor_in_chains(supervisor_chains: List, supervisor_name: str):
return supervisor, supervisor_chain_id, position_in_chain
return None, None, None

@approver
def human_approver(timeout: int = 86400, n: int = 3) -> Approver:
@approver(name="human_approver")
def human_approver(timeout: int = 3000, n: int = 3) -> Approver:
def human_approver(timeout: int = 86400, n: int = 3) -> Approver:
"""
Human approver function for Inspect AI.
Expand Down

0 comments on commit 0e2548a

Please sign in to comment.