Skip to content

Commit

Permalink
two fixes
Browse files Browse the repository at this point in the history
 * super resolution fp16 sampling
 * fractional channel multiplier (used for 512x512 models)
  • Loading branch information
unixpickle committed Jul 16, 2021
1 parent 63a928f commit 18ff408
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 25 deletions.
6 changes: 4 additions & 2 deletions guided_diffusion/script_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def create_model(
):
if channel_mult == "":
if image_size == 512:
channel_mult = (1, 1, 2, 2, 4, 4)
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
elif image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
Expand Down Expand Up @@ -235,7 +235,9 @@ def create_classifier(
classifier_resblock_updown,
classifier_pool,
):
if image_size == 256:
if image_size == 512:
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
elif image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
channel_mult = (1, 1, 2, 3, 4)
Expand Down
38 changes: 15 additions & 23 deletions guided_diffusion/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,12 @@ def __init__(
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)

ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
Expand All @@ -495,13 +491,13 @@ def __init__(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
Expand Down Expand Up @@ -576,13 +572,13 @@ def __init__(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
out_channels=int(model_channels * mult),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
Expand Down Expand Up @@ -616,7 +612,7 @@ def __init__(
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
)

def convert_to_fp16(self):
Expand Down Expand Up @@ -739,16 +735,12 @@ def __init__(
linear(time_embed_dim, time_embed_dim),
)

ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
Expand All @@ -757,13 +749,13 @@ def __init__(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
Expand Down
2 changes: 2 additions & 0 deletions scripts/super_res_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def main():
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()

logger.log("loading data...")
Expand Down

0 comments on commit 18ff408

Please sign in to comment.