Skip to content

Commit

Permalink
reland [finishing colesbury's PR 100642] Guard on nn.Module dicts and…
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Sep 9, 2023
1 parent ed7f9ca commit d4230e5
Show file tree
Hide file tree
Showing 16 changed files with 358 additions and 124 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name,accuracy,graph_breaks
selecsls42b,pass,0
adv_inception_v3,pass,0
beit_base_patch16_224,pass,0
botnet26t_256,pass,0
Expand Down Expand Up @@ -48,6 +47,7 @@ resmlp_12_224,pass,0
resnest101e,pass,0
rexnet_100,pass,0
sebotnet33ts_256,pass,0
selecsls42b,pass,0
spnasnet_100,pass,0
swin_base_patch4_window7_224,pass,0
swsl_resnext101_32x16d,pass,0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name,accuracy,graph_breaks
selecsls42b,pass,8
adv_inception_v3,pass,8
beit_base_patch16_224,pass,8
botnet26t_256,pass,8
Expand Down Expand Up @@ -41,6 +40,7 @@ res2next50,pass,8
resmlp_12_224,pass,8
resnest101e,pass,8
rexnet_100,pass,8
selecsls42b,pass,8
spnasnet_100,pass,8
swin_base_patch4_window7_224,pass,8
swsl_resnext101_32x16d,pass,8
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name,accuracy,graph_breaks
selecsls42b,pass,0
adv_inception_v3,pass,0
beit_base_patch16_224,pass,0
botnet26t_256,pass,0
Expand Down Expand Up @@ -48,6 +47,7 @@ resmlp_12_224,pass,0
resnest101e,pass,0
rexnet_100,pass,0
sebotnet33ts_256,pass,0
selecsls42b,pass,0
spnasnet_100,pass,0
swin_base_patch4_window7_224,pass,0
swsl_resnext101_32x16d,pass,0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name,accuracy,graph_breaks
selecsls42b,pass,8
adv_inception_v3,pass,8
beit_base_patch16_224,pass,8
botnet26t_256,pass,8
Expand Down Expand Up @@ -43,6 +42,7 @@ resmlp_12_224,pass,8
resnest101e,pass,8
rexnet_100,pass,8
sebotnet33ts_256,pass,8
selecsls42b,pass,8
spnasnet_100,pass,8
swin_base_patch4_window7_224,pass,8
swsl_resnext101_32x16d,pass,8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ basic_gnn_gcn,pass,6
basic_gnn_gin,pass,0
basic_gnn_sage,pass,0
clip,pass,0
cm3leon_generate,pass,6
cm3leon_generate,pass,8
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ LearningToPaint,pass,8
Super_SloMo,pass,8
alexnet,pass,8
attention_is_all_you_need_pytorch,pass,8
basic_gnn_edgecnn,pass,23
basic_gnn_edgecnn,pass,25
basic_gnn_gcn,pass,14
basic_gnn_gin,pass,8
basic_gnn_sage,pass,8
Expand All @@ -20,7 +20,7 @@ hf_Albert,pass,7
hf_Bart,pass,7
hf_DistilBert,pass,7
hf_GPT2,pass,7
hf_Reformer,pass,27
hf_Reformer,pass,30
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,7
lennard_jones,pass,8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ basic_gnn_gcn,pass,6
basic_gnn_gin,pass,0
basic_gnn_sage,pass,0
clip,pass,0
cm3leon_generate,pass,6
cm3leon_generate,pass,8
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ LearningToPaint,pass,8
Super_SloMo,pass,8
alexnet,pass,8
attention_is_all_you_need_pytorch,pass,8
basic_gnn_edgecnn,pass,23
basic_gnn_edgecnn,pass,25
basic_gnn_gcn,pass,14
basic_gnn_gin,pass,8
basic_gnn_sage,pass,8
Expand All @@ -20,7 +20,7 @@ hf_Albert,pass,7
hf_Bart,pass,7
hf_DistilBert,pass,7
hf_GPT2,pass,7
hf_Reformer,pass,27
hf_Reformer,pass,30
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,7
lennard_jones,pass,8
Expand Down Expand Up @@ -51,5 +51,5 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,8
tts_angular,pass,10
vgg16,pass,8
vision_maskrcnn,fail_accuracy,40
vision_maskrcnn,fail_accuracy,42
yolov3,pass,10
5 changes: 3 additions & 2 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,9 @@ def warmup(fn, model, example_inputs, mode, niters=5):
# Use distributed wrapping as necessary
model = self.deepcopy_and_maybe_ddp(model)

if not hasattr(model, name):
model.name = name

self.init_optimizer(name, current_device, model.parameters())
with self.pick_grad(name, self.args.training):
ok, total = Stats.reset_counters()
Expand Down Expand Up @@ -2344,8 +2347,6 @@ def warmup(fn, model, example_inputs, mode, niters=5):
f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
)

