Skip to content

Commit

Permalink
fix cuda memory, super res settings, fp16 training (openai#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
unixpickle authored May 26, 2021
1 parent b16b0a1 commit 0ba878e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
3 changes: 2 additions & 1 deletion guided_diffusion/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def setup_dist():
"""
if dist.is_initialized():
return
os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"

comm = MPI.COMM_WORLD
backend = "gloo" if not th.cuda.is_available() else "nccl"
Expand All @@ -46,7 +47,7 @@ def dev():
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
return th.device(f"cuda")
return th.device("cpu")


Expand Down
3 changes: 2 additions & 1 deletion guided_diffusion/fp16_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def _optimize_fp16(self, opt: th.optim.Optimizer):
logger.logkv_mean("grad_norm", grad_norm)
logger.logkv_mean("param_norm", param_norm)

opt.step(grad_scale=2.0 ** self.lg_loss_scale)
self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
opt.step()
zero_master_grads(self.master_params)
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
self.lg_loss_scale += self.fp16_scale_growth
Expand Down
16 changes: 15 additions & 1 deletion guided_diffusion/script_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def sr_create_model_and_diffusion(
num_channels,
num_res_blocks,
num_heads,
num_head_channels,
num_heads_upsample,
attention_resolutions,
dropout,
Expand All @@ -295,6 +296,8 @@ def sr_create_model_and_diffusion(
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
resblock_updown,
use_fp16,
):
model = sr_create_model(
large_size,
Expand All @@ -306,9 +309,12 @@ def sr_create_model_and_diffusion(
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
resblock_updown=resblock_updown,
use_fp16=use_fp16,
)
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
Expand All @@ -333,13 +339,18 @@ def sr_create_model(
use_checkpoint,
attention_resolutions,
num_heads,
num_head_channels,
num_heads_upsample,
use_scale_shift_norm,
dropout,
resblock_updown,
use_fp16,
):
_ = small_size # hack to prevent unused variable

if large_size == 256:
if large_size == 512:
channel_mult = (1, 1, 2, 2, 4, 4)
elif large_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif large_size == 64:
channel_mult = (1, 2, 3, 4)
Expand All @@ -362,8 +373,11 @@ def sr_create_model(
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
use_fp16=use_fp16,
)


Expand Down

0 comments on commit 0ba878e

Please sign in to comment.