Skip to content

Commit

Permalink
tests for testing helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Oct 11, 2024
1 parent 5268884 commit b0a63af
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 12 deletions.
10 changes: 5 additions & 5 deletions src/numpydantic/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,21 @@ def validate_case(self, path: Optional[Path] = None) -> bool:
Raises:
ValueError: if an ``interface`` is missing
"""
if self.interface is None:
if self.interface is None: # pragma: no cover
raise ValueError("Missing an interface")
if path is None:
if self.path:
path = self.path
else:
else: # pragma: no cover
raise ValueError("Missing a path to generate arrays into")

return self.interface.validate_case(self, path)

def array(self, path: Path) -> NDArrayType:
"""Generate an array for the validation case if we have an interface to do so"""
if self.interface is None:
if self.interface is None: # pragma: no cover
raise ValueError("Missing an interface")
if path is None:
if path is None: # pragma: no cover
if self.path:
path = self.path
else:
Expand Down Expand Up @@ -242,7 +242,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
"""
Merge multiple validation cases
"""
if len(args) == 1:
if len(args) == 1: # pragma: no cover
return args[0]

dumped = [
Expand Down
11 changes: 5 additions & 6 deletions src/numpydantic/testing/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def make_array(
path: Optional[Path] = None,
array: Optional[NDArrayType] = None,
) -> Optional[H5ArrayPath]:
if cls.skip(shape, dtype):
if cls.skip(shape, dtype): # pragma: no cover
return None

hdf5_file = path / "h5f.h5"
Expand Down Expand Up @@ -99,7 +99,7 @@ def make_array(
path: Optional[Path] = None,
array: Optional[NDArrayType] = None,
) -> Optional[H5ArrayPath]:
if cls.skip(shape, dtype):
if cls.skip(shape, dtype): # pragma: no cover
return None

hdf5_file = path / "h5f.h5"
Expand Down Expand Up @@ -140,7 +140,7 @@ def make_array(
array: Optional[NDArrayType] = None,
) -> da.Array:
if array is not None:
return da.array(array, dtype=dtype, chunks=-1)
return da.array(array, dtype=dtype)
if issubclass(dtype, BaseModel):
return da.full(shape=shape, fill_value=dtype(x=1), chunks=-1)
else:
Expand Down Expand Up @@ -244,11 +244,11 @@ def make_array(
path: Optional[Path] = None,
array: Optional[NDArrayType] = None,
) -> Optional[Path]:
if cls.skip(shape, dtype):
if cls.skip(shape, dtype): # pragma: no cover
return None

if array is not None:
array = np.ndarray(shape, dtype=np.uint8)
array = np.array(array, dtype=np.uint8)
shape = array.shape

is_color = len(shape) == 4
Expand All @@ -263,7 +263,6 @@ def make_array(
(frame_shape[1], frame_shape[0]),
is_color,
)

for i in range(frames):
if array is not None:
frame = array[i]
Expand Down
1 change: 0 additions & 1 deletion tests/test_interface/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_interface_to_numpy_array(dtype_by_interface_instance):
All interfaces should be able to have the output of their validation stage
coerced to a numpy array with np.array()
"""

_ = np.array(dtype_by_interface_instance.array)


Expand Down
60 changes: 60 additions & 0 deletions tests/test_testing_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Tests for the testing helpers lmao
"""

import numpy as np
import pytest
from pydantic import BaseModel

from numpydantic import NDArray, Shape
from numpydantic.testing.cases import INTERFACE_CASES
from numpydantic.testing.helpers import ValidationCase
from numpydantic.testing.interfaces import NumpyCase


def test_validation_case_merge():
case_1 = ValidationCase(id="1", interface=NumpyCase, passes=False)
case_2 = ValidationCase(id="2", dtype=str, passes=True)
case_3 = ValidationCase(id="3", shape=(1, 2, 3), passes=True)

merged_simple = case_2.merge(case_3)
assert merged_simple.dtype == case_2.dtype
assert merged_simple.shape == case_3.shape

merged_multi = case_1.merge([case_2, case_3])
assert merged_multi.dtype == case_2.dtype
assert merged_multi.shape == case_3.shape
assert merged_multi.interface == case_1.interface

# passes should be true only if all the cases are
assert merged_simple.passes
assert not merged_multi.passes

# ids should merge
assert merged_simple.id == "2-3"
assert merged_multi.id == "1-2-3"


@pytest.mark.parametrize(
"interface",
[
pytest.param(
i.interface, marks=getattr(pytest.mark, i.interface.interface.name)
)
for i in INTERFACE_CASES
if i.id not in ("hdf5_compound")
],
)
def test_make_array(interface, tmp_output_dir_func):
"""
An interface case can generate an array from params or a given array
Not testing correctness here, that's what hte rest of the testing does.
"""
arr = np.zeros((10, 10, 2, 3), dtype=np.uint8)
arr = interface.make_array(array=arr, dtype=np.uint8, path=tmp_output_dir_func)

class MyModel(BaseModel):
array: NDArray[Shape["10, 10, 2, 3"], np.uint8]

_ = MyModel(array=arr)

0 comments on commit b0a63af

Please sign in to comment.