Skip to content

Commit

Permalink
Revert "[Typing] Fix PEP 484 Violation (pytorch#105022)"
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Jul 14, 2023
1 parent 528ab47 commit b4d91b1
Show file tree
Hide file tree
Showing 30 changed files with 74 additions and 75 deletions.
4 changes: 2 additions & 2 deletions torch/_functorch/functional_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import Counter
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -12,7 +12,7 @@ def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
strict: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def _linalg_svd_meta(
A: Tensor,
full_matrices: bool = False,
compute_uv: bool = True,
driver: Optional[str] = None,
driver: str = None,
):
checkIsMatrix(A, "linalg.svd")
checkFloatingOrComplex(A, "linalg.svd")
Expand Down Expand Up @@ -1207,7 +1207,7 @@ def linalg_solve_triangular_meta(
upper: bool,
left: bool = True,
unitriangular: bool = False,
out: Optional[Tensor] = None,
out: Tensor = None,
) -> Tensor:
if out is None:
out = A.new_empty([0])
Expand Down Expand Up @@ -4755,8 +4755,8 @@ def upsample_nearest2d_backward(
grad_output: Tensor,
output_size: Sequence[Union[int, torch.types.SymInt]],
input_size: Sequence[Union[int, torch.types.SymInt]],
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
scales_h: float = None,
scales_w: float = None,
):
full_output_size = upsample_common_check(
input_size, output_size, num_spatial_dims=2
Expand Down
2 changes: 1 addition & 1 deletion torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
def _elementwise_meta(
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
) -> FakeTensor:
"""
Meta function for elementwise operations that produce outputs in the same dtype
Expand Down
8 changes: 4 additions & 4 deletions torch/_prims/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, Sequence
from warnings import warn

import torch
Expand Down Expand Up @@ -111,7 +111,7 @@ def __torch_function__(
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
Expand Down Expand Up @@ -161,7 +161,7 @@ def __torch_function__(
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
Expand Down Expand Up @@ -374,7 +374,7 @@ def __torch_function__(
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
Expand Down
6 changes: 3 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ def clamp(
@out_wrapper()
def clamp_min(
self: TensorLikeType,
min: Optional[TensorOrNumberLikeType] = None,
min: TensorOrNumberLikeType = None,
) -> TensorLikeType:
return torch.clamp(self, min=min) # type: ignore[arg-type]

Expand All @@ -1841,7 +1841,7 @@ def clamp_min(
@out_wrapper()
def clamp_max(
self: TensorLikeType,
max: Optional[TensorOrNumberLikeType] = None,
max: TensorOrNumberLikeType = None,
) -> TensorLikeType:
return torch.clamp(self, max=max) # type: ignore[arg-type]

Expand Down Expand Up @@ -4654,7 +4654,7 @@ def logspace(
ret = torch.linspace(
start,
end,
steps, # type: ignore[arg-type]
steps,
dtype=torch.float64,
layout=layout,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/nn/quantizable/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MultiheadAttention(nn.MultiheadAttention):
def __init__(self, embed_dim: int, num_heads: int,
dropout: float = 0., bias: bool = True,
add_bias_kv: bool = False, add_zero_attn: bool = False,
kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False,
kdim: int = None, vdim: int = None, batch_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(embed_dim, num_heads, dropout,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Dict, Any, List
import torch
from collections import defaultdict
from torch import nn
Expand Down Expand Up @@ -205,7 +205,7 @@ def register_layer(self, layer: nn.Module, aggregate_fn=None, reduce_fn=None,
# or sparsify_hook()
self.data_groups[name]['hook_state'] = "aggregate" # aggregate hook is attached

def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None):
def get_mask(self, name: str = None, layer: nn.Module = None):
"""
Returns mask associated to the layer.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.nn import functional as F
from functools import reduce
from typing import Any, List, Optional, Tuple
from typing import Tuple, Any, List

from .base_data_sparsifier import BaseDataSparsifier

Expand Down Expand Up @@ -31,9 +31,9 @@ class DataNormSparsifier(BaseDataSparsifier):
arguments and could be overriden by the configuration provided in the
`add_data` step.
"""
def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, sparsity_level: float = 0.5,
def __init__(self, data_list: List[Tuple[str, Any]] = None, sparsity_level: float = 0.5,
sparse_block_shape: Tuple[int, int] = (1, 4),
zeros_per_block: Optional[int] = None, norm: str = 'L1'):
zeros_per_block: int = None, norm: str = 'L1'):
if zeros_per_block is None:
zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
from typing import Dict, List, Optional
from typing import Dict, List

SUPPORTED_MODULES = {
nn.Embedding,
Expand All @@ -28,7 +28,7 @@ def _fetch_all_embeddings(model):
def post_training_sparse_quantize(model,
data_sparsifier_class,
sparsify_first=True,
select_embeddings: Optional[List[nn.Module]] = None,
select_embeddings: List[nn.Module] = None,
**sparse_config):
"""Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def full(
def zeros(
*size,
requires_grad: bool = False,
dtype: Optional[torch.dtype] = None,
dtype: torch.dtype = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
Expand Down
5 changes: 2 additions & 3 deletions torch/distributed/algorithms/_comm_hooks/default_hooks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools
import torch
import torch.distributed as dist
from typing import Optional


class DefaultState:
Expand Down Expand Up @@ -128,7 +127,7 @@ def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch
allreduce_hook(state, grad)
_decompress(state, grad)

def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
r"""
This FSDP communication hook implements a simple gradient compression
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
Expand All @@ -145,7 +144,7 @@ def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Opt
fp16_hook = functools.partial(_low_precision_hook, torch.float16)
return fp16_hook(state, grad, output)

def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
r"""
This FSDP communication hook implements a simple gradient compression
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/checkpoint/state_dict_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def load_state_dict(
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[LoadPlanner] = None,
planner: LoadPlanner = None,
) -> None:
"""
Loads a distributed ``state_dict`` in SPMD style.
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/checkpoint/state_dict_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def save_state_dict(
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
planner: SavePlanner = None,
) -> Metadata:
"""
Saves a distributed model in SPMD style.
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/collective_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def broadcast(
raise AssertionError("Data or Function is expected to be None if not successful")

payload: Optional[T] = None
exception : Optional[Exception] = None
exception : Exception = None
# if no pg is passed then execute if rank is 0
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
# determine if it is an executable function or data payload only
Expand Down Expand Up @@ -119,7 +119,7 @@ def all_gather(
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
"""
payload: Optional[T] = None
exception : Optional[Exception] = None
exception : Exception = None
success = True
# determine if it is an executable function or data payload only
if callable(data_or_fn):
Expand All @@ -143,7 +143,7 @@ def all_gather(
total_list = [None] * dist.get_world_size(pg)
all_gather_object_enforce_type(pg, total_list, sync_obj)
# Each rank will throw RuntimeError in case of failure on any rank.
stage_name = cast(SyncPayload[T], total_list[0]).stage_name
stage_name: Optional[str] = cast(SyncPayload[T], total_list[0]).stage_name
exception_list: List[Tuple[int, Exception]] = []
ret_list: List[T] = []
error_msg: str = ""
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/elastic/metrics/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def add_value(self, metric_name: str, metric_value: int):


# pyre-fixme[9]: group has type `str`; used as `None`.
def configure(handler: MetricHandler, group: Optional[str] = None):
def configure(handler: MetricHandler, group: str = None):
if group is None:
global _default_metrics_handler
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/elastic/timer/file_based_local_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
file_path: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None
log_event: Callable[[str, Optional[FileTimerRequest]], None] = None
) -> None:
self._file_path = file_path
self._max_interval = max_interval
Expand Down
8 changes: 4 additions & 4 deletions torch/distributed/nn/api/remote_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __init__(
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
_module_interface_cls: Any = None,
):
"""
Expand Down Expand Up @@ -685,8 +685,8 @@ def __init__(
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(remote_device, module_cls, args, kwargs)

Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/optim/named_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
self,
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
optimizer_class: optim.Optimizer,
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
module: Optional[nn.Module] = None,
param_groups: Collection[Mapping[str, Any]] = None,
module: nn.Module = None,
*args,
**kwargs,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/pipeline/sync/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from torch import nn
from typing import List, Optional
from typing import List

__all__ = ["partition_model"]

def partition_model(
module: nn.Sequential,
balance: List[int],
devices: Optional[List[int]] = None):
devices: List[int] = None):
"""
Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions
the model across multiple GPU devices according the provided ``balance``
Expand Down
8 changes: 4 additions & 4 deletions torch/distributions/wishart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import warnings
from numbers import Number
from typing import Optional, Union
from typing import Union

import torch
from torch import nan
Expand Down Expand Up @@ -72,9 +72,9 @@ class Wishart(ExponentialFamily):

def __init__(self,
df: Union[torch.Tensor, Number],
covariance_matrix: Optional[torch.Tensor] = None,
precision_matrix: Optional[torch.Tensor] = None,
scale_tril: Optional[torch.Tensor] = None,
covariance_matrix: torch.Tensor = None,
precision_matrix: torch.Tensor = None,
scale_tril: torch.Tensor = None,
validate_args=None):
assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/experimental/meta_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):


def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Optional[Dict[str, torch.Tensor]] = None,
meta_args : Dict[str, torch.Tensor] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args)
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/passes/infra/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger.setLevel(logging.WARNING)

class Partition:
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
def __init__(self, id: int = None, nodes: Iterable[Node] = None):
self.id = id
self.nodes: Set[Node] = set(nodes) if nodes is not None else set()

Expand Down
4 changes: 2 additions & 2 deletions torch/fx/passes/pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
from typing import Callable, List
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,7 +76,7 @@ def wrapped_fn(gm):



def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None):
"""
Convenience wrapper for passes which need to be applied multiple times.
Expand Down
4 changes: 2 additions & 2 deletions torch/masked/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def amin(
@_apply_docstring_templates
def argmax(
input: Union[Tensor, MaskedTensor],
dim: Optional[int] = None,
dim: int = None,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
Expand All @@ -1328,7 +1328,7 @@ def argmax(
@_apply_docstring_templates
def argmin(
input: Union[Tensor, MaskedTensor],
dim: Optional[int] = None,
dim: int = None,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
Expand Down
Loading

0 comments on commit b4d91b1

Please sign in to comment.