Skip to content

Commit

Permalink
add ability to set marks in test cases, add combinatoric test for jso…
Browse files Browse the repository at this point in the history
…n schema generation
  • Loading branch information
sneakers-the-rat committed Dec 14, 2024
1 parent 0cbdfb2 commit 8bf2203
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ markers = [
"numpy: numpy interface",
"video: video interface",
"zarr: zarr interface",
"union: union dtypes",
"pipe_union: union dtypes specified with a pipe",
]

[tool.black]
Expand Down
15 changes: 14 additions & 1 deletion src/numpydantic/testing/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,27 +143,35 @@ class SubClass(BasicModel):
dtype=np.uint32,
passes=True,
id="union-type-uint32",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float32,
passes=True,
id="union-type-float32",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.uint64,
passes=False,
id="union-type-uint64",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float64,
passes=False,
id="union-type-float64",
marks={"union"},
),
ValidationCase(
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str"
annotation_dtype=UNION_TYPE,
dtype=str,
passes=False,
id="union-type-str",
marks={"union"},
),
]
"""
Expand All @@ -181,30 +189,35 @@ class SubClass(BasicModel):
dtype=np.uint32,
passes=True,
id="union-pipe-uint32",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.float32,
passes=True,
id="union-pipe-float32",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.uint64,
passes=False,
id="union-pipe-uint64",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=np.float64,
passes=False,
id="union-pipe-float64",
marks={"union", "pipe_union"},
),
ValidationCase(
annotation_dtype=UNION_PIPE,
dtype=str,
passes=False,
id="union-pipe-str",
marks={"union", "pipe_union"},
),
]
)
Expand Down
28 changes: 25 additions & 3 deletions src/numpydantic/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from itertools import product
from operator import ior
from pathlib import Path
from typing import Generator, List, Literal, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Type, Union

import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field

from numpydantic import NDArray, Shape
from numpydantic.dtype import Float
from numpydantic.interface import Interface
from numpydantic.types import DtypeType, NDArrayType

if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator


class InterfaceCase(ABC):
"""
Expand Down Expand Up @@ -139,6 +142,8 @@ class ValidationCase(BaseModel):
"""The interface test case to generate and validate the array with"""
path: Optional[Path] = None
"""The path to generate arrays into, if any."""
marks: set[str] = Field(default_factory=set)
"""pytest marks to set for this test case"""

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -179,6 +184,19 @@ class Model(BaseModel):

return Model

@property
def pytest_marks(self) -> list["MarkDecorator"]:
"""
Instantiated pytest marks from :attr:`.ValidationCase.marks`
plus the interface name.
"""
import pytest

marks = self.marks.copy()
if self.interface is not None:
marks.add(self.interface.interface.name)
return [getattr(pytest.mark, m) for m in marks]

def validate_case(self, path: Optional[Path] = None) -> bool:
"""
Whether the generated array correctly validated against the annotation,
Expand Down Expand Up @@ -246,7 +264,10 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
return args[0]

dumped = [
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args
m.model_dump(
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
)
for m in args
]

# self_dump = self.model_dump(exclude_unset=True)
Expand All @@ -263,6 +284,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
merged = reduce(ior, dumped, {})
merged["passes"] = passes
merged["id"] = ids
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
return ValidationCase.model_construct(**merged)


Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def pytest_addoption(parser):


@pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES]
scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in SHAPE_CASES],
)
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
Expand All @@ -23,7 +24,8 @@ def shape_cases(request, tmp_output_dir_func) -> ValidationCase:


@pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES]
scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in DTYPE_CASES],
)
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
Expand Down
12 changes: 3 additions & 9 deletions tests/test_interface/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def interface_cases(request) -> InterfaceCase:


@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES
)
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES)
)
def all_cases(interface_cases, request) -> ValidationCase:
"""
Expand All @@ -83,10 +80,7 @@ def all_cases(interface_cases, request) -> ValidationCase:


@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES_PASSING
)
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES_PASSING)
)
def all_passing_cases(request) -> ValidationCase:
"""
Expand Down Expand Up @@ -132,7 +126,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):

@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
pytest.param(p, id=p.id, marks=p.pytest_marks)
for p in DTYPE_AND_INTERFACE_CASES_PASSING
)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_interface/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_interface_revalidate(all_passing_cases_instance):
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)


@pytest.mark.json_schema
def test_interface_jsonschema(all_passing_cases_instance):
"""
All interfaces should be able to generate json schema
for all combinations of dtype and shape
Note that this does not test for json schema correctness -
see ndarray tests for that
"""
_ = all_passing_cases_instance.model_json_schema()


@pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func):
"""
Expand Down

0 comments on commit 8bf2203

Please sign in to comment.