Skip to content

Commit

Permalink
Port NumPy typing testing style to PyTorch (pytorch#54234)
Browse files Browse the repository at this point in the history
Summary:
This is a follow-up PR of pytorch#52408 and includes the `pass/` and `fail/` directories.

Pull Request resolved: pytorch#54234

Reviewed By: walterddr

Differential Revision: D27681410

Pulled By: malfet

fbshipit-source-id: e6817df77c758f4c1295ea62613106c71cfd3fc3
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Apr 15, 2021
1 parent a128938 commit 6eeffc6
Show file tree
Hide file tree
Showing 10 changed files with 659 additions and 53 deletions.
82 changes: 79 additions & 3 deletions test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
import re
import shutil
from typing import IO, Dict, List
from collections import defaultdict
from typing import IO, Dict, List, Optional

import pytest

Expand All @@ -18,6 +19,8 @@

DATA_DIR = os.path.join(os.path.dirname(__file__), "typing")
REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
PASS_DIR = os.path.join(DATA_DIR, "pass")
FAIL_DIR = os.path.join(DATA_DIR, "fail")
MYPY_INI = os.path.join(DATA_DIR, os.pardir, os.pardir, "mypy.ini")
CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")

Expand All @@ -35,6 +38,12 @@ def _key_func(key: str) -> str:
return os.path.join(drive, tail.split(":", 1)[0])


def _strip_filename(msg: str) -> str:
"""Strip the filename from a mypy message."""
_, tail = os.path.splitdrive(msg)
return tail.split(":", 1)[-1]


@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.fixture(scope="module", autouse=True)
def run_mypy() -> None:
Expand All @@ -46,7 +55,7 @@ def run_mypy() -> None:
if os.path.isdir(CACHE_DIR):
shutil.rmtree(CACHE_DIR)

for directory in (REVEAL_DIR,):
for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
# Run mypy
stdout, stderr, _ = api.run(
[
Expand Down Expand Up @@ -80,17 +89,84 @@ def get_test_cases(directory):
id=relpath,
)


@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
def test_success(path):
# Alias `OUTPUT_MYPY` so that it appears in the local namespace
output_mypy = OUTPUT_MYPY
if path in output_mypy:
msg = "Unexpected mypy output\n\n"
msg += "\n".join(_strip_filename(v) for v in output_mypy[path])
raise AssertionError(msg)


@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
@pytest.mark.parametrize("path", get_test_cases(FAIL_DIR))
def test_fail(path):
__tracebackhide__ = True

with open(path) as fin:
lines = fin.readlines()

errors = defaultdict(lambda: "")

output_mypy = OUTPUT_MYPY
assert path in output_mypy
for error_line in output_mypy[path]:
error_line = _strip_filename(error_line)
match = re.match(
r"(?P<lineno>\d+): (error|note): .+$",
error_line,
)
if match is None:
raise ValueError(f"Unexpected error line format: {error_line}")
lineno = int(match.group('lineno'))
errors[lineno] += f'{error_line}\n'

for i, line in enumerate(lines):
lineno = i + 1
if line.startswith('#') or (" E:" not in line and lineno not in errors):
continue

target_line = lines[lineno - 1]
if "# E:" in target_line:
marker = target_line.split("# E:")[-1].strip()
expected_error = errors.get(lineno)
_test_fail(path, marker, expected_error, lineno)
else:
pytest.fail(f"Unexpected mypy output\n\n{errors[lineno]}")


_FAIL_MSG1 = """Extra error at line {}
Extra error: {!r}
"""

_FAIL_MSG2 = """Error mismatch at line {}
Expected error: {!r}
Observed error: {!r}
"""


def _test_fail(path: str, error: str, expected_error: Optional[str], lineno: int) -> None:
if expected_error is None:
raise AssertionError(_FAIL_MSG1.format(lineno, error))
elif error not in expected_error:
raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))


def _construct_format_dict():
dct = {
'ModuleList': 'torch.nn.modules.container.ModuleList',
'AdaptiveAvgPool2d': 'torch.nn.modules.pooling.AdaptiveAvgPool2d',
'AdaptiveMaxPool2d': 'torch.nn.modules.pooling.AdaptiveMaxPool2d',
'Tensor': 'torch.tensor.Tensor',
'Tensor': 'torch._tensor.Tensor',
'Adagrad': 'torch.optim.adagrad.Adagrad',
'Adam': 'torch.optim.adam.Adam',
}
return dct


#: A dictionary with all supported format keys (as keys)
#: and matching values
FORMAT_DICT: Dict[str, str] = _construct_format_dict()
Expand Down
9 changes: 9 additions & 0 deletions test/typing/fail/bitwise_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# flake8: noqa
import torch

