diff --git a/medmnist/utils.py b/medmnist/utils.py index cd4257e..d3d5a26 100644 --- a/medmnist/utils.py +++ b/medmnist/utils.py @@ -20,9 +20,9 @@ def save2d(imgs, labels, img_folder, def montage2d(imgs, n_channels, sel): - sel_img = imgs[sel] - montage_arr = skimage_montage(sel_img, multichannel=(n_channels == 3)) + channel_axis = 3 if n_channels == 3 else None + montage_arr = skimage_montage(sel_img, channel_axis=channel_axis) montage_img = Image.fromarray(montage_arr) return montage_img