This repository contains the code for the methods in our 2020 CVPR CVMI paper, Self-Supervised Feature Extraction for 3D Axon Segmentation.
The code has been developed and tested with PyTorch 1.2.0 in a Conda environment
conda activate {Project Environment}
git clone ${repository}
cd ssl-for-axons
conda install --yes --file requirements.yml
Our work utilized data generated by the MIT Chung Lab using the SHIELD imaging technique. We also used the publically available Janelia dataset from the BigNeuron Project.
We propose first training the 3D U-Net encoder along with an auxiliary classifier on an auxiliary, self-supervised task. The auxiliary task consists of predicting the permutation used to reorder the slices of an input subvolume. For example, if 10 permutations have been generated, the auxiliary classifier should predict a one-hot encoding of length 10, where argmax of the encoding is the index to the permutation used per sample. Code for generating permutations is in utils/permutations.py. After training with the auxiliary task, the 3D U-Net pre-trained encoder and randomly initialized decoder can be fine-tuned on the target task of axon segmentation.
# 3D U-Net Baseline
python train_unet.py configs/train.json
# Auxiliary Task
python train_aux_task.py configs/train_slices.json
# Fine Tuning 3D U-Net with pre-trained encoder
python transfer_train_unet.py configs/train.json weights/aux_task/best.ckpt
We natively support training with data stored in H5 datasets. H5 files in ./data/train are used for training and H5 files in ./data/val are used for validation. All training parameters can be set in the JSON config and/or top of the py file. For example, you should always indicate the location of your data directory and dataset names in the config:
"path": "./data/janelia,
"dataset_name_raw": "data",
"dataset_name_truth": "truth",
As in training, testing can be done using a JSON configuration file. These are located in configs and can be modified as needed. We include support for saving predictions back to an H5 or Tiff image stack, as well as reporting voxel-based metrics including precision, recall, AUC, and F1 scores when truth is provided. All H5 files in ./data/test will be processed.
# To test 3D U-Net
python test.py configs/test.json {path_to_weights_file}
# To test auxiliary classifier
python test_aux_task.py configs/test_slices.json {path_to_weights_file}
Our 2020 CVPR CVMI paper:
@inproceedings{klinghoffer2020ssl,
title={Self-Supervised Feature Extraction for 3D Axon Segmentation},
author={Klinghoffer, Tzofi and Morales, Peter and Park, Young-Gyun and Evans, Nicholas and Chung, Kwanghun and Brattain, Laura},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops},
year={2020}
}
Our implementation also draws upon the 3D U-Net code by Wolny et al.