if not hasattr(model, name):
model.name = name
results.append(experiment(model, example_inputs, **experiment_kwargs))
return " ".join(map(str, results))

Expand Down
30 changes: 8 additions & 22 deletions docs/source/torch.compiler_nn_module.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyTorch 2.0 NNModule Support
============================
PyTorch 2.0 nn.Module Support
=============================

**Author**: `Will Constable <https://github.com/wconstab>`_

Expand All @@ -8,12 +8,9 @@ arbitrary python classes, with the intent of producing faster code by making ass

This doc describes some of the tradeoffs or edge cases that come up due to this specialization.

NNModule Hooks Support
----------------------
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
they would simply be ignored in the compiled program. Indeed many users do not
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
for composing nn.Module hooks with `torch.compile`.
`nn.Module` Hooks Support
-------------------------
`torch.compile` now has partial support for forward and backward hooks on nn.Modules.

Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
Expand All @@ -25,11 +22,11 @@ unsupported by `torch.compile`.
`nn.Module.__call__` Hooks Usage and limitations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
or alter the hooks later, your use case should be supported by default.
and run forward/pre-forward hooks. `torch.compile` installs guards that detect added and removed hooks,
and will trigger a recompilation if the forward/pre-forward hooks change.

Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo
occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the
occur when accessing backward_hooks dicts, which is probably avoidable with some work. Graph-breaks also impact the
timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at
the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would
still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward
Expand All @@ -41,17 +38,6 @@ by allowing them to be called opaquely in the dynamo graph instead of traced int
currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could
introduce a significant performance regression, and additional work is required to improve this support.

**skip_nnmodule_hook_guards**
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
if any hook dict is changed after compilation.

If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
guards.

TODO: confirm if backward/pre_backward hooks are working or not and document accordingly

state_dict Hooks
~~~~~~~~~~~~~~~~
State dict hooks have not yet been supported in `torch.compile`.
Expand Down
164 changes: 110 additions & 54 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,7 @@ def forward(self, x):
module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
pre = m(data)
cnt.clear()
torch._dynamo.reset()

with torch._dynamo.optimize(cnt, nopython=False):
opt_pre = m(data)
Expand All @@ -1266,8 +1267,8 @@ def forward(self, x):
out1 = m(data)

out_post = m(data)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))

Expand Down Expand Up @@ -1741,7 +1742,111 @@ def fn(x):
)
)

@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_recompile(self):
# Modifying hooks should lead to a recompiation
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1

def compute_output_and_grad(m, x):
output = m(x)
output.sum().backward()
return x.grad

def forward_pre_hook(module: torch.nn.Module, inputs: Tuple[torch.Tensor]):
return (2 * inputs[0] + 1,)

def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
):
return 2 * output + 1

def backward_hook(module, grad_input, grad_output):
if len(grad_input) == 1:
return (grad_input[0] * 3,)
else:
return (grad_input[0] * 3, None)

def backward_pre_hook(module, grad_outputs):
return (grad_outputs[0] * 5,)

def run_test_case(hook_type, hook_func, expected_grad):
m = TestModule()
input = torch.ones(10, requires_grad=True)
cnt = torch._dynamo.testing.CompileCounter()
opt = torch._dynamo.optimize(cnt)(compute_output_and_grad)

