Skip to content

Commit

Permalink
Update bundle docs to use pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
dmchoiboi committed Jul 18, 2024
1 parent 792fce0 commit b7f6103
Show file tree
Hide file tree
Showing 10 changed files with 1,041 additions and 991 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
docker:
# Important: Don't change this otherwise we will stop testing the earliest
# python version we have to support.
- image: python:3.7-buster
- image: python:3.8-bullseye
resource_class: small
steps:
- checkout # checkout source code to working directory
Expand Down
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
[tool.pylint.MESSAGE_CONTROL]
disable=
bad-continuation,
no-else-return,
too-few-public-methods,
line-too-long,
duplicate-code,
import-error,
unused-argument,
no-self-use,
import-outside-toplevel,
too-many-instance-attributes,
no-member,
W3101,
R1735,
W0511,
R0914,
R0913,
Expand Down
40 changes: 20 additions & 20 deletions docs/concepts/model_bundles.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ Each of these modes of creating a model bundle is called a "Flavor".
=== "Creating From Callables"
```py
import os
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from launch import LaunchClient


class MyRequestSchema(BaseModel):
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int
class MyResponseSchema(RootModel):
root: int


def my_load_predict_fn(model):
Expand Down Expand Up @@ -107,7 +107,7 @@ Each of these modes of creating a model bundle is called a "Flavor".
```py
import os
import tempfile
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from launch import LaunchClient

directory = tempfile.mkdtemp()
Expand Down Expand Up @@ -151,8 +151,8 @@ Each of these modes of creating a model bundle is called a "Flavor".
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int
class MyResponseSchema(RootModel):
root: int

print(directory)
print(model_filename)
Expand Down Expand Up @@ -183,16 +183,16 @@ Each of these modes of creating a model bundle is called a "Flavor".
=== "Creating From a Runnable Image"
```py
import os
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from launch import LaunchClient


class MyRequestSchema(BaseModel):
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int
class MyResponseSchema(RootModel):
root: int


BUNDLE_PARAMS = {
Expand All @@ -218,16 +218,16 @@ Each of these modes of creating a model bundle is called a "Flavor".
=== "Creating From a Triton Enhanced Runnable Image"
```py
import os
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from launch import LaunchClient


class MyRequestSchema(BaseModel):
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int
class MyResponseSchema(RootModel):
root: int


BUNDLE_PARAMS = {
Expand Down Expand Up @@ -260,16 +260,16 @@ Each of these modes of creating a model bundle is called a "Flavor".
=== "Creating From a Streaming Enhanced Runnable Image"
```py
import os
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from launch import LaunchClient


class MyRequestSchema(BaseModel):
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int
class MyResponseSchema(RootModel):
root: int


BUNDLE_PARAMS = {
Expand Down Expand Up @@ -305,7 +305,7 @@ tasks.
```py title="Creating Model Bundles with app_config"
import os
from launch import LaunchClient
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from typing import List, Union
from typing_extensions import Literal

Expand All @@ -320,11 +320,11 @@ class MyRequestSchemaBatched(BaseModel):
x: List[int]
y: List[str]

class MyRequestSchema(BaseModel):
__root__: Union[MyRequestSchemaSingle, MyRequestSchemaBatched]
class MyRequestSchema(RootModel):
root: Union[MyRequestSchemaSingle, MyRequestSchemaBatched]

class MyResponseSchema(BaseModel):
__root__: Union[int, List[int]]
class MyResponseSchema(RootModel):
root: Union[int, List[int]]


def my_load_predict_fn(app_config, model):
Expand Down
8 changes: 4 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import os
import time
from launch import LaunchClient
from launch import EndpointRequest
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from rich import print


class MyRequestSchema(BaseModel):
x: int
y: str

class MyResponseSchema(BaseModel):
__root__: int
class MyResponseSchema(RootModel):
root: int


def my_load_predict_fn(model):
Expand Down Expand Up @@ -86,7 +86,7 @@ request = MyRequestSchema(x=5, y="hello")
response = predict_on_endpoint(request)
print(response)
"""
MyResponseSchema(__root__=10)
MyResponseSchema(root=10)
"""
```

Expand Down
2 changes: 1 addition & 1 deletion launch/cli/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def edit_endpoint(ctx: click.Context, endpoint_name: str):
post_inference_hooks = model_endpoint.post_inference_hooks or []
for hook in PostInferenceHooks:
value = hook.value # type: ignore
post_inference_hooks_choices.append(q.Choice(title=value, checked=(value in post_inference_hooks)))
post_inference_hooks_choices.append(q.Choice(title=value, checked=value in post_inference_hooks))

if model_endpoint.status != "READY":
pretty_print(f"Endpoint '{endpoint_name}' is not ready. Please wait for it to be ready " "before editing.")
Expand Down
4 changes: 2 additions & 2 deletions launch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,7 @@ def completions_stream(
stream=True,
timeout=timeout,
)
sse_client = sseclient.SSEClient(response)
sse_client = sseclient.SSEClient(response) # type: ignore
events = sse_client.events()
for event in events:
yield json.loads(event.data)
Expand All @@ -3027,7 +3027,7 @@ def create_fine_tune(
fine_tuning_method: Optional[str] = None,
hyperparameters: Optional[Dict[str, str]] = None,
wandb_config: Optional[Dict[str, Any]] = None,
suffix: str = None,
suffix: Optional[str] = None,
) -> CreateFineTuneResponse:
"""
Create a fine-tune
Expand Down
17 changes: 7 additions & 10 deletions launch/pydantic_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
import pydantic
from pydantic import BaseModel

if hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("1."):
PYDANTIC_VERSION = 1
from pydantic.schema import (
PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.")

if not PYDANTIC_V2:
from pydantic.schema import ( # pylint: disable=no-name-in-module
get_flat_models_from_models,
model_process_schema,
)
elif hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2."):
PYDANTIC_VERSION = 2
else:
raise ImportError("Unsupported pydantic version.")


REF_PREFIX = "#/components/schemas/"
Expand All @@ -36,10 +33,10 @@ def get_model_definitions_v2(request_schema: Type[BaseModel], response_schema: T
}


if PYDANTIC_VERSION == 1:
get_model_definitions: Callable = get_model_definitions_v1 # type: ignore
elif PYDANTIC_VERSION == 2:
if PYDANTIC_V2:
get_model_definitions: Callable = get_model_definitions_v2 # type: ignore
else:
get_model_definitions: Callable = get_model_definitions_v1 # type: ignore


def get_model_definitions_from_flat_models(
Expand Down
Loading

0 comments on commit b7f6103

Please sign in to comment.