forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaot_autograd.py
4221 lines (3780 loc) · 190 KB
/
aot_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import dataclasses
import itertools
import logging
import warnings
import pprint
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, NewType
from unittest.mock import patch
from functorch import make_fx
import torch
import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._subclasses.meta_utils import safe_is_leaf
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo import compiled_autograd
from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code, preserve_rng_state
from torch._guards import detect_fake_mode, tracing
from torch._prims_common import CUDARngStateHelper
from torch._logging import getArtifactLogger
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_concrete_int, fx_placeholder_vals
from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn.utils import stateless
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
from . import config
from .partitioners import default_partition
from torch._guards import TracingContext, DuplicateInputs, Source
log = logging.getLogger(__name__)
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
MutationType = Enum(
"MutationType", ("none", "metadata_only", "data", "data_and_metadata")
)
OutputType = Enum(
"OutputType", (
# output is not an alias
"non_alias",
# output aliases an input
"alias_of_input",
# output **is** an input tensor
"is_input",
# output has a ._base tensor, which is a graph intermediate.
# We need to return its ._base as a graph output,
# so its requires_grad info is populated correctly.
# Instructs the runtime code to regenerate the current output
# from a base tensor, graph_intermediates[base_idx]
"alias_of_intermediate_save_as_output",
# Same as above; but we don't need to explicitly add its ._base
# as a graph output, because it already **is** a graph output.
"alias_of_intermediate",
# Same as above; but the output's ._base is **already** a user output.
# Instructs the runtime code to regenerate the current output from
# a base tensor, user_outputs[base_idx]
"alias_of_intermediate_base_is_user_output",
# See Note [Intermediate Bases Optimization]
"unsafe_view_alias",
# output is an alias, but has a custom autograd.Function backward.
# In this case, we don't want to do view-replay, since we won't be able to replay the custom function.
# Instead, we'll treat this output "normally", and trace its backward into the graph.
"custom_function_view",
)
)
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
dict(zip(c, x))
),
)
def partial_asdict(obj: Any) -> Any:
if dataclasses.is_dataclass(obj):
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
elif isinstance(obj, (list, tuple)):
return obj.__class__([partial_asdict(item) for item in obj])
elif isinstance(obj, dict):
return {k: partial_asdict(v) for k, v in obj.items()}
else:
return obj
aten = torch.ops.aten
# This global counter increments every time we compile a graph with
# AOTAutograd. You can use this to correlate runtime error messages
# with compile time (e.g., if you get an error at runtime saying
# compiled graph 3 failed, you can set a breakpoint at compile time
# for this graph number to investigate further at compile time.)
#
# NB: this is different from get_aot_compilation_context, which tracks
# each underlying graph that is compiled. In contrast, AOT_COUNTER
# corresponds to top-level invocations of aot_module/aot_function;
# one counter is allocated per entire compiled block (but this block
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
AOT_COUNTER = itertools.count()
KNOWN_TYPES = tuple(
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
)
# Set up hooks so that during backward the fx's stack_trace is properly set
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def iter_graph(roots):
if not roots:
return
seen = set()
q = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def get_callback(saved_stack_):
def callback():
global callback_set
fx_traceback.set_stack_trace(saved_stack_)
callback_set = False
return callback
def get_prehook(stack_, seq_nr):
def prehook(grad_output):
global callback_set
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback(
get_callback(fx_traceback.format_stack())
)
callback_set = True
fx_traceback.set_stack_trace(stack_)
fx_traceback.set_grad_fn_seq_nr(seq_nr)
return prehook
def get_posthook(special_stack_, seq_nr):
def posthook(grad_input, grad_output):
fx_traceback.set_stack_trace(special_stack_)
fx_traceback.reset_grad_fn_seq_nr()
return posthook
for node in iter_graph(roots):
forward_node_stack = node.metadata.get("traceback_", [])
node.register_prehook(get_prehook(forward_node_stack,
node._sequence_nr()))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to multiple use of tensor around:"
)
node.register_hook(get_posthook(special_stack, node._sequence_nr()))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation
# that are external to the graph (they show up as side effects in some way when you run the graph).
#
# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions
# and what they're compiled graphs looks like.
# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.
#
# Note [AOT Autograd: input data mutations]
#
# If we compile a function that mutates inputs, then those input mutations are real side effects
# that a user expects to see after running the compiled graph.
# However, the graph that we want to send to a backend needs to be *entirely* functional.
# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile
# but we update the graph to return (updated_inputs, user_outputs).
# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.
#
# Example: original user code:
# def f(x):
# x.mul_(2)
# out = x.mul(3)
# return out
#
# After AOT Autograd compiles, we end up with a:
# (a) compiled graph
# (b) autograd.Function.forward() method, that executes the compiled graph
# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue
#
# The output of (a, b, c) are all written below.
#
# def compiled_forward_graph(x):
# x_updated = x.mul(2)
# out = x_updated.mul(3)
# return x_updated, out
#
# # x_updated gets a gradient in the compiled backward
# def compiled_backward_graph(grad_x_updated, grad_out):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# x_updated, out = compiled_forward_graph(x)
# return x_updated, out
#
# def compiled_wrapper(x):
# x_updated, out = autograd.Function.apply(x)
# x.copy_(x_updated)
# return out
#
# Another important thing to note is that updated inputs (due to data mutations) *do* participate
# in the compiled backward graph! Since the compiled forward graph gets N extra outputs
# (due to updated inputs showing up as graph outputs),
# The compiled backward gets an additional N inputs.
# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input
# back to the original input.
# Note [AOT Autograd: input metadata mutations]
#
# For the same reason as input mutations, we also don't put input metadata mutations in the graph.
# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph
#
# Example: original user code:
# def f(x):
# x.t_()
# out = x.mul(3)
# return out
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(x):
# x_updated = x.t()
# out = x_updated.mul(3)
# return x_updated, out
#
# # x_updated does *not* get a gradient in the compiled backward
# def compiled_backward_graph(grad_out):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# x_updated, out = compiled_forward_graph(x)
# return x_updated, out
#
# def compiled_wrapper(x):
# x_updated, out = autograd.Function.apply(x)
# x.as_strided_(x_updated)
# return out
# Note [AOT Autograd: outputs aliasing inputs or intermediates!]
#
# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!
# Why?
# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.
# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,
# in an epilogue.
# For outputs that alias inputs, we do the following:
# (a) *still* return the aliased output as a graph output
# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.
#
# For outputs that alias *intermediates*, we do the following:
# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward
# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).
# You might wonder why we return the aliased output directly in the graph (and making the graph compute it),
# only to not return it and instead generate a fresh alias off of the intermediate,
# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:
# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call
# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.
# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,
# when it has a different set of strides.
# By including the view op directly in the graph, inductor takes that into account when deciding what memory format
# the graph intermediate should be.
#
# Another important thing to note is how our traced backward() graph handles aliases.
# (this applies to outputs aliasing inputs, outputs aliasing intermediates,
# *and* updated inputs returned in the compiled forward due to metadata-only mutations).
# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph
# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly
# at the end of the forward.
#
# Example: original user code:
# def f(x):
# out1 = x.t()
# intermediate = x.mul(2)
# out2 = intermediate.view(-1)
# return out1, out2
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(x):
# out1 = x.t()
# intermediate = x.mul(2)
# out2 = intermediate.view(-1)
# # the compiled graph also returns the intermediate
# return out1, out2, intermediate
#
# # intermediate gets a gradient in the compiled backward.
# # both output aliases (out1 and out2) do not.
# def compiled_backward_graph(grad_intermediate):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# out1, out2, intermediate = compiled_forward_graph(x)
# return out1, out2, intermediate
#
# def compiled_wrapper(x):
# out1, out2, intermediate = autograd.Function.apply(x)
# # regenerate out1 from the input
# out1_regenerated = out1._view_func(x)
# # regenerate out1 from the intermediate
# out2_regenerated = out2._view_func(intermediate)
# return out1_regenerated, out2_regenerated
# Note [AOT Autograd: mutations to inputs that alias other inputs]
#
# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.
# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.
# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias
# given the mutation that occurred.
#
# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input
# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base
# inside of the compiled function.
#
# This logic is fully encapsulated in aot_wrapper_synthetic_base()
#
# Example: original user code:
# def f(x, x_view):
# x.mul_(2)
# out = x * x_view
# return out
# f(x, x.view(-1))
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(base)
# x = generate_x(base)
# x_view = generate_x_view(base)
# x_updated = x.mul(2)
# x_view_updated = x_updated.view(-1)
# out = x_updated * x_view_udpated
# return x_updated, out
#
# # The calling convention change from (aliases) -> (base) happens
# # *outside* of the autograd.Function.forward().
# # That means the forward() only has 1 input (base),
# # and the backward() only has 1 output (grad_base)
# def compiled_backward_graph(grad_out):
# grad_base = ...
# return grad_base
#
# def autograd.Function.forward(base):
# x_updated, out = compiled_forward_graph(base)
# return x_updated, out
#
# # The compiled wrapper is where we create synthetic bases.
# # The info on which inputs are mutated is also tracked *before* synthetic base creation.
# def compiled_wrapper(x, x_view):
# base = merge_view_inputs(x, x_view)
# x_updated, out = autograd.Function.apply(base)
# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.
# x.copy_(x_updated)
# return out
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This class stores info about every user output.
@dataclass(frozen=True)
class OutputAliasInfo:
# Tells us if this output is:
# (1) a regular (non-aliased) output
# (2) an alias of a forward input
# (3) **is** a forward input (special case of "alias_of_input")
# (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)
# (5) an alias of an intermediate, that explicitly requires returning the intermediate
# as a graph output
# (6) an alias of an intermediate, where that intermediate is also a user output
output_type: OutputType
# The raw type of the output (torch.Tensor, SymInt, etc)
raw_type: type
# If (1) above, then
# - base_idx is None
# If (2) or (3) above, then
# - Tells us that the base of this alias is user_fwd_input[base_idx]
# (This is an index into the inputs *before* we make synthetic bases)
# If (4) or (5) above, then
# - Tells us that the base of this alias is output_graph_intermediates[base_idx]
# here, this refers to the index of the *direct* traced
# If (6) above, then:
# - Tells us that the base of this alias is output_user_fwds[base_idx]
# here, this refers to the index of the *direct* traced
base_idx: Optional[int]
# If it is a Tensor, what the dynamic dims are (otherwise is None)
dynamic_dims: Optional[Set[int]]
# This class tells us info about user inputs.
@dataclass(frozen=True)
class InputAliasInfo:
is_leaf: bool
mutates_data: bool
mutates_metadata: bool
# This class encapsulates all aliasing + mutation info we need about the forward graph
# See a more detailed overview of the edge case handling at
# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
@dataclass(eq=False)
class ViewAndMutationMeta:
# length = # user inputs
# This gives us info about every input, and what sort of mutation happened to it (if any)
input_info: List[InputAliasInfo]
# length = # user outputs
# This gives us info about every output (mostly around whether it aliases other tensors)
output_info: List[OutputAliasInfo]
# length = # mutated inps + # user outputs
# For every output *and* mutated input returned from the forward,
# tells us whether or not the output should require gradients or not
requires_grad_info: List[bool]
# length = the number of intermediate bases appended as outputs to the end of the forward graph.
# Note: this is not necessarily the same thing as:
# len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])
# Because outputs might share a ._base, or an output's ._base might itself be
# another user output (in both cases, we won't redundantly append bases to the end of the graph)
num_intermediate_bases: int
# For inference only: instructs us to keep data-only input mutations directly in the graph
keep_input_mutations: int
# length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)
# + (# intermediate bases)
# These are the FakeTensor (or potential SymInt) outputs that we traced from our
# metadata pass of the user's forward function.
# Their only use today is to pass them as a best-guess for tangents when tracing the joint.
# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
# pass once, and re-use the output throughout AOTAutograd
traced_tangents: List[Any]
num_symints_saved_for_bw: Optional[int] = None
def __post_init__(self):
mutated_inp_indices = [
i for i, m in enumerate(self.input_info) if m.mutates_metadata or m.mutates_data
]
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
# handling data-only mutations, because we keep them directly in the graph.
mutated_inp_runtime_indices = [
i for i, m in enumerate(self.input_info) if m.mutates_metadata or (not self.keep_input_mutations and m.mutates_data)
]
aliased_out_indices = [
i
for i, m in enumerate(self.output_info)
if m.output_type not in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
]
unsafe_view_out_indices = [
i for i, m in enumerate(self.output_info) if m.output_type is OutputType.unsafe_view_alias
]
self.mutated_inp_indices = mutated_inp_indices
# This is pre-computed in post_init for perf.
# It contains the index of every element
# of input_info that corresponds to a mutation (data or metadata or both)
self.mutated_inp_runtime_indices = mutated_inp_runtime_indices
# This is pre-computed for perf.
# It contains the index of every element
# of output_info that corresponds to an alias (either of an input or intermediate)
self.aliased_out_indices = aliased_out_indices
self.unsafe_view_out_indices = unsafe_view_out_indices
self.num_outputs = len(self.output_info)
self.num_outputs_non_aliased = len(
[x for x in self.output_info
if x.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]]
)
self.num_outputs_aliased_to_inputs = len(
[
x
for x in self.output_info
if x.output_type in [
OutputType.alias_of_input,
OutputType.is_input,
]
]
)
self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices)
self.num_outputs_aliased_to_intermediates = len(
[
x
for x in self.output_info
if x.output_type in [
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
OutputType.alias_of_intermediate_base_is_user_output,
]
]
)
self.num_outputs_aliased = (
self.num_outputs_aliased_to_inputs + self.num_outputs_aliased_to_intermediates
)
self.num_mutated_data_inputs = len(
[x for x in self.input_info if x.mutates_data]
)
self.num_mutated_metadata_inputs = len(
[
x
for x in self.input_info
if x.mutates_metadata
]
)
self.num_mutated_metadata_only_inputs = len(
[
x
for x in self.input_info
if not x.mutates_data and x.mutates_metadata
]
)
self.num_mutated_inputs = self.num_mutated_data_inputs + self.num_mutated_metadata_only_inputs
self.dynamic_outputs = any(
o.dynamic_dims for o in self.output_info
)
self.is_rng_op_functionalized = config.functionalize_rng_ops
# All of the above metadata is collected by tracing the fw function.
# However, extra outputs for rng offsets behave differently. Both fwd
# and bwd graphs have their own outputs for the total consumed offsets.
# Unlike mutated inputs, we don't have to worry about sending the right
# set of tensors between fwd and bwd. Fwd and bwd offsets are
# independent and simpler to handle. Therefore, we track them
# separately.
self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0
# Our forward() returns both (mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints)
self.num_forward_returns = self.num_mutated_inputs + self.num_outputs + self.num_intermediate_bases
# In case of functionalization of rng ops, the fw_module returns one
# additinal output for rng offset. This rng offset is used right
# away to advance the rng state, and is not passed on to the raw
# outputs. However, we need to know the exact boundary to identify
# which tensors to be saved for the bwd graph. num_forward captures
# this information.
self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset
@property
def tensors_saved_for_backwards_slice(self):
assert self.num_symints_saved_for_bw is not None
if self.num_symints_saved_for_bw > 0:
return slice(self.num_forward, -self.num_symints_saved_for_bw)
else:
return slice(self.num_forward, None)
@property
def symints_saved_for_backwards_slice(self):
assert self.num_symints_saved_for_bw is not None
if self.num_symints_saved_for_bw > 0:
return slice(-self.num_symints_saved_for_bw, None)
else:
return slice(0, 0) # empty slice
def __eq__(self, other):
if not isinstance(other, ViewAndMutationMeta):
return NotImplemented
return (self.input_info == other.input_info and
self.output_info == other.output_info and
self.requires_grad_info == other.requires_grad_info and
self.num_intermediate_bases == other.num_intermediate_bases and
self.keep_input_mutations == other.keep_input_mutations and
self.is_rng_op_functionalized == other.is_rng_op_functionalized and
self.num_outputs_rng_offset == other.num_outputs_rng_offset and
len(self.traced_tangents) == len(other.traced_tangents) and
all(x.shape == y.shape and x.dtype == y.dtype for x, y, in zip(self.traced_tangents, other.traced_tangents)))
# This class exists because:
# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs
# - we only care about the metadata on those aliases, so we can regenerate them.
# We do not want them to participate in the autograd.Function.
# We do that by wrapping them in an opaque class, so the autograd.Function
# does not know to treat them as tensors.
@dataclass(frozen=True)
class TensorAlias:
alias: torch.Tensor
def has_same_metadata(t1, t2):
return (
t1.size() == t2.size()
and t1.stride() == t2.stride()
and t1.storage_offset() == t2.storage_offset()
and t1.storage_offset() == t2.storage_offset()
and t1.is_conj() == t2.is_conj()
and t1.is_neg() == t2.is_neg()
)
def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad):
# Try to do view-replay if possible.
# fall back to .as_strided() if we can't.
if target_meta_tensor._base is not None:
# The base that we want to replay our view off of might have a different shape than the view's original base.
b = target_meta_tensor._base
abt = aliased_base_tensor
# Don't unnecessarily call as_strided if nothing changed; as_strided's
# backward is poorly implemented and slow
if abt is not b and (
abt.size() != b.size() or
abt.stride() != b.stride() or
abt.storage_offset() != b.storage_offset()
):
reshaped_base_tensor = aliased_base_tensor.as_strided(
b.size(), b.stride(), b.storage_offset()
)
else:
reshaped_base_tensor = aliased_base_tensor
out = target_meta_tensor._view_func(reshaped_base_tensor)
# This shape mismatch can happen due to a bug in inplace/view handling in autograd.
# Try putting a breakpoint here and running
# `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`
# Also, https://github.com/pytorch/pytorch/issues/49825
#
# As a stopgap, we'll fall back to as_strided.
if out is not None and out.shape == target_meta_tensor.shape:
if aliased_base_tensor.requires_grad and not target_requires_grad:
out = out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
out.requires_grad_(True)
return out
size = target_meta_tensor.size()
stride = target_meta_tensor.stride()
storage_offset = target_meta_tensor.storage_offset()
if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex():
aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex():
aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
else:
aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
# For outputs aliasing inputs, we need to check if the requires-gradness has changed.
if aliased_base_tensor.requires_grad and not target_requires_grad:
aliased_out = aliased_out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
aliased_out.requires_grad_(True)
return aliased_out
def to_fun(t):
if isinstance(t, Tensor):
out = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, out)
return out
else:
return t
def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)
def _get_hints(exprs):
"""
Get the hints of a list/tuple of int/SymInt.
"""
if isinstance(exprs, (list, tuple)):
return type(exprs)(_get_hints(e) for e in exprs)
elif isinstance(exprs, torch.SymInt):
return exprs.node.shape_env.size_hint(exprs.node.expr)
else:
return exprs
# This is a version of functionalization that is specifically designed
# for the AOTAutograd use case.
#
# Unlike functorch's variant, this doesn't use the functorch level system,
# instead it directly uses PyTorch's conventional dispatcher to hit the
# functionalization key. In particular, this means that FunctionalTensorWrapper
# can have autograd data stored directly on it.
#
# In typical AOTAutograd usage, the dispatch key order will look like:
#
# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
# outer tensor inner tensor
#
# Returns:
# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and
# The list of outputs from the forward, but **only** the outputs that we need
# to pass in as tangents into the backward.
# Specifically, aliased outputs from the forward get regenerated, and don't participate
# in the compiled backward function.
def run_functionalized_fw_and_collect_metadata(
f,
*,
keep_input_mutations: bool
) -> ViewAndMutationMeta:
memo = {}
def to_fun(t):
if isinstance(t, Tensor):
if t in memo:
return memo[t]
r = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, r)
memo[t] = r
return r
else:
return t
def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)
@wraps(f)
def inner(*flat_args):
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)
input_info: List[InputAliasInfo] = []
output_info: List[OutputAliasInfo] = []
input_requires_grad_info: List[bool] = []
output_requires_grad_info: List[bool] = []
flat_f_args = pytree.tree_map(to_fun, flat_args)
torch._enable_functionalization(reapply_views=True)
try:
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_outs = f(*flat_f_args)
finally:
torch._disable_functionalization()
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for (i, (arg, f_arg)) in enumerate(zip(flat_args, flat_f_args)):
if not isinstance(arg, Tensor):
new_arg = arg
else:
torch._sync(f_arg)
new_arg = torch._from_functional_tensor(f_arg)
if arg is not new_arg:
if StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()):
mutates_data = False
mutates_metadata = True
else:
mutates_data = True
mutates_metadata = torch._functionalize_has_metadata_mutation(f_arg)
# Only track requires_grad info on *mutated* inputs,
# because they show up in the autograd.Function.forward as outputs
input_requires_grad_info.append(
isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
)
else:
mutates_data = False
mutates_metadata = False
input_info.append(InputAliasInfo(
is_leaf=isinstance(arg, torch.Tensor) and safe_is_leaf(arg),
mutates_data=mutates_data,
mutates_metadata=mutates_metadata
))
# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediiate,
# We need to make sure our graph returns the _base as a graph output, and we manually recreate the view
# to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad
# on the base tensor, but we are obligated to properly set requires-gradness on the real output.
num_mutated_inps = len(
[x for x in input_info if x.mutates_data or x.mutates_metadata]
)
inp_storage_refs = {
StorageWeakRef(inpt.untyped_storage()): idx
for idx, inpt in enumerate(flat_f_args)
if isinstance(inpt, torch.Tensor)
}
# We need inp tensor id's to be able to tell if an outputs **are** inputs.
inp_tensor_ids = {
id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)
}
# We need output tensor id's to tell if any output._base` attributes **are** other outputs.
# (This is also a dict because we need to know that output's index, so we can regenerate
# the alias from it).
out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
# Keep track of which outputs alias other outputs
out_tensor_alias_counts = collections.defaultdict(int)
out_storage_to_tensors = collections.defaultdict(set)
for o in flat_f_outs:
if isinstance(o, torch.Tensor):
curr_storage = StorageWeakRef(o.untyped_storage())
out_tensor_alias_counts[curr_storage] += 1
out_storage_to_tensors[curr_storage].add(o)
# maps the id of an intermediate base to its index in the output of the compiled forward
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
intermediate_bases: List[torch.Tensor] = []
for o in flat_f_outs:
curr_storage = None if not isinstance(o, torch.Tensor) else StorageWeakRef(o.untyped_storage())
outs_with_identical_metadata_that_require_grad = [] if not isinstance(o, torch.Tensor) else [
curr for curr in out_storage_to_tensors[curr_storage]
if has_same_metadata(o, curr) and curr.requires_grad and o is not curr
]
is_result_of_custom_autograd_fn = False
if isinstance(o, torch.Tensor):
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) autograd fns
if type(o.grad_fn).__name__ == "CppFunction":
is_result_of_custom_autograd_fn = True
if isinstance(o.grad_fn, torch.autograd.function.BackwardCFunction):
is_result_of_custom_autograd_fn = True
if not isinstance(o, torch.Tensor):
output_type = OutputType.non_alias
base_idx = None
elif curr_storage in inp_storage_refs and o.grad_fn is not None \
and is_result_of_custom_autograd_fn:
output_type = OutputType.custom_function_view
base_idx = None
elif curr_storage in inp_storage_refs:
base_idx = inp_storage_refs[curr_storage]
is_input_tensor = id(o) in inp_tensor_ids
if is_input_tensor:
output_type = OutputType.is_input
else:
output_type = OutputType.alias_of_input
# We only need to handle the intermediate base case when both
# the intermediate base and the output require gradients.
# See Note [AOT Autograd: outputs aliasing inputs or intermediates!]
elif (
o._base is not None
and o.requires_grad
and o._base.requires_grad
):
if out_tensor_alias_counts[curr_storage] == 1:
# Note [Intermediate Bases Optimization]
# Normally if we have an output that aliases an intermediate,
# we need to add the extra "intermediate base" logic further down
# to prevent autograd from yelling at us if the user later tries to
# mutate that output.
# However, the common case here is if we have an output that aliases an intermediate,
# but doesn't alias any other outputs.
# In that case, autograd shouldn't have to worry about the aliasing at all
# (if that output is mutated, there are no other live aliases for autograd to worry about).
# The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs.
# So as an optimization, we won't do intermediate base handling in this case.
# Instead, we'll hide the aliasing from autograd using aten._unsafe_view().
output_type = OutputType.unsafe_view_alias
base_idx = None
else:
# First, check if o's ._base is an existing output
maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None)
if maybe_existing_out_idx is not None:
# Special case where the output is an alias of a graph intermediate, but that intermediate
# is itself also a user output.
output_type = OutputType.alias_of_intermediate_base_is_user_output
base_idx = maybe_existing_out_idx
else:
# Next, check if o's ._base is an intermediate base that we already returned
maybe_existing_base_output_idx = intermediate_base_tensor_id_to_output_idx.get(
id(o._base), None
)
if maybe_existing_base_output_idx is not None:
output_type = OutputType.alias_of_intermediate
base_idx = maybe_existing_base_output_idx
else:
# Otherwise, take o._base and explicitly return it as an output in the compiled graph
new_out_idx = len(intermediate_bases)
base_idx = new_out_idx
# Indicate to the logic later on (when we trace the joint)
# that this particular output should get it's ._base appended to the forward graph outputs
output_type = OutputType.alias_of_intermediate_save_as_output
intermediate_base_tensor_id_to_output_idx[id(o._base)] = new_out_idx
intermediate_bases.append(o._base)
elif (
# See https://github.com/pytorch/pytorch/issues/100348 for this case.
# This protects against the specific case where a user fn returns (output, output.detach())
out_tensor_alias_counts[curr_storage] > 1
and len(outs_with_identical_metadata_that_require_grad) > 0
and not o.requires_grad
):
assert len(outs_with_identical_metadata_that_require_grad) > 0
# In theory we could use any of these tensors to regenerat the aliased outputs from,
# since they all alias each other and have identical metatadata
out_alias = outs_with_identical_metadata_that_require_grad[0]
existing_out_idx = out_tensor_ids[id(out_alias)]
output_type = OutputType.alias_of_intermediate_base_is_user_output
base_idx = existing_out_idx
else:
output_type = OutputType.non_alias
base_idx = None
if isinstance(o, torch.Tensor):
dynamic_dims = {i for i, s in enumerate(o.shape) if not is_concrete_int(s)}
else:
dynamic_dims = None
out_info = OutputAliasInfo(
output_type=output_type,
raw_type=type(o),
base_idx=base_idx,
dynamic_dims=dynamic_dims,
)
output_info.append(out_info)
output_requires_grad_info.append(
isinstance(o, torch.Tensor) and o.requires_grad
)
# Our autograd.Function.forward returns both mutated inputs and outputs,
# so we need grad info on all of them.
requires_grad_info = input_requires_grad_info + output_requires_grad_info
assert len(requires_grad_info) == len(output_info) + len(
[x for x in input_info if x.mutates_data or x.mutates_metadata]
)
# This analysis function returns *only* the outputs that are meant to be tangents to the backwards.
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
# are *regenerated* later, and not used directly in the autograd graph
f_input_tangents = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutates_data
]
f_output_tangents = [
o
for o, info in zip(flat_f_outs, output_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
]
# intermediate bases are also included in the backward graph
f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
traced_tangents = pytree.tree_map(from_fun, f_tangents)
metadata = ViewAndMutationMeta(
input_info=input_info,
requires_grad_info=requires_grad_info,
output_info=output_info,
num_intermediate_bases=len(intermediate_bases),
keep_input_mutations=keep_input_mutations,
traced_tangents=traced_tangents,
)
return metadata
return inner
@dataclass
class BackwardSignature:
"""
Provides information about the backward section of an exported
joint forward-backward graph.
For a particular fx GraphModule, this class contains information on:
(1) A mapping from each gradient (backwards output) to the parameter
it corresponds to (forward input)
(2) A mapping from each gradient (backwards output) to the user input
it corresponds to (forward input)
(3) Which of the forward outputs corresponds to the loss, that we backprop on.
Each string name is the `node.name` of the corresponding node in the fx graph.
"""
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
loss_output: str
GraphOutputName = NewType('GraphOutputName', str)
GraphInputName = NewType('GraphInputName', str)
FQN = NewType('FQN', str)
@dataclass
class GraphSignature: