Skip to content

Commit

Permalink
Enable mypy checking in compile_fx.py (pytorch#105830)
Browse files Browse the repository at this point in the history
This is part of the effort for issue pytorch#105230

Pull Request resolved: pytorch#105830
Approved by: https://github.com/eellison
chenyang78 authored and pytorchmergebot committed Aug 9, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 088e316 commit 40a15b5
Showing 2 changed files with 125 additions and 82 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -188,6 +188,7 @@ include_patterns = [
'torch/_inductor/codegen/common.py',
'torch/_inductor/codegen/wrapper.py',
'torch/_inductor/cudagraph_trees.py',
'torch/_inductor/compile_fx.py',
'torch/_inductor/lowering.py',
'torch/_inductor/metrics.py',
'torch/_C/_dynamo/**/*.py',
206 changes: 124 additions & 82 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,11 @@
import itertools
import logging
import sys
import unittest
import warnings

from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
from unittest import mock

from functorch.compile import min_cut_rematerialization_partition

@@ -42,7 +42,7 @@
from .virtualized import V

if config.is_fbcode():
from torch._inductor.fb.utils import time_and_log
from torch._inductor.fb.utils import time_and_log # type: ignore[import]
else:
# no-op decorator
def time_and_log(attr: str):
@@ -95,13 +95,13 @@ def get_expanded_dims(t):
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]


def index_expanded_dims(t, expanded_dims):
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
for expanded_dim in expanded_dims:
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
return t


def complex_memory_overlap(t):
def complex_memory_overlap(t: torch.Tensor) -> bool:
# if torch._debug_has_internal_overlap thinks this tensor potentially has
# memory overlap internally, let's dig deeper to find out whether it's true.
t = index_expanded_dims(t, get_expanded_dims(t))
@@ -157,7 +157,12 @@ def is_tf32_warning_applicable(gm: torch.fx.GraphModule):


@DebugContext.wrap
def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
def count_bytes_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
num_fixed: int = 0,
**kwargs,
):
shape_env = _shape_env_from_inputs(example_inputs)

graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
@@ -169,7 +174,7 @@ def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
return make_boxed_func(gm.forward)


def inner_compile_with_cpp_wrapper(inner_compile):
def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
@functools.wraps(inner_compile)
def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
"""
@@ -227,15 +232,15 @@ def materialize(x):
if param is not None
]
real_inputs = [
materialize(x) for x in [*params_flat, *V.real_inputs]
materialize(x) for x in (params_flat + V.real_inputs)
]
else:
real_inputs = [materialize(x) for x in V.real_inputs]

with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)

real_inputs = None
del real_inputs

# second pass
kwargs_patched = {**kwargs, "cpp_wrapper": True}
@@ -247,7 +252,7 @@ def materialize(x):
def fake_tensor_prop(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
force_allow_non_fake_inputs=False,
force_allow_non_fake_inputs: bool = False,
):
"""
If we can not detect fake mode from the context of inputs, create one.
@@ -262,9 +267,9 @@ def fake_tensor_prop(
ctx = (
contextlib.nullcontext()
if not force_allow_non_fake_inputs
else unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
)
with ctx:
with ctx: # type: ignore[attr-defined]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
@@ -279,15 +284,15 @@ def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
num_fixed=0,
is_backward=False,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
is_inference=False,
boxed_forward_device_index=None,
user_visible_outputs=frozenset(),
layout_opt=None,
num_fixed: int = 0,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
is_inference: bool = False,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
user_visible_outputs: FrozenSet[str] = frozenset(),
layout_opt: Optional[bool] = None,
):
if dynamo_utils.count_calls(gm.graph) == 0:
return make_boxed_func(gm.forward)
@@ -310,7 +315,7 @@ def compile_fx_inner(
}

compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
*graph_args, **graph_kwargs
*graph_args, **graph_kwargs # type: ignore[arg-type]
)

