Skip to content

Commit

Permalink
hoo boy. working combinatoric testing.
Browse files Browse the repository at this point in the history
Split out annotation dtype and shape, swap out all interface tests, fix numpy and dask model casting, make merging models more efficient, correctly parameterize and mark tests!
  • Loading branch information
sneakers-the-rat committed Oct 11, 2024
1 parent 5d4f03a commit 1187b37
Show file tree
Hide file tree
Showing 12 changed files with 476 additions and 272 deletions.
28 changes: 26 additions & 2 deletions src/numpydantic/interface/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Iterable, List, Literal, Optional, Union

import numpy as np
from pydantic import SerializationInfo
from pydantic import BaseModel, SerializationInfo

from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType
Expand Down Expand Up @@ -70,9 +70,33 @@ def check(cls, array: Any) -> bool:
else:
return False

def before_validation(self, array: DaskArray) -> NDArrayType:
"""
Try and coerce dicts that should be model objects into the model objects
"""
try:
if issubclass(self.dtype, BaseModel) and isinstance(
array.reshape(-1)[0].compute(), dict
):

def _chunked_to_model(array: np.ndarray) -> np.ndarray:
def _vectorized_to_model(item: Union[dict, BaseModel]) -> BaseModel:
if not isinstance(item, self.dtype):
return self.dtype(**item)
else:
return item

return np.vectorize(_vectorized_to_model)(array)

array = array.map_blocks(_chunked_to_model, dtype=self.dtype)
except TypeError:
# fine, dtype isn't a type
pass
return array

def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value"""
return type(array.ravel()[0].compute())
return type(array.reshape(-1)[0].compute())

@classmethod
def enabled(cls) -> bool:
Expand Down
13 changes: 12 additions & 1 deletion src/numpydantic/interface/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any, Literal, Union

from pydantic import SerializationInfo
from pydantic import BaseModel, SerializationInfo

from numpydantic.interface.interface import Interface, JsonDict

Expand Down Expand Up @@ -59,6 +59,9 @@ def check(cls, array: Any) -> bool:
Check that this is in fact a numpy ndarray or something that can be
coerced to one
"""
if array is None:
return False

if isinstance(array, ndarray):
return True
elif isinstance(array, dict):
Expand All @@ -77,6 +80,14 @@ def before_validation(self, array: Any) -> ndarray:
"""
if not isinstance(array, ndarray):
array = np.array(array)

try:
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
array = np.vectorize(lambda x: self.dtype(**x))(array)
except TypeError:
# fine, dtype isn't a type
pass

return array

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion src/numpydantic/interface/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ZarrJsonDict(JsonDict):
type: Literal["zarr"]
file: Optional[str] = None
path: Optional[str] = None
dtype: Optional[str] = None
value: Optional[list] = None

def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
Expand All @@ -73,7 +74,7 @@ def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
if self.file:
array = ZarrArrayPath(file=self.file, path=self.path)
else:
array = zarr.array(self.value)
array = zarr.array(self.value, dtype=self.dtype)
return array


Expand Down Expand Up @@ -194,6 +195,7 @@ def to_json(
is_file = False

as_json = {"type": cls.name}
as_json["dtype"] = array.dtype.name
if hasattr(array.store, "dir_path"):
is_file = True
as_json["file"] = array.store.dir_path()
Expand Down
Loading

0 comments on commit 1187b37

Please sign in to comment.