Skip to content

Commit

Permalink
fix load_state_dict() for larger models
Browse files Browse the repository at this point in the history
  • Loading branch information
unixpickle committed Jul 16, 2021
1 parent 0ba878e commit 63a928f
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions guided_diffusion/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,22 @@ def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
chunk_size = 2 ** 30 # MPI has a relatively small size limit
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
num_chunks = len(data) // chunk_size
if len(data) % chunk_size:
num_chunks += 1
MPI.COMM_WORLD.bcast(num_chunks)
for i in range(0, len(data), chunk_size):
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
else:
data = None
data = MPI.COMM_WORLD.bcast(data)
num_chunks = MPI.COMM_WORLD.bcast(None)
data = bytes()
for _ in range(num_chunks):
data += MPI.COMM_WORLD.bcast(None)

return th.load(io.BytesIO(data), **kwargs)


Expand Down

0 comments on commit 63a928f

Please sign in to comment.