forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
332 lines (280 loc) · 11.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import dataclasses
import traceback
from typing import Any, Callable, Container, Dict, List, Optional, OrderedDict, Tuple, TypeVar, overload
import torch
import torch.distributed as dist
from torch import nn
from torch.nn.parallel._functions import _get_stream
from torch.nn.parallel.scatter_gather import _is_namedtuple
from torch.nn.utils.rnn import PackedSequence
__all__ = [] # type: ignore[var-annotated]
def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
"""
Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70
Usage::
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
assert kwarg_keys == ("a", "b")
assert flat_args == (1, 2, 3, 4)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
assert args == (1, 2)
assert kwargs == {"a": 3, "b": 4}
Returns:
Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives
gives both positional args and kwarg values, where the positional args
proceed kwarg values and kwarg values are ordered consistently with the
kwarg keys. The second tuple element gives the kwarg keys.
The second tuple element's length is at most the first tuple element's length.
"""
kwarg_keys: List[str] = []
flat_args: List[Any] = list(args)
for k, v in kwargs.items():
kwarg_keys.append(k)
flat_args.append(v)
return tuple(flat_args), tuple(kwarg_keys)
def _cast_forward_inputs(
dtype: Optional[torch.dtype],
*args: Any,
**kwargs: Any,
) -> Tuple[Any, Any]:
"""
Casts floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
This respects the existing ``requires_grad`` on the tensors.
"""
if dtype is None:
return args, kwargs
def cast_fn(x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype:
return x
return x.to(dtype)
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""See _pack_kwargs."""
assert len(kwarg_keys) <= len(
flat_args
), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
return args, kwargs
S = TypeVar("S", dict, list, tuple)
T = TypeVar("T", torch.Tensor, PackedSequence)
@overload
def _recursive_to(inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> List[S]:
...
@overload
def _recursive_to(inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> Tuple[T]:
...
def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
r"""
Recursively moves input to the target_device.
"""
def to_map(obj):
if isinstance(obj, (torch.Tensor, PackedSequence)):
device = obj.data.device if isinstance(obj, PackedSequence) else obj.device
if device == target_device:
return (obj,)
if not use_side_stream_for_tensor_copies:
return (obj.to(target_device),)
else:
# If the custom module is not registered to torch, stream is not used for acceleration
device_mod = getattr(torch, device.type, None)
if device.type == "cpu" or device_mod is None:
return (obj.to(target_device),)
# Perform CPU -> target_device copies in a background stream. This code is
# motivated from similar logic in torch/nn/parallel/_functions.py
stream = _get_stream(target_device)
with device_mod.stream(stream):
output = obj.to(target_device)
# synchronize with the copy stream
with device_mod.device(target_device.index):
current_stream = device_mod.current_stream()
# Sync the current stream with the copy stream
current_stream.wait_stream(stream)
# Ensure tensor memory is not reused until work on
# main stream is complete
if isinstance(obj, PackedSequence):
output.data.record_stream(current_stream) # type: ignore[arg-type]
else:
assert isinstance(output, torch.Tensor)
output.record_stream(current_stream) # type: ignore[arg-type]
return (output,)
if _is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(to_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(to_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
return [obj]
# Avoid reference cycle
try:
res = to_map(inputs)
finally:
to_map = None # type: ignore[assignment]
return res
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""This is used as an alternate to ``assert`` when in the backward context
to print the error message ``s`` since otherwise, it is swallowed."""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor._typed_storage()._resize_(size.numel())
return not already_allocated
def _free_storage(tensor: torch.Tensor) -> bool:
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with torch.no_grad():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
return not already_freed
Q = TypeVar("Q")
R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any)
@overload
def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q:
...
@overload
def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R:
...
def _apply_to_tensors(fn, container):
"""Recursively apply to all tensor in different kinds of container types."""
def apply(x):
if isinstance(x, torch.Tensor):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
dc = dataclasses.replace(x)
for f in dataclasses.fields(dc):
name = f.name
setattr(dc, name, apply(getattr(dc, name)))
return dc
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = apply(value)
return od
elif isinstance(x, PackedSequence):
apply(x.data)
return x
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif _is_namedtuple(x):
res = (apply(el) for el in x)
return type(x)(*res)
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)
def _to_kwargs(
inputs: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
target_device: torch.device,
use_side_stream_for_tensor_copies: bool,
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
moved_inputs = (
_recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
if inputs
else []
)
moved_kwargs = (
_recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies)
if kwargs
else []
)
if len(moved_inputs) < len(moved_kwargs):
moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))])
elif len(moved_kwargs) < len(moved_inputs):
moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
return tuple(moved_inputs), tuple(moved_kwargs)
def _verify_param_shape_across_processes(
process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None
):
return dist._verify_params_across_processes(process_group, tensors, logger)
def _sync_module_states(
module: nn.Module,
process_group: dist.ProcessGroup,
broadcast_bucket_size: int,
src: int,
params_and_buffers_to_ignore: Container[str],
broadcast_buffers: bool = True,
) -> None:
"""
Syncs ``module``'s parameters and buffers state so that all ranks contain
the same module state across all ranks. Note that this API assumes that all
parameter shapes are consistent before running the synchronization. This can
be checked with ``_verify_param_shape_across_processes``.
"""
module_states: List[torch.Tensor] = []
for name, param in module.named_parameters():
if name not in params_and_buffers_to_ignore:
module_states.append(param.detach())
if broadcast_buffers:
for name, buffer in module.named_buffers():
if name not in params_and_buffers_to_ignore:
module_states.append(buffer.detach())
_sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
def _sync_params_and_buffers(
process_group: dist.ProcessGroup,
module_states: List[torch.Tensor],
broadcast_bucket_size: int,
src: int,
) -> None:
"""
Synchronizes ``module_states`` (list of tensors) across all processes by
broadcasting them from rank 0.
"""
if len(module_states) > 0:
dist._broadcast_coalesced(
process_group, module_states, broadcast_bucket_size, src
)
def _replace_by_prefix(
state_dict: Dict[str, Any],
old_prefix: str,
new_prefix: str,
) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if old_prefix == new_prefix:
raise ValueError("old_prefix and new_prefix must be distinct")
for key in list(state_dict.keys()):
if not key.startswith(old_prefix):
continue
new_key = new_prefix + key[len(old_prefix) :]
state_dict[new_key] = state_dict[key]
del state_dict[key]