-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwrn_utils.py
72 lines (51 loc) · 2.39 KB
/
wrn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from torch.nn.init import kaiming_normal_
import torch.nn.functional as F
from torch.nn.parallel._functions import Broadcast
from torch.nn.parallel import scatter, parallel_apply, gather
from functools import partial
from nested_dict import nested_dict
def cast(params, dtype='float'):
if isinstance(params, dict):
return {k: cast(v, dtype) for k,v in params.items()}
else:
return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)()
def conv_params(ni, no, k=1):
return kaiming_normal_(torch.Tensor(no, ni, k, k))
def linear_params(ni, no):
return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)}
def bnparams(n):
return {'weight': torch.rand(n),
'bias': torch.zeros(n),
'running_mean': torch.zeros(n),
'running_var': torch.ones(n)}
def data_parallel(f, input, params, mode, device_ids, output_device=None):
assert isinstance(device_ids, list)
if output_device is None:
output_device = device_ids[0]
if len(device_ids) == 1:
return f(input, params, mode)
params_all = Broadcast.apply(device_ids, *params.values())
params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())}
for j in range(len(device_ids))]
replicas = [partial(f, params=p, mode=mode)
for p in params_replicas]
inputs = scatter([input], device_ids)
outputs = parallel_apply(replicas, inputs)
return gather(outputs, output_device)
def flatten(params):
return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None}
def batch_norm(x, params, base, mode):
return F.batch_norm(x, weight=params[base + '.weight'],
bias=params[base + '.bias'],
running_mean=params[base + '.running_mean'],
running_var=params[base + '.running_var'],
training=mode)
def print_tensor_dict(params):
kmax = max(len(key) for key in params.keys())
for i, (key, v) in enumerate(params.items()):
print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad)
def set_requires_grad_except_bn_(params):
for k, v in params.items():
if not k.endswith('running_mean') and not k.endswith('running_var'):
v.requires_grad = True