Skip to content

Commit f3777f6

Browse files
committed
Support & instructions for MPS (Silicon Mac M1/M2) and CPU
1 parent d7f7319 commit f3777f6

8 files changed

+49
-28
lines changed

README.md

+13-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,19 @@
3535

3636
## Requirements
3737

38-
Please follow the requirements of [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
38+
If you have CUDA graphic card, please follow the requirements of [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
39+
40+
Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:
41+
42+
```sh
43+
cat environment.yml | \
44+
grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
45+
conda env create -f environment-no-nvidia.yml
46+
conda activate stylegan3
47+
48+
# On MacOS
49+
export PYTORCH_ENABLE_MPS_FALLBACK=1
50+
```
3951

4052
## Download pre-trained StyleGAN2 weights
4153

environment.yml

+10-7
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,23 @@ channels:
55
dependencies:
66
- python >= 3.8
77
- pip
8-
- numpy>=1.20
8+
- numpy>=1.25
99
- click>=8.0
10-
- pillow=8.3.1
11-
- scipy=1.7.1
12-
- pytorch=1.9.1
10+
- pillow=9.4.0
11+
- scipy=1.11.0
12+
- pytorch>=2.0.1
13+
- torchvision>=0.15.2
1314
- cudatoolkit=11.1
1415
- requests=2.26.0
1516
- tqdm=4.62.2
1617
- ninja=1.10.2
1718
- matplotlib=3.4.2
1819
- imageio=2.9.0
1920
- pip:
20-
- imgui==1.3.0
21-
- glfw==2.2.0
21+
- imgui==2.0.0
22+
- glfw==2.6.1
23+
- gradio==3.35.2
2224
- pyopengl==3.1.5
2325
- imageio-ffmpeg==0.4.3
24-
- pyspng
26+
# pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
27+
- pyspng-seunglab

gen_images.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,10 @@ def generate_images(
103103
"""
104104

105105
print('Loading networks from "%s"...' % network_pkl)
106-
device = torch.device('cuda')
106+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
107+
dtype = torch.float32 if device.type == 'mps' else torch.float64
107108
with dnnlib.util.open_url(network_pkl) as f:
108-
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
109+
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
109110
# import pickle
110111
# G = legacy.load_network_pkl(f)
111112
# output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
@@ -126,7 +127,7 @@ def generate_images(
126127
# Generate images.
127128
for seed_idx, seed in enumerate(seeds):
128129
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
129-
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
130+
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
130131

131132
# Construct an inverse rotation/translation matrix and pass to the generator. The
132133
# generator expects this matrix as an inverse to avoid potentially failing numerical

stylegan_human/generate.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ def generate_images(
6363

6464
else:
6565
import torch
66-
device = torch.device('cuda')
66+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
67+
dtype = torch.float32 if device.type == 'mps' else torch.float64
6768
with dnnlib.util.open_url(network_pkl) as f:
68-
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
69+
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
6970
os.makedirs(outdir, exist_ok=True)
7071

7172

@@ -92,7 +93,7 @@ def generate_images(
9293

9394
else: ## stylegan v2/v3
9495
label = torch.zeros([1, G.c_dim], device=device)
95-
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
96+
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
9697
if target_z.size==0:
9798
target_z= z.cpu()
9899
else:

stylegan_human/interpolation.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,10 @@ def main(
116116
):
117117

118118

119-
device = torch.device('cuda')
119+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
120+
dtype = torch.float32 if device.type == 'mps' else torch.float64
120121
with dnnlib.util.open_url(network_pkl) as f:
121-
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
122+
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
122123

123124
outdir = os.path.join(outdir)
124125
if not os.path.exists(outdir):
@@ -132,8 +133,8 @@ def main(
132133
print('Require two seeds, randomly generate two now.')
133134
seeds = [seeds[0],random.randint(0,10000)]
134135

135-
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device)
136-
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device)
136+
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device, dtype=dtype)
137+
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device, dtype=dtype)
137138
img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device)
138139
img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device)
139140
img1.save(f'{outdir}/seed{seeds[0]:04d}.png')

stylegan_human/style_mixing.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ def generate_style_mix(
4949
):
5050

5151
print('Loading networks from "%s"...' % network_pkl)
52-
device = torch.device('cuda')
52+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
53+
dtype = torch.float32 if device.type == 'mps' else torch.float64
5354
with dnnlib.util.open_url(network_pkl) as f:
54-
G = legacy.load_network_pkl(f)['G_ema'].to(device)
55+
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
5556

5657
os.makedirs(outdir, exist_ok=True)
5758

5859
print('Generating W vectors...')
5960
all_seeds = list(set(row_seeds + col_seeds))
6061
all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
61-
all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
62+
all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None)
6263
w_avg = G.mapping.w_avg
6364
all_w = w_avg + (all_w - w_avg) * truncation_psi
6465
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}

stylegan_human/stylemixing_video.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def style_mixing_video(network_pkl: str,
6565
print('col_seeds: ', dst_seeds)
6666
num_frames = int(np.rint(duration_sec * mp4_fps))
6767
print('Loading networks from "%s"...' % network_pkl)
68-
device = torch.device('cuda')
68+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
69+
dtype = torch.float32 if device.type == 'mps' else torch.float64
6970
with dnnlib.util.open_url(network_pkl) as f:
70-
Gs = legacy.load_network_pkl(f)['G_ema'].to(device)
71+
Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
7172

7273
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
7374
max_style = int(2 * np.log2(Gs.img_resolution)) - 3
@@ -80,14 +81,14 @@ def style_mixing_video(network_pkl: str,
8081
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
8182
src_z /= np.sqrt(np.mean(np.square(src_z)))
8283
# Map into the detangled latent space W and do truncation trick
83-
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None)
84+
src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None)
8485
w_avg = Gs.mapping.w_avg
8586
src_w = w_avg + (src_w - w_avg) * truncation_psi
8687

8788
# Top row latents (fixed reference)
8889
print('Generating Destination W vectors...')
8990
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
90-
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None)
91+
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None)
9192
dst_w = w_avg + (dst_w - w_avg) * truncation_psi
9293
# Get the width and height of each image:
9394
H = Gs.img_resolution # 1024
@@ -120,7 +121,7 @@ def make_frame(t):
120121
for col, dst_image in enumerate(list(dst_images)):
121122
# Select the pertinent latent w column:
122123
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
123-
w_col = torch.from_numpy(w_col).to(device)
124+
w_col = torch.from_numpy(w_col).to(device, dtype=dtype)
124125
# Replace the values defined by col_styles:
125126
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
126127
# Generate these synthesized images:

viz/renderer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
6969

7070
class Renderer:
7171
def __init__(self, disable_timing=False):
72-
self._device = torch.device('cuda')
72+
self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
73+
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
7374
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
7475
self._networks = dict() # {cache_key: torch.nn.Module, ...}
7576
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
@@ -241,7 +242,7 @@ def init_network(self, res,
241242

242243
if self.w_load is None:
243244
# Generate random latents.
244-
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device).float()
245+
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
245246

246247
# Run mapping network.
247248
label = torch.zeros([1, G.c_dim], device=self._device)

0 commit comments

Comments
 (0)