Skip to content

Commit

Permalink
[SDK-999] feat: allow api key to be passed
Browse files Browse the repository at this point in the history
  • Loading branch information
joehewett committed Jan 30, 2025
1 parent b1dd49d commit 0c0eefb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
20 changes: 17 additions & 3 deletions scripts/check_commit_message.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# Define the regex pattern for commit messages
pattern='^\[SDK-[0-9]+\] (feat|fix|docs|style|refactor|perf|test|chore): .{1,80}(\n.*)*$'
pattern='^\[SDK-[0-9]+\]? ?(feat|fix|docs|style|refactor|perf|test|chore): .{1,80}(\n.*)*$'

# Get the commit messages for the PR
commit_messages=$(git log --format='%H' origin/main..HEAD)
Expand All @@ -11,8 +11,22 @@ for commit_hash in $commit_messages; do
commit_message=$(git log --format=%B -n 1 $commit_hash)
echo "Checking commit message: $commit_message"
if ! [[ $commit_message =~ $pattern ]]; then
echo "Commit message does not match the required pattern:"
echo "$commit_message"
echo "Error: Commit message does not match the required pattern!"
echo "Your message: '$commit_message'"
echo -e "\nThe message should:"
# if ! [[ $commit_message =~ ^\[SDK-[0-9]+\] ]]; then
# echo "- Start with [SDK-XXX] where XXX is a number"
# fi
if ! [[ $commit_message =~ ^\[SDK-[0-9]+\]\ (feat|fix|docs|style|refactor|perf|test|chore): ]]; then
echo "- Include one of these types after the SDK number: feat, fix, docs, style, refactor, perf, test, chore"
fi
if ! [[ $commit_message =~ ^\[SDK-[0-9]+\]\ (feat|fix|docs|style|refactor|perf|test|chore):\ .+ ]]; then
echo "- Include a description after the type"
fi
if [[ ${#commit_message} -gt 80 ]]; then
echo "- Be no longer than 80 characters on the first line"
fi
echo -e "\nExample: [SDK-123] feat: add new awesome feature"
exit 1
fi
done
Expand Down
25 changes: 22 additions & 3 deletions src/asteroid_sdk/registration/initialise_project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Callable, List, Optional, Dict
from uuid import UUID

from asteroid_sdk import settings
from asteroid_sdk.api.generated.asteroid_api_client.client import Client
from asteroid_sdk.api.generated.asteroid_api_client.models import Status
from asteroid_sdk.registration.helper import (
create_run, register_project, register_task, register_tools_and_supervisors_from_registry, submit_run_status,
APIClientFactory, create_run, register_project, register_task, register_tools_and_supervisors_from_registry, submit_run_status,
register_tool, create_supervisor_chain, register_supervisor_chains
)
from asteroid_sdk.supervision.config import ExecutionMode, RejectionPolicy, get_supervision_config
Expand All @@ -17,11 +19,28 @@ def asteroid_init(
run_name: str = "My Run",
execution_settings: Dict[str, Any] = {},
message_supervisors: Optional[List[Callable]] = None,
run_id: Optional[UUID] = None
run_id: Optional[UUID] = None,
api_key: Optional[str] = None
) -> UUID:
"""
Initializes supervision for a project, task, and run.
Args:
project_name: Name of the project
task_name: Name of the task
run_name: Name of the run
execution_settings: Dictionary of execution settings
message_supervisors: Optional list of message supervisor functions
run_id: Optional UUID for the run
api_key: Optional API key to override the default from environment variables
"""
if api_key:
# If the user provided an API key, override the client in settings
logging.info("Overriding API key env variable with provided API key")
APIClientFactory._instance = Client(
base_url=settings.api_url,
headers={"X-Asteroid-Api-Key": api_key}
)

project_id = register_project(project_name)
logging.info(f"Registered new project '{project_name}' with ID: {project_id}")
Expand All @@ -34,7 +53,7 @@ def asteroid_init(
supervision_config.set_execution_settings(execution_settings)

register_tools_and_supervisors_from_registry(run_id=run_id,
message_supervisors=message_supervisors)
message_supervisors=message_supervisors)

return run_id

Expand Down
6 changes: 2 additions & 4 deletions src/asteroid_sdk/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def __init__(self):
logging.info("Initializing Asteroid SDK settings")

# Asteroid API settings
self.api_key = os.getenv('ASTEROID_API_KEY')
self.api_key = os.getenv('ASTEROID_API_KEY') # Don't error out if this is not set, user might provide in init
self.api_url = os.getenv('ASTEROID_API_URL', "https://api.asteroid.ai/api/v1")
self.openai_api_key = os.getenv('OPENAI_API_KEY')

# NEW: Optional integration
# Optional integration
self.langfuse_enabled = (
os.getenv('LANGFUSE_ENABLED', 'false').lower() in ['true', '1']
)
Expand All @@ -30,8 +30,6 @@ def __init__(self):
raise ValueError("LANGFUSE_HOST environment variable is required")

# Validate required settings
if not self.api_key:
raise ValueError("ASTEROID_API_KEY environment variable is required")
if not self.api_url:
raise ValueError("ASTEROID_API_URL environment variable is required")

Expand Down

0 comments on commit 0c0eefb

Please sign in to comment.