forked from patrick-kidger/jaxtyping
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
84 lines (64 loc) · 2.47 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) 2022 Google LLC
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import random
import jax.random as jr
import pytest
import typeguard
try:
import beartype
except ImportError:
def skip(*args, **kwargs):
pytest.skip("Beartype not installed")
typecheck_params = [typeguard.typechecked, skip]
else:
typecheck_params = [typeguard.typechecked, beartype.beartype]
@pytest.fixture(params=typecheck_params)
def typecheck(request):
return request.param
@pytest.fixture(params=(False, True))
def jaxtyp(request):
import jaxtyping
if request.param:
# New-style
# @jaxtyping.jaxtyped(typechecker=typechecker)
# def f(...)
return lambda typechecker: jaxtyping.jaxtyped(typechecker=typechecker)
else:
# Old-style
# @jaxtyping.jaxtyped
# @typechecker
# def f(...)
def impl(typechecker):
def decorator(fn):
with pytest.warns(match="As of jaxtyping version 0.2.24"):
return jaxtyping.jaxtyped(typechecker(fn))
return decorator
return impl
@pytest.fixture()
def getkey():
def _getkey():
# Not sure what the maximum actually is but this will do
return jr.PRNGKey(random.randint(0, 2**31 - 1))
return _getkey
@pytest.fixture(scope="module")
def beartype_or_skip():
yield pytest.importorskip("beartype")
@pytest.fixture(scope="module")
def typeguard_or_skip():
yield pytest.importorskip("typeguard")