-
Notifications
You must be signed in to change notification settings - Fork 893
/
Copy pathutil.py
186 lines (116 loc) · 4.31 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import jax
import jax.numpy as jnp
from jax.experimental.pjit import with_sharding_constraint
from optax import AdditiveWeightDecayState, GradientTransformation, OptState
# same as with_sharding_constraint but doesn't fail if run outside of pjit/mesh context
def maybe_shard(x, resource):
try:
return with_sharding_constraint(x, resource)
except ValueError as e:
print(e)
return x
def gpt3_schedule(warmup_steps,
total_steps,
peak_lr,
end_lr):
def sch(step):
warmup_pct = jnp.clip(step, 0, warmup_steps) / warmup_steps
anneal_pct = jnp.clip(step - warmup_steps, 0, total_steps) / total_steps
return warmup_pct * peak_lr - (peak_lr - end_lr) * (1 - jnp.cos(jnp.pi * anneal_pct)) / 2
return sch
def global_norm(updates, use_psum=True):
pre_sqrt = sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)])
if use_psum:
pre_sqrt = jax.lax.psum(pre_sqrt, "shard")
return jnp.sqrt(pre_sqrt)
class ClipByGlobalNormState(OptState):
"""The `clip_by_global_norm` transformation is stateless."""
def clip_by_global_norm(max_norm, use_psum=True) -> GradientTransformation:
"""Clip updates using their global norm.
References:
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
Args:
max_norm: the maximum global norm for an update.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
return ClipByGlobalNormState()
def update_fn(updates, state, params=None):
del params
g_norm = global_norm(updates, use_psum=use_psum)
trigger = g_norm < max_norm
updates = jax.tree_map(
lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates)
return updates, state
return GradientTransformation(init_fn, update_fn)
def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation:
"""Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc)
Args:
weight_decay: a scalar weight decay rate.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(_):
return AdditiveWeightDecayState()
def update_fn(updates, state, params):
updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params)
return updates, state
return GradientTransformation(init_fn, update_fn)
def to_f32(t):
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
def to_bf16(t):
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
def to_f16(t):
return jax.tree_map(lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x, t)
# identity in forward pass, psum in backward
@jax.custom_vjp
def f_psum(x):
return x
def f_psum_fwd(x):
return f_psum(x), None
def f_psum_bwd(_, g):
return jax.lax.psum(g, "shard"),
f_psum.defvjp(f_psum_fwd, f_psum_bwd)
# identity in forward pass, pmean in backward
@jax.custom_vjp
def f_pmean(x):
return x
def f_pmean_fwd(x):
return f_psum(x), None
def f_pmean_bwd(_, g):
return jax.lax.pmean(g, "shard"),
f_pmean.defvjp(f_pmean_fwd, f_pmean_bwd)
# psum in forward pass, identity in backward
@jax.custom_vjp
def g_psum(x):
return jax.lax.psum(x, "shard")
def g_psum_fwd(x):
return g_psum(x), None
def g_psum_bwd(_, g):
return g,
g_psum.defvjp(g_psum_fwd, g_psum_bwd)
def shard_axis(x, axis_size, axis_name):
# in_shape = x.shape
assert x.shape[0] % axis_size == 0
x = x.reshape((axis_size, -1) + x.shape[1:])
x = x[jax.lax.axis_index(axis_name)]
# print("shard out", x.shape, "in", in_shape)
# assert np.prod(x.shape) * axis_size == np.prod(in_shape)
return x
def unshard_axis(x, axis_name):
# in_shape = x.shape
x = jax.lax.all_gather(x, axis_name)
x = x.reshape((-1, ) + x.shape[2:])
# assert x.shape[-1] == 4096
# print("unshard out", x.shape, "in", in_shape)
return x
# print but only on the first node
def head_print(*args, **kwargs):
if jax.host_id() == 0:
print(*args, **kwargs)
if __name__ == "__main__":
sch = gpt3_schedule(1_000, 20_000, 1e-4, 1e-5)
for i in range(150):
i = i * 200
print(i, sch(i))