-
-
Notifications
You must be signed in to change notification settings - Fork 988
/
svi_horovod.py
169 lines (145 loc) · 6.51 KB
/
svi_horovod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
# Distributed training via Horovod.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the Horovod
# library. Horovod enables data-parallel training by aggregating stochastic
# gradients at each step of training. Horovod is not intended for model
# parallelism. We focus on integration between Horovod and Pyro. For further
# details on distributed computing with Horovod, see
# https://horovod.readthedocs.io/en/stable
#
# This assumes you have installed horovod, e.g. via
# pip install pyro-ppl[horovod]
# For detailed instructions see
# https://horovod.readthedocs.io/en/stable/install.html
# On my mac laptop I was able to install horovod with
# CFLAGS=-mmacosx-version-min=10.15 HOROVOD_WITH_PYTORCH=1 \
# pip install --no-cache-dir 'horovod[pytorch]'
#
# Finally, you'll need to run this script via horovodrun, e.g.
# horovodrun -np 2 python svi_horovod.py
# For details on running Horovod see
# https://github.com/horovod/horovod/blob/master/docs/running.rst
import argparse
import torch
import torch.multiprocessing as mp
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.nn import PyroModule
from pyro.optim import Adam, HorovodOptimizer
# We define a model as usual, with no reference to Horovod.
# This model is data parallel and supports subsampling.
class Model(PyroModule):
def __init__(self, size):
super().__init__()
self.size = size
def forward(self, covariates, data=None):
coeff = pyro.sample("coeff", dist.Normal(0, 1))
bias = pyro.sample("bias", dist.Normal(0, 1))
scale = pyro.sample("scale", dist.LogNormal(0, 1))
# Since we'll use a distributed dataloader during training, we need to
# manually pass minibatches of (covariates,data) that are smaller than
# the full self.size. In particular we cannot rely on pyro.plate to
# automatically subsample, since that would lead to all workers drawing
# identical subsamples.
with pyro.plate("data", self.size, len(covariates)):
loc = bias + coeff * covariates
return pyro.sample("obs", dist.Normal(loc, scale), obs=data)
# The following is a standard training loop. To emphasize the Horovod-specific
# parts, we've guarded them by `if args.horovod:`.
def main(args):
# Create a model, synthetic data, and a guide.
pyro.set_rng_seed(args.seed)
model = Model(args.size)
covariates = torch.randn(args.size)
data = model(covariates)
guide = AutoNormal(model)
if args.horovod:
# Initialize Horovod and set PyTorch globals.
import horovod.torch as hvd
hvd.init()
torch.set_num_threads(1)
if args.cuda:
torch.cuda.set_device(hvd.local_rank())
if args.cuda:
torch.set_default_device("cuda")
device = torch.tensor(0).device
if args.horovod:
# Initialize parameters and broadcast to all workers.
guide(covariates[:1], data[:1]) # Initializes model and guide.
hvd.broadcast_parameters(guide.state_dict(), root_rank=0)
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
# Create an ELBO loss and a Pyro optimizer.
elbo = Trace_ELBO()
optim = Adam({"lr": args.learning_rate})
if args.horovod:
# Wrap the basic optimizer in a distributed optimizer.
optim = HorovodOptimizer(optim)
# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
if args.horovod:
# Horovod requires a distributed sampler.
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, hvd.size(), hvd.rank()
)
else:
sampler = torch.utils.data.RandomSampler(dataset)
config = {"batch_size": args.batch_size, "sampler": sampler}
if args.cuda:
config["num_workers"] = 1
config["pin_memory"] = True
# Try to use forkserver to spawn workers instead of fork.
if (
hasattr(mp, "_supports_context")
and mp._supports_context
and "forkserver" in mp.get_all_start_methods()
):
config["multiprocessing_context"] = "forkserver"
dataloader = torch.utils.data.DataLoader(dataset, **config)
# Run stochastic variational inference.
svi = SVI(model, guide, optim, elbo)
for epoch in range(args.num_epochs):
if args.horovod:
# Set rng seeds on distributed samplers. This is required.
sampler.set_epoch(epoch)
for step, (covariates_batch, data_batch) in enumerate(dataloader):
loss = svi.step(covariates_batch.to(device), data_batch.to(device))
if args.horovod:
# Optionally average loss metric across workers.
# You can do this with arbitrary torch.Tensors.
loss = torch.tensor(loss)
loss = hvd.allreduce(loss, "loss")
loss = loss.item()
# Print only on the rank=0 worker.
if step % 100 == 0 and hvd.rank() == 0:
print("epoch {} step {} loss = {:0.4g}".format(epoch, step, loss))
else:
if step % 100 == 0:
print("epoch {} step {} loss = {:0.4g}".format(epoch, step, loss))
if args.horovod:
# After we're done with the distributed parts of the program,
# we can shutdown all but the rank=0 worker.
hvd.shutdown()
if hvd.rank() != 0:
return
if args.outfile:
print("saving to {}".format(args.outfile))
torch.save({"model": model, "guide": guide}, args.outfile)
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Distributed training via Horovod")
parser.add_argument("-o", "--outfile")
parser.add_argument("-s", "--size", default=1000000, type=int)
parser.add_argument("-b", "--batch-size", default=100, type=int)
parser.add_argument("-n", "--num-epochs", default=10, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--horovod", action="store_true", default=True)
parser.add_argument("--no-horovod", action="store_false", dest="horovod")
parser.add_argument("--seed", default=20200723, type=int)
args = parser.parse_args()
main(args)