Skip to content

krulllab/COSDD

Repository files navigation

COSDD (COrrelated and Signal-Dependent Denoising)

ArXiv
Benjamin Salmon1 and Alexander Krull2
1, 2University of Birmingham
1[email protected], 2[email protected]

This repository contains code for the paper Unsupervised Denoising for Signal-Dependent and Row-Correlated Imaging Noise. It is a fully unsupervised denoiser for removing noise that is correlated along rows or columns of images. This type of noise commonly occurs in scanning-based microscopy and in some sensors such as EMCCD and sCMOS. It is trained using only noisy images, i.e., the very data that is to be denoised.

Reproducing results in publication
There have been updates to this code since the results in the WACV paper were obtained. Please checkout the WACV_reproducibility branch to reproduce those results.

Getting started

Environment

It is recommended to install the dependencies in a conda environment. If you haven't already, install miniconda on your system by following this link.
Once conda is installed, create and activate an environment by entering these lines into a command line interface:

  1. conda create --name cosdd python=3.12
  2. conda activate cosdd

Next, install PyTorch and torchvision for your system by following this link.
After that, you're ready to install the dependencies for this repository:
pip install lightning jupyterlab matplotlib tifffile scikit-learn scikit-image tensorboard

Data

The tutorial notebook training.ipynb will download an example dataset and store it as a .tiff in ./data. By default, this repo uses skimage.io.imread to load images, which works for common image types including .tiff. If your data is unsupported, edit utils.get_imread_fn to use a different function. The function should return images as a numpy array.

Tutorial notebooks

This repository contains three tutorial notebooks, training.ipynb, prediction.ipynb and generation.ipynb. They will walk through training a model using an example dataset, using the trained model to denoise the example dataset, and using the trained model to generate entirely new images from nothing.

Command line interface

COSDD takes a long time to train. It is recommended to train it in the terminal, instead of the notebooks. Training and prediction configurations are set using .yaml files. An example training config file is example-train-config.yaml and an example prediction config file is example-predict-config.yaml. See below for options.

To train, run:
python training.py example-train-config.yaml

After training, use the model to denoise by running:
python prediction.py example-predict-config.yaml
The prediction script will save results in a directory called denoised-<date>_<time>, with a file called denoised-<i>.tif for each input file.

Yaml training config file options

Important options are: model_name, data: paths, patterns & axes, train-parameters: max-time, hyper-parameters: number-gaussians & noise-direction. If training fails due to NaNs, see data: clip-outliers, hyper-parameters: number-gaussians and train-parameters: monte-carlo-kl.

model-name
  (str) Name the trained model will be saved as.

data:
paths
  (str) Path to the directory the training data is stored in. Can be a list of strings if using more than one directory
patterns
  (str) glob pattern to identify files within `paths` that will be used as training data. Current accepted file types are tiff/tif, czi & png. Edit get_imread_fn in utils.py to add data loading funtions for currently unsupported filetypes.
axes
  (str) (S(ample) | C(hannel) | T(ime) | Z | Y | X). Describes the axes of the data as they are stored. I.e., when we call tifffile.imread("your-data.tiff"), what will be the shape of the returned numpy array? 
  The sample axis can be repeated, e.g. SCSZYX, if there are extra axes that should be concatenated as samples.
number-dimensions
  (int) Number of spatial dimensions of your images. Default: 2.
  If your data has shape [T(ime), Y, X], the time dimension can be optionally treated as a spatial dimension and included in convolutions by setting this parameter to 3. If your data has shape Z, Y, X, the Z axis can be optionally treated as a sample dimension and excluded from convolutions by setting this parameter to 2.
patch-size
  (list(int) | null) [(Depth), Height, Width]. Set to patch data into non-overlapping windows. Defualt: null.
  The training/validation split is made along the sample axis. If your data has only one sample, use this to break it into  patches that will be concatenated along the sample axis. 
  This is different from crop-size below, as it is deterministic and done once at the start of training.
clip-outliers
  (bool) Hot or dead outlier pixels can disrupt training. Default: False.
  Set this to True to clip extreme pixel values between 1st and 99th percentile.

train-parameters:
max-epochs
  (int) Maximum number of epochs to train. Default: 1000.
max-time
  (str | null) Maximum time to train for. Default: null.
  Must be of form "DD:HH:MM:SS", or just `null`.
  COSDD can take a long time to converge, so use this to stop training in a reasonable time.
patience
  (int) Stop training when validation loss plateaus for this number of epochs. Default: 100.
batch-size
  (int) Number of images passed through network at a time. Default: 4.
