Skip to content

Commit

Permalink
added save and montage for medmnist2d
Browse files Browse the repository at this point in the history
  • Loading branch information
duducheng committed Aug 19, 2021
1 parent ea9788a commit df11795
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 73 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ Please note that this dataset is **NOT** intended for clinical use.
# Installation and Requirements
Setup the required environments and install `medmnist` as a standard Python package:

pip install git+https://github.com/MedMNIST/MedMNIST.git
pip install --upgrade git+https://github.com/MedMNIST/MedMNIST.git

Check whether you have isnstalled the latest [version](medmnist/info.py):

>>> import medmnist
>>> print(medmnist.__version__)

The code requires only common Python environments for machine learning. Basicially, it was tested with
* Python 3 (Anaconda 3.6.3 specifically)
Expand Down Expand Up @@ -68,10 +72,18 @@ The MedMNIST dataset contains several subsets. Each subset (e.g., `pathmnist.npz

* Print the dataset details given a subset flag:

python -m medmnist info <subset:xxxmnist>
python -m medmnist info --flag=xxxmnist

* Save the dataset as standard figures, which could be used for AutoML tools, e.g., Google AutoML Vision:

python -m medmnist save --flag=xxxmnist --folder=tmp/

* Download the dataset manually or automatically (by setting `download=True` in [`dataset.py`](medmnist/dataset.py)).

* Explore the MedMNIST dataset with jupyter notebook ([`getting_started.ipynb`](examples/getting_started.ipynb)), and train basic neural networks in PyTorch.

* If you do not use PyTorch, go to [`getting_started_without_PyTorch.ipynb`](examples/getting_started_without_PyTorch.ipynb), which provides snippets about how to use MedMNIST data (the `.npz` files) without PyTorch.

* Please refer to our another repository [`MedMNIST/experiments`](https://github.com/MedMNIST/experiments) for all experiments, including PyTorch, auto-sklearn, AutoKeras and Google AutoML Vision together with their weights!

# Citation
Expand Down
139 changes: 88 additions & 51 deletions examples/getting_started.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion medmnist/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from medmnist.info import __version__, HOMEPAGE
from medmnist.dataset import PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST, BreastMNIST, BloodMNIST, TissueMNIST, OrganAMNIST, OrganCMNIST, OrganSMNIST
from medmnist.dataset import (PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST,
BreastMNIST, BloodMNIST, TissueMNIST, OrganAMNIST, OrganCMNIST, OrganSMNIST,
OrganMNIST3D, NoduleMNIST3D, AdrenalMNIST3D, FractureMNIST3D, VesselMNIST3D, SynapseMNIST3D)
# from medmnist.evaluator import Evaluator
67 changes: 60 additions & 7 deletions medmnist/__main__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from medmnist.info import __version__, HOMEPAGE, INFO, DEFAULT_ROOT
import medmnist
from medmnist.info import INFO, DEFAULT_ROOT


def available():
'''List all available datasets.'''
print(f"MedMNIST v{__version__} @ {HOMEPAGE}")
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

print("All available datasets:")
for key in INFO.keys():
Expand All @@ -15,7 +16,7 @@ def download(root=DEFAULT_ROOT):
for key in INFO.keys():
print(f"Downloading {key}...")
_ = getattr(medmnist, INFO[key]['python_class'])(
split="train", root=root)
split="train", root=root, download=True)


def clean(root=DEFAULT_ROOT):
Expand All @@ -34,18 +35,70 @@ def info(flag):
pprint(INFO[flag])


def test():
def save(flag, folder, postfix="png", root=DEFAULT_ROOT):
'''Save the dataset as standard figures, which could be used for AutoML tools, e.g., Google AutoML Vision.'''
print(f"Saving {flag} train...")
train_dataset = getattr(medmnist, INFO[flag]['python_class'])(
split="train", root=root)
train_dataset.save(folder, postfix)

print(f"Saving {flag} val...")
val_dataset = getattr(medmnist, INFO[flag]['python_class'])(
split="val", root=root)
val_dataset.save(folder, postfix)

print(f"Saving {flag} test...")
test_dataset = getattr(medmnist, INFO[flag]['python_class'])(
split="test", root=root)
test_dataset.save(folder, postfix)


def test(save_folder="tmp/", root=DEFAULT_ROOT):
'''For developmemnt only.'''

available()

download()

clean()
download(root)

for key in INFO.keys():
print(f"Verifying {key}....")

info(key)

train_dataset = getattr(medmnist, INFO[key]['python_class'])(
split="train", root=root)
assert len(train_dataset) == INFO[key]["n_samples"]["train"]

val_dataset = getattr(medmnist, INFO[key]['python_class'])(
split="val", root=root)
assert len(val_dataset) == INFO[key]["n_samples"]["val"]

test_dataset = getattr(medmnist, INFO[key]['python_class'])(
split="test", root=root)
assert len(test_dataset) == INFO[key]["n_samples"]["test"]

n_channels = INFO[key]["n_channels"]

_, *shape = train_dataset.img.shape
if n_channels == 3:
assert shape == [28, 28, 3]
else:
assert n_channels == 1
assert shape == [28, 28] or shape == [28, 28, 28]

if save_folder != "null":
try:
train_dataset.montage(save_folder=save_folder)
except NotImplementedError:
print(f"{key} `montage` method not implemented.")

try:
save(key, save_folder, postfix=".jpg", root=root)
except:
print(f"{key} `save` method not implemented.")

# clean(root)


if __name__ == "__main__":
import fire
Expand Down
110 changes: 101 additions & 9 deletions medmnist/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from sys import base_prefix
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
Expand All @@ -13,7 +14,7 @@ def __init__(self,
split,
transform=None,
target_transform=None,
download=True,
download=False,
as_rgb=False,
root=DEFAULT_ROOT):
''' dataset
Expand Down Expand Up @@ -94,12 +95,10 @@ class MedMNIST2D(MedMNIST):

def __getitem__(self, index):
img, target = self.img[index], self.label[index].astype(int)
img = Image.fromarray(np.uint8(img))
img = Image.fromarray(img)

if self.as_rgb:
img = Image.fromarray(img).convert('RGB')
else:
img = Image.fromarray(img)
img = img.convert('RGB')

if self.transform is not None:
img = self.transform(img)
Expand All @@ -109,11 +108,80 @@ def __getitem__(self, index):

return img, target

def save(self, folder):
pass
def save(self, folder, postfix="png", write_csv=True):

split_dict = {
"train": "TRAIN",
"val": "VALIDATION",
"test": "TEST"
} # compatible for Google AutoML Vision

from tqdm import trange

_transform = self.transform
_target_transform = self.target_transform
self.transform = None
self.target_transform = None

base_folder = os.path.join(folder, self.flag)

if not os.path.exists(base_folder):
os.makedirs(base_folder)

if write_csv:
csv_file = open(os.path.join(folder, f"{self.flag}.csv"), "a")

for idx in trange(self.__len__()):

img, label = self.__getitem__(idx)

file_name = f"{self.split}{idx}_{'_'.join(map(str,label))}.{postfix}"

img.save(os.path.join(base_folder, file_name))

if write_csv:
line = f"{split_dict[self.split]},{file_name},{','.join(map(str,label))}\n"
csv_file.write(line)

self.transform = _transform
self.target_transform = _target_transform
csv_file.close()

def montage(self, length=20, replace=False, save_folder=None):
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.util import montage as skimage_montage

n_imgs = length * length
sel = np.random.choice(self.__len__(), size=n_imgs, replace=replace)
sel_img = self.img[sel]
if self.info['n_channels'] == 3:
montage_arr = skimage_montage(sel_img, multichannel=True)
else:
assert self.info['n_channels'] == 1
montage_arr = skimage_montage(sel_img, multichannel=False)

montage_img = Image.fromarray(montage_arr)

if save_folder is not None:
montage_img.save(
os.path.join(save_folder,
f"{self.flag}_{self.split}_montage.jpg"))

def montage(self, length):
pass
return montage_img


class MedMNIST3D(MedMNIST):

def __getitem__(self, index):
return super().__getitem__(index)

def save(self, folder, postfix="png", write_csv=True):
raise NotImplementedError

def montage(self, length=20, replace=False, save_folder=None):
raise NotImplementedError


class PathMNIST(MedMNIST2D):
Expand Down Expand Up @@ -164,6 +232,30 @@ class OrganSMNIST(MedMNIST2D):
flag = "organsmnist"


class OrganMNIST3D(MedMNIST3D):
flag = "organmnist3d"


class NoduleMNIST3D(MedMNIST3D):
flag = "nodulemnist3d"


class AdrenalMNIST3D(MedMNIST3D):
flag = "adrenalmnist3d"


class FractureMNIST3D(MedMNIST3D):
flag = "fracturemnist3d"


class VesselMNIST3D(MedMNIST3D):
flag = "vesselmnist3d"


class SynapseMNIST3D(MedMNIST3D):
flag = "synapsemnist3d"


# backward-compatible
OrganMNISTAxial = OrganAMNIST
OrganMNISTCoronal = OrganCMNIST
Expand Down
6 changes: 3 additions & 3 deletions medmnist/info.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
__version__ = "1.1"
__version__ = "1.2"


import os
import warnings
from os.path import expanduser
import warnings


def get_default_root():
Expand Down Expand Up @@ -413,7 +413,7 @@ def get_default_root():
"label": {
"0": "Buckle Rib Fracture",
"1": "Nondisplaced Rib Fracture",
"2": "Displaced Rib Fracture, Segmental Rib Fracture"
"2": "Displaced Rib Fracture"
},
"n_channels": 1,
"n_samples": {
Expand Down

0 comments on commit df11795

Please sign in to comment.