Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cli/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import click
import questionary
from rich.console import Console
from rich.pretty import pretty_repr
from rich.table import Column, Table

from cli.client import init_client
Expand All @@ -18,6 +20,47 @@ def models(ctx, web):
launch_web_or_invoke("models", ctx, web, list_models)


STRING_REPLACEMENTS = {
"\\n": "\n",
"\\t": "\t",
'\\"': '"',
}


def json_string_to_string(s: str) -> str:
for key, val in STRING_REPLACEMENTS.items():
s = s.replace(key, val)
return s


@models.command("calculate-metrics")
def metrics():
client = init_client()
models = client.models
prompt_to_id = {f"{m.id}: {m.name}": m.id for m in models}
ans = questionary.select(
"What model do you want to run metrics for?",
choices=list(prompt_to_id.keys()),
).ask()
model_id = prompt_to_id[ans]
jobs = client.validate.metrics(model_id)
console = Console()
with console.status("Calculating metrics"):
for job in jobs:
job.sleep_until_complete(False)

if len(job.errors()) == 0:
status = job.status()
click.echo(click.style("Done", fg="green"))
console.print(pretty_repr(status))
else:
click.echo(
click.style("Encountered errors during running", fg="green")
)
for error in job.errors():
click.echo(json_string_to_string(error))


@models.command("list")
def list_models():
"""List your Models"""
Expand Down
8 changes: 8 additions & 0 deletions nucleus/validate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def evaluate_model_on_scenario_tests(
)
return AsyncJob.from_json(response, self.connection)

def metrics(self, model_id: str):
response = self.connection.post(
{},
f"validate/{model_id}/metrics",
)
jobs = [AsyncJob.from_json(job, self.connection) for job in response]
return jobs

def create_external_eval_function(
self,
name: str,
Expand Down
Loading