forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathverifier.py
318 lines (261 loc) · 12.2 KB
/
verifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import operator
from collections.abc import Iterable
from typing import Any, final, List, Set, Tuple, Type
import torch
from torch._ops import HigherOrderOperator, OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
PRESERVED_META_KEYS: Set[str] = {
"val",
"stack_trace",
"source_fn_stack",
}
class SpecViolationError(Exception):
pass
def is_functional(op: OpOverload) -> bool:
return not op._schema.is_mutable
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
# TODO(angelayi): remove this in favor of _check_val
return _check_val(node)
def _check_val(node: torch.fx.Node) -> None:
def _check_correct_val(val):
if val is None:
return True
elif isinstance(val, (int, bool, str, float)):
return True
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
return True
elif isinstance(val, (FakeTensor, torch.Tensor)):
return True
elif isinstance(val, (SymInt, SymFloat, SymBool)):
return True
elif isinstance(val, Iterable):
return all(_check_correct_val(x) for x in val)
return False
if "val" not in node.meta:
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
val = node.meta["val"]
if not _check_correct_val(val):
raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
class Verifier:
def __call__(self, gm: GraphModule) -> None:
self.check_valid(gm)
def allowed_builtin_ops(self) -> List:
return [operator.getitem]
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
return (OpOverload, HigherOrderOperator)
def check_valid_op(self, op) -> None:
if op not in self.allowed_builtin_ops():
if not isinstance(op, self.allowed_op_types()):
raise SpecViolationError(
f"Operator '{op}' is not an allowed operator type.\n"
f"Valid op types: {self.allowed_builtin_ops}"
)
if isinstance(op, OpOverload):
# All ops functional
if not is_functional(op):
raise SpecViolationError(
f"operator '{op}' is not functional"
)
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (torch.fx.GraphModule,)
def check_additional(self, gm: GraphModule) -> None:
"""
Additional checks that are specific to some dialects.
"""
pass
@final
def check_valid(self, gm: GraphModule) -> None: # noqa: C901
gm.graph.lint()
if object in self.allowed_op_types():
raise SpecViolationError(
"'object' is too generic to be in the list of allowed op types"
)
if object in self.allowed_getattr_types():
raise SpecViolationError(
"'object' is too generic to be in the list of allowed getattr types"
)
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
continue
for node in mod.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
if node.op in {"call_module", "call_method"}:
raise SpecViolationError(
f"call_module is not valid: got a class '{node.target}' ",
)
elif node.op == "call_function":
_check_val(node)
self.check_valid_op(node.target)
if isinstance(node.target, OpOverload):
# Check preserved metadata
for meta in PRESERVED_META_KEYS:
if node.meta.get(meta, None) is None:
raise SpecViolationError(
f"node {node} is missing metadata {meta}"
)
elif node.op == "get_attr":
if not isinstance(node.target, str):
raise SpecViolationError(
f"Expected get_attr target to be string, but got {type(node.target)}"
)
attr = getattr(mod, node.target)
if not isinstance(attr, self.allowed_getattr_types()):
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
f"Valid get_attr types: {self.allowed_getattr_types}"
)
elif node.op == "placeholder":
_check_val(node)
self.check_additional(gm)
def is_valid(self, gm: GraphModule) -> bool:
try:
self.check_valid(gm)
return True
except SpecViolationError:
return False
class ATenDialectVerifier(Verifier):
def check_valid_op(self, op) -> None:
super().check_valid_op(op)
if isinstance(op, OpOverload):
if (
torch.Tag.core not in op.tags
and torch.Tag.view_copy not in op.tags
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
)
def verify_exported_program_signature(exported_program) -> None:
# Check ExportedProgram signature matches
gs = exported_program.graph_signature
bs_grad_to_param = {}
bs_grad_to_user_inputs = {}
if gs.backward_signature is not None:
bs_grad_to_param = gs.backward_signature.gradients_to_parameters
bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs
# Check every node in the signature exists in the graph
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
for node in exported_program.graph.nodes:
if node.op != "placeholder":
break
input_node_names.append(node.name)
output_node = list(exported_program.graph.nodes)[-1]
assert output_node.op == "output"
output_node_names = [node.name for node in output_node.args[0]]
def check_exists(node_list, container):
for node in node_list:
if node not in container:
raise SpecViolationError(
f"Node {node} found in the signature's is not in the graph."
)
check_exists(gs.user_inputs, input_node_names)
check_exists(gs.user_outputs, output_node_names)
check_exists(gs.inputs_to_parameters.keys(), input_node_names)
check_exists(gs.inputs_to_parameters.values(), gs.parameters)
check_exists(gs.inputs_to_buffers.keys(), input_node_names)
check_exists(gs.inputs_to_buffers.values(), gs.buffers)
check_exists(gs.buffers_to_mutate.keys(), output_node_names)
check_exists(gs.buffers_to_mutate.values(), gs.buffers)
check_exists(bs_grad_to_param.keys(), output_node_names)
check_exists(bs_grad_to_param.values(), gs.parameters)
check_exists(bs_grad_to_user_inputs.keys(), output_node_names)
check_exists(bs_grad_to_user_inputs.values(), gs.user_inputs)
# Check parameters
for param in gs.parameters:
if param not in exported_program.state_dict:
raise SpecViolationError(
f"Parameter {param} is not in the state dict."
)
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
raise SpecViolationError(
f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
)
# Check buffers
for buffer in gs.buffers:
if buffer not in exported_program.state_dict:
raise SpecViolationError(
f"Buffer {buffer} is not in the state dict."
)
# Check inputs
placeholder_nodes = [n.name for n in exported_program.graph.nodes if n.op == "placeholder"]
total_gs_placeholders = len(gs.inputs_to_parameters) + len(gs.inputs_to_buffers) + len(gs.user_inputs)
if len(placeholder_nodes) != total_gs_placeholders:
raise SpecViolationError(
f"Number of placholders nodes {len(placeholder_nodes)} is different "
"Than the number of inputs specified by the graph signature: \n"
f"Number of parameters: {len(gs.inputs_to_parameters)}. \n"
f"Number of buffers: {len(gs.inputs_to_buffers)}. \n"
f"Number of user inputs: {len(gs.user_inputs)}. \n"
)
parameter_nodes = placeholder_nodes[:len(gs.parameters)]
buffer_nodes = placeholder_nodes[len(gs.parameters):len(gs.parameters) + len(gs.buffers)]
user_input_nodes = placeholder_nodes[len(gs.parameters) + len(gs.buffers):]
for param_node, param_name in zip(parameter_nodes, gs.parameters):
if (
param_node not in gs.inputs_to_parameters or
gs.inputs_to_parameters[param_node] != param_name
):
raise SpecViolationError(
f"Parameter input {param_node} is not in the correct "
"order or is not found in the exported program's parameter list. \n"
f"List of parameters, in order: {gs.parameters} \n"
f"Parameter node to parameter name mapping: {gs.inputs_to_parameters} \n"
)
for buffer_node, buffer_name in zip(buffer_nodes, gs.buffers):
if (
buffer_node not in gs.inputs_to_buffers or
gs.inputs_to_buffers[buffer_node] != buffer_name
):
raise SpecViolationError(
f"Buffer input {buffer_node} is not in the correct "
"order or is not found in the exported program's buffer list. \n"
f"List of buffers, in order: {gs.buffers} \n"
f"Buffer node to buffer name mapping: {gs.inputs_to_buffers} \n"
)
for user_input_node, user_input_name in zip(user_input_nodes, gs.user_inputs):
if user_input_node != user_input_name:
raise SpecViolationError(
f"User input {user_input_node} is not in the correct "
"order or is not found in the "
f"exported program's user_input list: {gs.user_input}. "
)
# Check outputs
output_node = list(exported_program.graph.nodes)[-1]
assert output_node.op == "output"
output_nodes = [arg.name for arg in output_node.args[0]]
total_gs_outputs = (
len(gs.buffers_to_mutate) +
len(gs.user_outputs) +
len(bs_grad_to_param) +
len(bs_grad_to_user_inputs)
)
if len(output_nodes) != total_gs_outputs:
raise SpecViolationError(
f"Number of output nodes {len(output_nodes)} is different "
"Than the number of outputs specified by the graph signature: \n"
f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
f"Number of user outputs: {len(gs.user_outputs)}. \n"
)
buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)]
user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)]
for buffer_node in buffer_mutate_nodes:
if (
buffer_node not in gs.buffers_to_mutate or
gs.buffers_to_mutate[buffer_node] not in gs.buffers
):
raise SpecViolationError(
f"Buffer output {buffer_node} is not in buffer mutation dictinoary "
"or, it does not point to a buffer that exists. \n"
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
f"Buffer nodes available: {gs.buffers} \n"
)
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
if user_output_node != user_output_name:
raise SpecViolationError(
f"User output {user_output_node} is not in the correct "
"order or is not found in the "
f"exported program's user_output list: {gs.user_output}. "
)