# binary ops: <<, >>, |, &, ~, ^

a = torch.ones(3, dtype=torch.float64)
i = int()

i | a # E: Unsupported operand types
6 changes: 6 additions & 0 deletions test/typing/fail/creation_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# flake8: noqa
import torch

torch.tensor([3], dtype='int32') # E: expected "Optional[dtype]"
torch.ones(3, dtype='int32') # E: No overload variant of "ones" matches argument types "int", "str"
torch.zeros(3, dtype='int32') # E: No overload variant of "zeros" matches argument types "int", "str"
4 changes: 4 additions & 0 deletions test/typing/fail/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# flake8: noqa
import torch

torch.set_rng_state([1, 2, 3]) # E: Argument 1 to "set_rng_state" has incompatible type "List[int]"; expected "Tensor"
118 changes: 118 additions & 0 deletions test/typing/pass/creation_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# flake8: noqa
import torch
from torch.testing._internal.common_utils import TEST_NUMPY
if TEST_NUMPY:
import numpy as np

# From the docs, there are quite a few ways to create a tensor:
# https://pytorch.org/docs/stable/tensors.html

# torch.tensor()
torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
torch.tensor([0, 1])
torch.tensor([[0.11111, 0.222222, 0.3333333]],
dtype=torch.float64,
device=torch.device('cuda:0'))
torch.tensor(3.14159)

# torch.sparse_coo_tensor
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)
torch.sparse_coo_tensor(i, v, [2, 4])
torch.sparse_coo_tensor(i, v)
torch.sparse_coo_tensor(i, v, [2, 4],
dtype=torch.float64,
device=torch.device('cuda:0'))
torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1])
torch.sparse_coo_tensor(torch.empty([1, 0]),
torch.empty([0, 2]), [1, 2])

# torch.as_tensor
a = [1, 2, 3]
torch.as_tensor(a)
torch.as_tensor(a, device=torch.device('cuda'))

# torch.as_strided
x = torch.randn(3, 3)
torch.as_strided(x, (2, 2), (1, 2))
torch.as_strided(x, (2, 2), (1, 2), 1)

# torch.from_numpy
if TEST_NUMPY:
torch.from_numpy(np.array([1, 2, 3]))

# torch.zeros/zeros_like
torch.zeros(2, 3)
torch.zeros((2, 3))
torch.zeros([2, 3])
torch.zeros(5)
torch.zeros_like(torch.empty(2, 3))

# torch.ones/ones_like
torch.ones(2, 3)
torch.ones((2, 3))
torch.ones([2, 3])
torch.ones(5)
torch.ones_like(torch.empty(2, 3))

# torch.arange
torch.arange(5)
torch.arange(1, 4)
torch.arange(1, 2.5, 0.5)

# torch.range
torch.range(1, 4)
torch.range(1, 4, 0.5)

# torch.linspace
torch.linspace(3, 10, steps=5)
torch.linspace(-10, 10, steps=5)
torch.linspace(start=-10, end=10, steps=5)
torch.linspace(start=-10, end=10, steps=1)

# torch.logspace
torch.logspace(start=-10, end=10, steps=5)
torch.logspace(start=0.1, end=1.0, steps=5)
torch.logspace(start=0.1, end=1.0, steps=1)
torch.logspace(start=2, end=2, steps=1, base=2)

# torch.eye
torch.eye(3)

# torch.empty/empty_like/empty_strided
torch.empty(2, 3)
torch.empty((2, 3))
torch.empty([2, 3])
torch.empty_like(torch.empty(2, 3), dtype=torch.int64)
torch.empty_strided((2, 3), (1, 2))

# torch.full/full_like
torch.full((2, 3), 3.141592)
torch.full_like(torch.full((2, 3), 3.141592), 2.71828)

# torch.quantize_per_tensor
torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8)

# torch.quantize_per_channel
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8)

# torch.dequantize
torch.dequantize(x)

# torch.complex
real = torch.tensor([1, 2], dtype=torch.float32)
imag = torch.tensor([3, 4], dtype=torch.float32)
torch.complex(real, imag)

# torch.polar
abs = torch.tensor([1, 2], dtype=torch.float64)
pi = torch.acos(torch.zeros(1)).item() * 2
angle = torch.tensor([pi / 2, 5 * pi / 4], dtype=torch.float64)
torch.polar(abs, angle)

# torch.heaviside
inp = torch.tensor([-1.5, 0, 2.0])
values = torch.tensor([0.5])
torch.heaviside(inp, values)
Loading

0 comments on commit 6eeffc6

Please sign in to comment.