Skip to content

Commit

Permalink
[COMMIT FOR JOE] FIX TESTS
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-diacono committed Jan 30, 2025
1 parent b1dd49d commit 44420e2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 18 deletions.
71 changes: 55 additions & 16 deletions tests/acceptance/abstract_acceptance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setUpGlobal(self, helper: ModelProviderHelper):

self.run_id = uuid.uuid4()

def resamples_then_works_globals(self, mock_get_client):
def resamples_then_works_globals(self, mock_get_client, should_add_get_run_call: bool = False):
"""
Sets up the global mocks for the resamples_then_works test
"""
Expand Down Expand Up @@ -135,22 +135,43 @@ def resamples_then_works_globals(self, mock_get_client):
# src/asteroid_sdk/registration/helper.py:453
send_supervision_result__response = make_created_response_with_id(supervision_result_id)

# Setup mock calls to asteroid API for during supervision
self.mock_asteroid_client.get_httpx_client.return_value.request.side_effect = [
send_chats_response,
get_tool_supervisor_chains_response,
get_tool_response,
send_supervision_request_response,
send_supervision_result__response,
get_run_response = None
if should_add_get_run_call:
# asteroid_sdk.registration.helper.get_run
# src/asteroid_sdk/registration/helper.py:346
get_run_response = make_ok_response(
{

send_chats_response,
get_tool_supervisor_chains_response,
get_tool_response,
send_supervision_request_response,
send_supervision_result__response,
]
"id": "123e4567-e89b-12d3-a456-426614174000",
"task_id": "123e4567-e89b-12d3-a456-426614174001",
"created_at": "2023-10-01T12:34:56Z",
"status": "pending",
"result": "Success",
"metadata": {
"key1": "value1",
"key2": "value2"
}
}
)

# # Setup mock calls to asteroid API for during supervision
api_responses = []
if should_add_get_run_call:
api_responses.append(get_run_response)
api_responses.append(send_chats_response)
api_responses.append(get_tool_supervisor_chains_response)
api_responses.append(get_tool_response)
api_responses.append(send_supervision_request_response)
api_responses.append(send_supervision_result__response)

def original_response_when_supervision_successful(self, mock_get_client):
api_responses.append(send_chats_response)
api_responses.append(get_tool_supervisor_chains_response)
api_responses.append(get_tool_response)
api_responses.append(send_supervision_request_response)
api_responses.append(send_supervision_result__response)
self.mock_asteroid_client.get_httpx_client.return_value.request.side_effect = api_responses

def original_response_when_supervision_successful(self, mock_get_client, should_add_get_run_call: bool = False):
# Mocking API client from the point it's called in registration
mock_get_client.return_value = self.mock_asteroid_client

Expand Down Expand Up @@ -185,7 +206,7 @@ def original_response_when_supervision_successful(self, mock_get_client):

# asteroid_sdk.api.asteroid_chat_supervision_manager.AsteroidChatSupervisionManager.handle_language_model_interaction
# src/asteroid_sdk/api/asteroid_chat_supervision_manager.py:73
send_chats_response = make_ok_response(
send_chats_response = make_created_response(
ChatIds(
chat_id=first_chat_id,
choice_ids=[ChoiceIds(
Expand Down Expand Up @@ -244,8 +265,26 @@ def original_response_when_supervision_successful(self, mock_get_client):
# src/asteroid_sdk/registration/helper.py:453
send_supervision_result__response = make_created_response_with_id(supervision_result_id)

get_run_response = None
if should_add_get_run_call:
# asteroid_sdk.registration.helper.get_run
# src/asteroid_sdk/registration/helper.py:346
get_run_response = make_ok_response(
{

"id": "123e4567-e89b-12d3-a456-426614174000",
"task_id": "123e4567-e89b-12d3-a456-426614174001",
"created_at": "2023-10-01T12:34:56Z",
"status": "pending",
"metadata": {
"key1": "value1",
"key2": "value2"
}
}
)
# Setup mock calls to asteroid API for during supervision
self.mock_asteroid_client.get_httpx_client.return_value.request.side_effect = [
get_run_response,
send_chats_response,
get_tool_supervisor_chains_response,
get_tool_response,
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/test_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setUp(self):

@patch('asteroid_sdk.registration.helper.APIClientFactory.get_client')
def test_original_response_is_returned_with_supervision_is_successful(self, mock_get_client):
self.original_response_when_supervision_successful(mock_get_client)
self.original_response_when_supervision_successful(mock_get_client, True)
# THIS IS KEY- maybe we should instantiate the wrapper after we've run the init?
self.openai_wrapper.run_id = self.run_id

Expand Down Expand Up @@ -56,7 +56,7 @@ def test_original_response_is_returned_with_supervision_is_successful(self, mock

@patch('asteroid_sdk.registration.helper.APIClientFactory.get_client')
def test_resamples_and_then_works(self, mock_get_client):
self.resamples_then_works_globals(mock_get_client)
self.resamples_then_works_globals(mock_get_client, True)
# Mock call to LM
# Note- the allow: true param is what the supervisor is after to approve
desired_completion_message = self.create_chat_completion_with_tool_calls(
Expand Down

0 comments on commit 44420e2

Please sign in to comment.