if aot_mode:
@@ -391,6 +396,7 @@ def compile_fx_inner(
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.get_current_callable()

@@ -451,14 +457,14 @@ def fx_codegen_and_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
num_fixed=0,
is_backward=False,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
is_inference=False,
user_visible_outputs=frozenset(),
layout_opt=None,
num_fixed: int = 0,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
is_inference: bool = False,
user_visible_outputs: FrozenSet[str] = frozenset(),
layout_opt: Optional[bool] = None,
) -> CompiledFxGraph:
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
@@ -522,6 +528,7 @@ def fx_codegen_and_compile(
if context is not None and context.output_strides is not None:
# Return the output strides to the caller via TracingContext
assert len(context.output_strides) == 0
assert graph.graph_outputs is not None
for out in graph.graph_outputs:
if hasattr(out, "layout"):
context.output_strides.append(
@@ -552,37 +559,46 @@ def fx_codegen_and_compile(
return compiled_graph


def clone_preserve_strides(x):
def clone_preserve_strides(x: torch.Tensor):
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
return torch.as_strided(buffer, x.size(), x.stride())


def copy_misaligned_inputs(new_inputs, check_inputs_idxs: Sequence[int]) -> None:
def copy_misaligned_inputs(
new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
) -> None:
for i in check_inputs_idxs:
if new_inputs[i].data_ptr() % ALIGNMENT:
new_inputs[i] = clone_preserve_strides(new_inputs[i])


def get_input_idxs_to_check(inputs, static_input_idxs) -> Sequence[int]:
def get_input_idxs_to_check(
inputs: Union[List[torch.Tensor], Sequence[int]],
static_input_idxs: Sequence[int],
) -> Sequence[int]:
def is_aligned(storage_offset, dtype):
return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0

return [
i
for i in range(len(inputs))
if isinstance(inputs[i], torch.Tensor)
and (
i not in static_input_idxs
or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
)
and inputs[i].device.type == "cuda"
]
ids_to_check = []
for i, input in enumerate(inputs):
if (
isinstance(input, torch.Tensor)
and (
i not in static_input_idxs
or not is_aligned(input.storage_offset(), input.dtype)
)
and input.device.type == "cuda"
):
ids_to_check.append(i)
return ids_to_check


def align_inputs_from_check_idxs(model, inputs_to_check: Sequence[int]):
def align_inputs_from_check_idxs(
model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
):
if len(inputs_to_check) == 0:
return model

@@ -593,16 +609,20 @@ def run(new_inputs):
return run


def align_inputs(model, inputs, static_input_idxs=()):
def align_inputs(
model: Callable[[List[torch.Tensor]], Any],
inputs: List[torch.Tensor],
static_input_idxs: Sequence[int] = (),
):
inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs)
return align_inputs_from_check_idxs(model, inputs_to_check)


@dynamo_utils.dynamo_timed
def cudagraphify(
model,
inputs,
static_input_idxs=(),
model: torch.fx.GraphModule,
inputs: List[torch.Tensor],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
stack_traces: List[Optional[str]],
@@ -613,6 +633,7 @@ def cudagraphify(
cudagraphify_impl as new_cudagraphify_impl,
)

cudagraphify_fn: Callable[..., Any]
if config.triton.cudagraph_trees:
cudagraphify_fn = functools.partial(
new_cudagraphify_impl,
@@ -640,23 +661,24 @@ def run(new_inputs):
return run


def remove_unaligned_input_idxs(inputs, static_input_idxs):
def remove_unaligned_input_idxs(
inputs: Union[List[torch.Tensor], Sequence[int]],
static_input_idxs: Sequence[int],
):
"""
We require all inputs to be aligned, so introduce a copy for any
that aren't.
"""
aligned_static_input_idxs = {
idx
for idx in static_input_idxs
if isinstance(inputs[idx], torch.Tensor)
and (inputs[idx].data_ptr() % ALIGNMENT) == 0
}
aligned_static_input_idxs = []
for idx, input in zip(static_input_idxs, inputs):
if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
aligned_static_input_idxs.append(idx)
if len(aligned_static_input_idxs) != len(static_input_idxs):
return aligned_static_input_idxs
return static_input_idxs


def static_input(x):
def static_input(x: torch.Tensor):
"""
Copy and input while preserving strides
"""
@@ -669,22 +691,30 @@ def static_input(x):
return torch.as_strided(buffer, x.size(), x.stride())


def index_expanded_dims_and_copy_(dst, src, expanded_dims):
def index_expanded_dims_and_copy_(
dst: torch.Tensor,
src: torch.Tensor,
expanded_dims: List[int],
):
"Index into expanded dimensions of both dst and src then copy_"
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dst.copy_(src)


def cudagraphify_impl(model, inputs, static_input_idxs=()):
def cudagraphify_impl(
model: torch.fx.GraphModule,
inputs: List[torch.Tensor],
static_input_idxs: Sequence[int] = (),
):
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
copy_misaligned_inputs(inputs, check_input_idxs)

assert isinstance(inputs, (list, tuple))
assert isinstance(inputs, list)

inps_expanded_dims = [
get_expanded_dims(x) if idx not in static_input_idxs else []
@@ -792,7 +822,7 @@ def is_saved_tensor(x):
def compile_fx_aot(
model_: torch.fx.GraphModule,
example_inputs_: List[torch.Tensor],
inner_compile=compile_fx_inner,
inner_compile: Callable[..., Any] = compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
):
config_patches = (
@@ -809,7 +839,7 @@ def compile_fx_aot(
"aot_inductor_output_path": code_hash(model_.code),
}

with unittest.mock.patch.object(_in_aot_compilation, "value", True):
with mock.patch.object(_in_aot_compilation, "value", True):
return compile_fx(
model_,
example_inputs_,
@@ -823,13 +853,13 @@ def compile_fx_aot(

def fw_compiler_freezing(
aot_autograd_model: torch.fx.GraphModule,
aot_example_inputs,
dynamo_model,
num_example_inputs,
inner_compile,
cudagraphs,
graph_id,
forward_device,
aot_example_inputs: List[torch.Tensor],
dynamo_model: torch.fx.GraphModule,
num_example_inputs: int,
inner_compile: Callable[..., Any],
cudagraphs: BoxedBool,
graph_id: int,
forward_device: BoxedDeviceIndex,
):
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze

@@ -845,7 +875,7 @@ def fw_compiler_freezing(
opt_model, preserved_arg_indices = freeze(
dynamo_model,
aot_autograd_model,
aot_example_inputs,
aot_example_inputs, # type: ignore[arg-type]
)

aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
@@ -866,7 +896,7 @@ def fw_compiler_freezing(
if i not in preserved_arg_indices:
params_flat[i] = None

with unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
opt_model,
aot_example_inputs,
@@ -889,17 +919,17 @@ def wrapper(args):
args.clear()
return optimized_function(args_new)

wrapper._boxed_call = True
wrapper._boxed_call = True # type: ignore[attr-defined]

return wrapper


def compile_fx(
model_: torch.fx.GraphModule,
example_inputs_: List[torch.Tensor],
inner_compile=compile_fx_inner,
inner_compile: Callable[..., Any] = compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
):
"""Main entrypoint to a compile given FX graph"""
if config_patches:
@@ -974,7 +1004,11 @@ def compile_fx(
)

@dynamo_utils.dynamo_timed
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
def fw_compiler_base(
model: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
is_inference: bool,
):
if is_inference:
# partition_fn won't be called
joint_graph_passes(model)
@@ -989,10 +1023,9 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
model_outputs, _ = pytree.tree_flatten(model_outputs_node.args)
num_model_outputs = len(model_outputs)

if torch._guards.TracingContext.get():
original_output_start_index = (
torch._guards.TracingContext.get().fw_metadata.num_mutated_inputs
)
context = torch._guards.TracingContext.get()
if context is not None and context.fw_metadata:
original_output_start_index = context.fw_metadata.num_mutated_inputs
else:
original_output_start_index = 0

@@ -1063,7 +1096,7 @@ def partition_fn(graph, joint_inputs, **kwargs):
)

@dynamo_utils.dynamo_timed
def bw_compiler(model: torch.fx.GraphModule, example_inputs):
def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
fixed = count_tangents(model)
return inner_compile(
model,
@@ -1107,7 +1140,7 @@ def get_patched_config_dict(config_patches=None):
return config.get_config_copy()


def _shape_env_from_inputs(inputs):
def _shape_env_from_inputs(inputs: List[torch.Tensor]):
shape_env = None
fake_mode = detect_fake_mode(inputs)

@@ -1153,7 +1186,11 @@ def graph_returns_tuple(gm: torch.fx.GraphModule):
return False


def make_graph_return_tuple(gm: torch.fx.GraphModule, inputs, compile_gm):
def make_graph_return_tuple(
gm: torch.fx.GraphModule,
inputs: List[torch.Tensor],
compile_gm: Callable[..., Any],
):
"""
Mutate gm so it returns a tuple. This is only needed for graphs
not created by torchdynamo that return non-tuples.
@@ -1189,6 +1226,7 @@ def __init__(self):
self.gm = gm

def forward(self, *args):
args: List[Any] = list(args)
return self.gm(*pytree.tree_unflatten(args, spec))

compiled_fn = compile_gm(GmWrapper(), inputs)
@@ -1201,7 +1239,11 @@ def wrapper(*args):
return wrapper


def handle_dynamo_export_graph(gm, inputs, compile_gm):
def handle_dynamo_export_graph(
gm: torch.fx.GraphModule,
inputs: List[torch.Tensor],
compile_gm: Callable[..., Any],
):
"""
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
convert that to a normal FX graph so inductor can compile it.

0 comments on commit 40a15b5

Please sign in to comment.