Skip to content

Commit

Permalink
feat: add run pause/fail and metadata update helpers
Browse files Browse the repository at this point in the history
- Add pause_run function to pause a run
- Add fail_run function to fail a run and update metadata
- Add update_run_metadata function to update run metadata
- Use APIClientFactory to get API client in helper functions
  • Loading branch information
joehewett committed Feb 1, 2025
1 parent 2051f8d commit 0164082
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 4 deletions.
53 changes: 52 additions & 1 deletion src/asteroid_sdk/interaction/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@

from asteroid_sdk.api.generated.asteroid_api_client.api.run.get_run import sync as get_run_sync
from asteroid_sdk.api.generated.asteroid_api_client.client import Client
from asteroid_sdk.api.generated.asteroid_api_client.api.run.update_run_status import sync_detailed as update_run_status_sync
from asteroid_sdk.api.generated.asteroid_api_client.models.status import Status
from asteroid_sdk.registration.helper import APIClientFactory
from asteroid_sdk.api.generated.asteroid_api_client.api.run.update_run_metadata import sync_detailed as update_run_metadata_sync
from asteroid_sdk.api.generated.asteroid_api_client.models.update_run_metadata_body import UpdateRunMetadataBody

async def wait_for_unpaused(run_id: str, client: Client):

async def wait_for_unpaused(run_id: str):
"""Wait until the run is no longer in paused state."""
client = APIClientFactory.get_client()

start_time = time.time()
timeout = 1200 # 20 minute timeout

Expand All @@ -30,3 +38,46 @@ async def wait_for_unpaused(run_id: str, client: Client):
except Exception as e:
logging.error(f"Error checking run status: {e}")
break # Exit the loop on error instead of continuing indefinitely

def pause_run(run_id: str):
"""Pause a running run."""
client = APIClientFactory.get_client()

try:
response = update_run_status_sync(client=client, run_id=run_id, body=Status(status="paused"))
if response is not None:
raise Exception(f"Failed to pause run {run_id}: {response.status_code} {response.content}")
except Exception as e:
logging.error(f"Error pausing run {run_id}: {e}")
raise e

def fail_run(run_id: str, error_message: str):
"""Fail a running run."""
client = APIClientFactory.get_client()

try:
response = update_run_status_sync(client=client, run_id=run_id, body=Status(status="failed", error_message=error_message))
if response is not None:
raise Exception(f"Failed to fail run {run_id}: {response.status_code} {response.content}")
update_run_metadata(run_id, {"fail_reason": error_message})
except Exception as e:
logging.error(f"Error failing run {run_id}: {e}")
raise e

def update_run_metadata(run_id: str, metadata: dict):
"""Update the metadata of a run with the provided dictionary."""
client = APIClientFactory.get_client()

try:
metadata_body = UpdateRunMetadataBody.from_dict(metadata)
response = update_run_metadata_sync(
client=client,
run_id=run_id,
body=metadata_body
)
if response.status_code != 204:
raise Exception(f"Failed to update run metadata for {run_id}: {response.status_code} {response.content}")
except Exception as e:
logging.error(f"Error updating run metadata for {run_id}: {e}")
raise e

2 changes: 1 addition & 1 deletion src/asteroid_sdk/wrappers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def create(
**kwargs,
) -> Any:
# Wait for unpaused state before proceeding - blocks until complete
future = schedule_task(wait_for_unpaused(self.run_id, self.chat_supervision_manager.client))
future = schedule_task(wait_for_unpaused(self.run_id))
future.result() # This blocks until the future is done

# If parallel tool calls are not set to false, then update accordingly.
Expand Down
2 changes: 1 addition & 1 deletion src/asteroid_sdk/wrappers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def generate_content(
**kwargs,
) -> Any:
# Wait for unpaused state before proceeding - blocks until complete
future = schedule_task(wait_for_unpaused(self.run_id, self.chat_supervision_manager.client))
future = schedule_task(wait_for_unpaused(self.run_id))
future.result() # This blocks until the future is done

# TODO - Check if there's any other config that we need to sort out here
Expand Down
2 changes: 1 addition & 1 deletion src/asteroid_sdk/wrappers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def create(
**kwargs,
) -> Any:
# Wait for unpaused state before proceeding - blocks until complete
future = schedule_task(wait_for_unpaused(self.run_id, self.chat_supervision_manager.client))
future = schedule_task(wait_for_unpaused(self.run_id))
future.result() # This blocks until the future is done

# If parallel tool calls not set to false (or doesn't exist, defaulting to true), then raise an error.
Expand Down

0 comments on commit 0164082

Please sign in to comment.