grad1 = opt(m, input)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(grad1, torch.full_like(grad1, 2))

input.grad = None
handle = getattr(m, hook_type)(hook_func)
grad2 = opt(m, input)
frame_count2 = cnt.frame_count
# Some backward hooks lead to graph breaks so frame_count may be 2 or 3
self.assertGreaterEqual(frame_count2, 2)
self.assertEqual(grad2, torch.full_like(grad2, expected_grad))

# Running again should not recompile
opt(m, input)
self.assertEqual(cnt.frame_count, frame_count2)

# Removing handle should lead to original result
input.grad = None
handle.remove()
grad3 = opt(m, input)
self.assertEqual(grad1, grad3)

run_test_case("register_forward_pre_hook", forward_pre_hook, 4)
run_test_case("register_forward_hook", forward_hook, 4)
run_test_case("register_backward_hook", backward_hook, 6)
run_test_case("register_full_backward_hook", backward_hook, 6)
run_test_case("register_full_backward_pre_hook", backward_pre_hook, 10)

def test_unspecialized_nn_module(self):
# This test is little confusing because of combination of
# nn_module_guard and unspecialized nn module variable.

# The graph break in forward causes two graphs
# 1) The first graph has self.relu which introduces a nn_module_guard
# 2) The second graph first assumes self to be NNModuleVariable, but the
# restarts the analysis with self mapping to
# UnSpecializedNNModuleVariable, on witnessing self.a += 1.

# Now, when we run the compiled mod the first time, it changes the value
# of self.a. This is fine for the first run. But, when we run the
# compiled module again, the first graph recompiles. This is because
# self.a has changed, changing the ma_version_tag, causing
# nn_module_guard to fail.

# At this point, we might feel that this is doomed as we will always
# keep recompiling on the first graph. But, then Dynamo has already
# marked the self to be UnspecializedNNModuleVariable (because of self.a
# in the second graph), and therefore during the recompilation, we do
# not introduce any nn_module_guard. So, in all we have just one
# recompilation.
class Mock(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 5
self.relu = torch.nn.ReLU()

def forward(self, x):
z = self.relu(x)
torch._dynamo.graph_break()
self.a += 1
return z * self.a

mod = Mock()
cnt = torch._dynamo.testing.CompileCounter()
opt = torch.compile(mod, backend=cnt)

for _ in range(5):
opt(torch.randn(4))

self.assertEqual(cnt.frame_count, 4)

def test_hooks_outer(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1788,7 +1893,6 @@ def guard_fail_fn(failure):
the eval_frame entrypoint to Module.__call__?
"""

@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_inner(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1833,7 +1937,7 @@ def guard_fail_fn(failure):
handle.remove()
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 7)
self.assertTrue("forward_hooks.keys" in failure_reason)
self.assertTrue("__nn_module_guard" in failure_reason)
self.assertEqual(cc.frame_count, 1 + 1)
self.assertEqual(cc.op_count, 6 + 4)

Expand All @@ -1853,55 +1957,7 @@ def new_forward_hook(
m._forward_hooks[handle.id] = new_forward_hook
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 16)
self.assertTrue("___check_obj_id(L['m']._forward_hooks" in failure_reason)

@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
def test_hooks_skip_guards(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1

m = TestModule()

def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1

handle = m.register_forward_hook(forward_hook)

def outer_func(tensor):
x = tensor * 2 + 1
y = m(x)
return y

inp = torch.tensor(1.0, requires_grad=True)

failure_reason = None

def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]

cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_func = torch._dynamo.optimize(
guard_fail_fn=guard_fail_fn,
backend=cc,
)(outer_func)

m = TestModule()
handle = m.register_forward_hook(forward_hook)
failure_reason = None
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
self.assertEqual(cc.op_count, 6)

# if we remove the hook, dynamo shouldn't notice
handle.remove()
self.assertNotEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
self.assertTrue("__nn_module_guard" in failure_reason)

def _forward_hook_test_helper(self, model):
forward_handles = {}
Expand Down
Loading

0 comments on commit d4230e5

Please sign in to comment.