number-grad-batches
  (int) Gradient accumulation. Default: 4.
  Number of batches to pass through network before updating model parameters.
crop-size
  (list(int))  [(Depth), Height, Width]. Default: [256, 256].
  As a form of data augmentation, at each training step a patch is randomly cropped from each training image. Set the size of that patch here.
  This is different from patch-size above as it is random and repeated at every training step.
training-split
  (float) Percent of data to use as training set. Default: 0.9.
  1 - training-split is used as validation set.
monte-carlo-kl
  (bool) Experimental. Default: False.
  Set True to calculate KL divergence using random samples from posterior. 
  I've found this can help when training crashes due to NaNs.
  Set False to calculate KL divergence analytically (common method).
direct-denoiser-loss
  (str) "L1" or "MSE". Default: "MSE".
  Train direct denoiser to calculate coordinate-median or mean, respectively.
use-direct-denoiser
  (bool) Train the direct denoiser to predict the average of samples. Default: True.
  Increases training time but reduces inference time. Worthwhile if inference is on a large dataset (GBs).

hyper-parameters:
noise-direction
  (str) "x", "y", "z" or "none". Default: "x".
  Axis along which noise is correlated.
s-code-channels
  (int) Number of feature channels in the latent code describing the clean signal. Default: 64.
  Half of this value will be used as feature channels in VAE.
number-layers
  (int) Number of levels in Ladder VAE. Default: 14.
number-gaussians
  (int) Number of components in Gaussian mixture model used to model the noise. Default: 3.
  If noise is reproduced in output, increase this value. If training fails, reduce this value.

memory:
precision
  (str) Floating point precision for training. Default: "bf16-mixed".
  "32-true"
  "32"
  "16-mixed"
  "bf16-mixed"
  "16-true"
  "bf16-true"
  "64-true"
  See https://lightning.ai/docs/pytorch/stable/common/precision.html
checkpointed
  (bool) Whether to use activation checkpointing. Default: True.
  Set True to save GPU memory. Set False to increase training speed.
gpu
  (list(int)) Index of which available GPU to use. Default: [0].

Yaml prediction config file options

Important options are: model_name, data: paths, patterns & axes.

model-name
  (str) Name of the trained model.
n-samples
  (int) When randomly sampling denoised images (i.e. not using the direct denoiser), sets number of images to sample for averaging. Default: 100.

data:
paths
  (str) Path to the directory the data is stored in. Can be a list of strings if using more than one directory
patterns
  (str) glob pattern to identify files within `paths` that will be used as prediction data. Current accepted file types are tiff/tif, czi & png. Edit get_imread_fn in utils.py to add data loading funtions for currently unsupported filetypes.
axes
  (str) (S(ample) | C(hannel) | T(ime) | Z | Y | X). Describes the axes of the data as they are stored. I.e., when we call tifffile.imread("your-data.tiff"), what will be the shape of the returned numpy array? 
  The sample axis can be repeated, e.g. SCSZYX, if there are extra axes that should be concatenated as samples.
number-dimensions
  (int) Number of spatial dimensions of your images. Default: 2.
  If your data has shape [T(ime), Y, X], the time dimension can be optionally treated as a spatial dimension and included in convolutions by setting this parameter to 3. If your data has shape Z, Y, X, the Z axis can be optionally treated as a sample dimension and excluded from convolutions by setting this parameter to 2.
patch-size
  (list(int) | null) [(Depth), Height, Width]. Set to patch data into non-overlapping windows. Defualt: null.
  Beware of patching data during prediction. The prediction.py script will automatically unpatch denoised images back to the original data shape, but this will result in boundary artefacts. 
clip-outliers
  (bool) Hot or dead outlier pixels can disrupt prediction. Default: False.
  Set this to True to clip extreme pixel values between 1st and 99th percentile.

predict-parameters:
batch-size
  (int) Number of images passed through network at a time. Default: 1.

memory:
precision
  (str) Floating point precision for training. Default: "bf16-mixed".
  "16-true"
  "16-mixed"
  "bf16-true"
  "bf16-mixed"
  "32-true"
  "64-true"
  See https://lightning.ai/docs/pytorch/stable/common/precision.html
gpu
  (list(int)) Index of which available GPU to use. Default: [0].

BibTeX

@misc{salmon2024unsuperviseddenoisingsignaldependentrowcorrelated,
      title={Unsupervised Denoising for Signal-Dependent and Row-Correlated Imaging Noise}, 
      author={Benjamin Salmon and Alexander Krull},
      year={2024},
      eprint={2310.07887},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2310.07887}, 
}

About

Correlated and Signal-Dependent Denoising

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published