"Hello there!". Welcome to the implementation of YODA (You Only Denoise once - or Average) as described in the paper "Regression is all you need for medical image translation" (see here for a visual abstract)
Our key findings are:
- Using 2.5D diffusion, we can achieve highly accurate 3D image synthesis avoiding artifacts from 2D slice-wise synthesis.
- Using regression sampling, we can achieve highly accurate noise-free image synthesis avoiding the need for expensive diffusion sampling and averaging multiple images to achieve noise suppression.
- By averaging several diffusion images to approximate the expected value of the random diffusion sampling process in ExpA sampling, we show that diffusion and regression sampling are equivalent i.e. the additional generation of fine-grained high-frequency details is non-systematic and mainly imitates acquisition noise
Here are some example results demonstrating YODA's performance on the public test case of the Rhineland study (RS) for translating T1w and T2w to FLAIR images. Click here to skip to the code instructions and reproduce our results.
Synthetic FLAIR images from single-step (regression-like) sampling. This can be done in
To generate realistic images, i.e. simulate acquisition noise, we can also use diffusion sampling.
However, this takes
Note that given the probabilistic nature of the sampling, the results are not deterministic, so we can draw multiple samples from the same model and inputs.
We can sample and average multiple images to approximate noise-free images (i.e. the expected value of the corresponding random variable) similar to physical multi-excitation (MEX) signal averages:
So in this example, after only 31 samples drawn from the model (>8h on a V100), you would get the same image quality (in terms of SSIM) as with single-step regression sampling.Here are some instructions to run our code and replicate some of our results:
This code is based on PyTorch and makes heavy use of the force of MONAI
and the (by now deprecated) MONAI generative
frameworks.
The exact dependencies can be found in the requirements.txt
file, yet, we recommend using docker/singularity:
The dependencies are available from dockerhub
and can be pulled using the following command:
docker pull srassmann/dif
Alternatively, the docker image can be converted to a singularity image using the following command:
SING_FILE=$HOME/singularity/${USER}_dagobah.sif
singularity build $SING_FILE docker://srassmann/dif:latest
We will for now assume that python
is from the correct environment, e.g. by using singularity exec $SING_FILE python
or docker exec -v <binds> -it $USER/dif python
.
This could be done via setting in your bash session:
alias python="singularity exec --nv -B <potential binds of symlinked data etc> $SING_FILE python"
Our preprocessing pipeline consists of registration & resampling, followed by segmentation. We used the following tools for this purpose:
The registration of source and target modalities is performed using FreeSurfer (v7.4).
This can be installed natively or via
docker/singularity (Note, that FreeSurfer requires a license).
However, other tools can likely also be used to perform the registration.
The full-brain segmentation is performed using FastSurfer (v2.2), which is also available as docker/singularity image.
Yet, except for obtaining precise label-wise brain metrics like the noise-level of the WM,
segmentation (including FreeSurfer's mri_synthstrip
) can be used just as well.
The mask is only used to constrain the synthesis ROI and, optionally, for skull-stripping / background masking.
If you want, you can also omit it altogether, however, then precious computation time is wasted on translating the background,
which is rather bothersome for diffusion sampling (again, not really a need for that ...).
Model weights will be released on Zenodo (link tba).
We expect the model weights to be placed in output/<run_name>/ckpt
, where <run_name>
is the name of the run and the model's base config to be in output/<run_name>/config.yml
.
For simplicity, we assume the data to be stored in ../data/<dataset_name>
where <dataset_name>
is the name of the dataset.
Within is directory, we expect one folder per subject, each containing the modalities as .nii.gz
files.
E.g. to reproduce FLAIR synthesis in the Rhineland study using the released example images (as shown above), the data should be organized as follows:
RAW_DATA=../data/rs_example_raw
mkdir -p $RAW_DATA
wget https://zenodo.org/records/11186582/files/sub_rs_mri_raw.zip -o ../data/rs_example
unzip -j $RAW_DATA/sub_rs_mri_raw.zip sub_rs_mri_raw/T1_RMS.nii.gz sub_rs_mri_raw/T2_caipi.nii.gz sub_rs_mri_raw/FLAIR.nii.gz -d $RAW_DATA && trash $RAW_DATA/sub_rs_mri_raw.zip
tree $RAW_DATA
../data/rs_example/
├── subj_0000
│ ├── FLAIR.nii.gz
│ ├── T1_RMS.nii.gz
│ └── T2_caipi.nii.gz
└── subj_0001 # in case you had more subjects
├── [...]
see here for details
In the case (like here) that the data is not already registered and resampled, do that with your tool of choice, e.g. (assuming FreeSurfer to be sourced):REGISTERED_DATA=../data/rs_example_registered
SOURCE_MODS=("T1_RMS T2_caipi")
TARGET_MOD="FLAIR"
mkdir -p $REGISTERED_DATA
for subj in $RAW_DATA/*; do
subj_name=$(basename $subj)
mkdir -p $REGISTERED_DATA/$subj_name
ln -s $(realpath $RAW_DATA/$subj_name/FLAIR.nii.gz) $REGISTERED_DATA/$subj_name/FLAIR.nii.gz
for mod in $SOURCE_MODS; do
mri_synthstrip -i $RAW_DATA/$subj_name/${mod}.nii.gz -m $REGISTERED_DATA/$subj_name/${mod}_brainmask.nii.gz --gpu
mri_coreg --mov $RAW_DATA/$subj_name/${mod}.nii.gz --ref $REGISTERED_DATA/$subj_name/$TARGET_MOD.nii.gz --reg $REGISTERED_DATA/$subj_name/${mod}_to_${TARGET_MOD}.lta \
--mov-mask $REGISTERED_DATA/$subj_name/${mod}_brainmask.nii.gz --ref-mask $REGISTERED_DATA/$subj_name/FLAIR.nii.gz --threads 16
mri_vol2vol --cubic --mov $RAW_DATA/$subj_name/${mod}.nii.gz --targ $REGISTERED_DATA/$subj_name/FLAIR.nii.gz \
--reg $REGISTERED_DATA/$subj_name/${mod}_to_${TARGET_MOD}.lta --o $REGISTERED_DATA/$subj_name/${mod}.nii.gz
done
done
This might take a couple of minutes / subject.
Note that here we register to the target modality (FLAIR). If the target modality is not available (e.g. IXI or HCP), we recommend registering to the T2w images (resampling to ~1mm iso.).
We rely on the FastSurfer script to robustly normalize the intensities of the registered images to 8 bit.
To do so, we use the following command (assuming appropriate python env, see above, e.g. replace with
singularity --nv exec $SING_FILE python
and don't forget to mount the data via -B
or, in docker via -v
):
INPUT=$REGISTERED_DATA # change if registered otherwise
CONFORMED_DATA=../data/rs_example_conformed
python scripts/preprocessing/conform.py -i $INPUT -o $CONFORMED_DATA --seqs $SOURCE_MODS $TARGET_MOD
Note that conformed/normalized/other pre-processed datasets (e.g. BraTS) might not require this step.
Furthermore, both inference and training requires a tissue mask to define the translation ROI.
Here, we simply use the mri_synthstrip
masks, which are already in the original space:
for subj in $RAW_DATA/*; do ln -s $subj/${TARGET_MOD}_brainmask.nii.gz $CONFORMED_DATA/$(basename $subj)/mask.nii.gz ; done
In case you were to use Fast/FreeSurfer for brain masking, you also want to map the brain mask (aseg.auto_noCCseg.mgz
) back to the original space.
See the respective script to this end.
In the lazy case, you can, however, omit the mask and simply symlink e.g. one of the input modalities. Then, the whole image (cropped to the max size of the model) will be translated.
To inform YODA about the data, define a dataset JSON file that we need.
JASON=../data/rs_example.json
JASONwM=../data/rs_example_noMask.json
This file looks like smth like so:
JASON=../data/rs_example.json
touch $JASON
echo $'''
{
"training": [
{
"subject_ID": "subj_0001",
"_comment": "theoretically, multiple scans per subject are possible for each sequence",
"flair": ["subj_0001/FLAIR.nii.gz"],
"t1": ["subj_0001/T1.nii.gz", "subj_0001/T1_RMS.nii.gz"],
"t2": ["subj_0001/T2_caipi.nii.gz"],
"mask": "subj_0001/mask.nii.gz"
}
], "validation": [
{
"subject_ID": "subj_0000",
"_comment" : "same structure as training, however only one modality per subject!",
"flair": "subj_0000/FLAIR.nii.gz",
"t1": "subj_0000/T1_RMS.nii.gz",
"t2": "subj_0000/T2_caipi.nii.gz",
"mask": "subj_0000/mask.nii.gz"
}
]
} ''' > $JASON
JASONwM=../data/rs_example_noMask.json
touch $JASONwM
echo $'''
{
"training": [],
"validation": [
{
"subject_ID": "subj_0000_noMask",
"_comment" : "same as before, but using as dummy as mask",
"flair": "subj_0000/FLAIR.nii.gz",
"t1": "subj_0000/T1_RMS.nii.gz",
"t2": "subj_0000/T2_caipi.nii.gz",
"mask": "subj_0000/T2_caipi.nii.gz"
}
]
} ''' > $JASONwM
Here you can find the basic usage of the prediction scripts.
See the respective --help
options for further options and ways to customize such as e.g. using different guidance/target sequence names.
To predict the FLAIR image of subj_0000
using the model weights and regression single-step sampling, run the following command
(assuming python to be in the correct environment, don't forget to mount the data via -B
or, in docker via -v
and enable docker via --nv
!):
RUN=rs_FLAIR_from_T1T2 # name of the run, the main configs are taken from output/<run_name>/config.yml
OUTNAME=predict_RS_example
CONF=configs/inference_schedulers/Regression.yml # define regression sampling
SHARED_ARGS=" -r $RUN -dj $JASON -dd $CONFORMED_DATA" # shared arguments
python predict/2d_yoda_predict.py $SHARED_ARGS $CONF -o $OUTNAME
Congrats, you have just used the force of YODA to predict a noise-free FLAIR image from T1w and T2w.
If you now want to also predict the other views for view aggregation, you can additionally run the following commands:
python predict/2d_yoda_predict.py $SHARED_ARGS $CONF -o ${OUTNAME}_cor -sd coronal
python predict/2d_yoda_predict.py $SHARED_ARGS $CONF -o ${OUTNAME}_sag -sd sagittal
python scripts/postprocessing/average_echos.py output/$RUN/${OUTNAME}* --o output/$RUN/${OUTNAME}_rms -s "pred_flair.nii.gz" # average the views
The view-aggregation results are in output/$RUN/${OUTNAME}_rms/subj_0000/pred_flair.nii.gz
.
Note: experts use the --force
flag to maximize YODA's capabilities.
Sampling without a mask (as specified in $JASONwM
), can be done as:
SHARED_ARGS=" -r $RUN -dj $JASONwM -dd $CONFORMED_DATA" # shared arguments
python predict/2d_yoda_predict.py $SHARED_ARGS $CONF -o $OUTNAME -om
python predict/2d_yoda_predict.py $SHARED_ARGS $CONF -o ${OUTNAME}_cor -sd coronal -om
python predict/2d_yoda_predict.py $SHARED_ARGS $CONF -o ${OUTNAME}_sag -sd sagittal -om
rm -r output/$RUN/${OUTNAME}_rms
python scripts/postprocessing/average_echos.py output/$RUN/${OUTNAME}* --o output/$RUN/${OUTNAME}_rms -s "pred_flair.nii.gz" # average the views
However, note this will simply center-crop the image, which might chop some important parts off.
Alternatively, diffusion sampling - potentially with
NEX=4 # how many images to average, can also be one
LAZY=250 # truncation, i.e. step to which to skip --> here the diffusion will skip from step 999 -> 250 sparing 1/4 of compute
MEXds=250 # expectation approximation diversion step --> step from which on to diverge into individual sampling trajectories
OUTNAME=predict_RS_example_diffusion_expa$NEX
python predict/25d_yoda_predict.py $SHARED_ARGS -o $OUTNAME -cor $RUN -sag $RUN \
-nex $NEX -lazy $LAZY -mexds $MEXds
Here, -cor
and -sag
could be distinct, view-specific models. Yet, we don't usually do that as we found no benefit for the extra training effort.
Note that we use a different script 25d_
rather than 2d_
.
Furthermore, note that diffusion sampling is inherently very time-consuming.
Thus, if the computational force is strong in your lab,
you can go for subject-wise parallelization on multi-GPU systems and on a SLURM cluster for which we provide the scipts in the batch
folder,
You can also use configs for pre-defined combinations such as data sets.
E.g. to the test the RS YODA on other datasets, you'd had to always set the -ds
and -dj
flags.
For e.g. the IXI (which does not have a FLAIR) sequence you'd also to need specify the src and trg sequences.
To simplify we can alternatively merge the corresponding config like so:
python predict/2d_yoda_predict.py $SHARED_ARGS -o $OUTNAME configs/inference_schedulers/Regression.yml configs/datasets/ixi_test.yml
Note that when using multiple configs, they overwrite each other (from right to left),
i.e. the model config output/$RUN/config.yml
is overwritten by **/Regression.yml
, which is again overwritten by **ixi_test.yml
.
Furthermore, note that some options (e.g. setting the 'target_sequences=null' or the skullstripping
) are not supperted via the flags.
Just create simple (tmp
) configs instead as shown above.
To train your own YODA model preprocess the data, i.e. register and create tissue masks. For brain MRI translation, we recommend the same processing as described above for the inference.
You will need to create a JSON file specifying your data, similar to the inference cases explained previously.
Some examples (for IXI, BraTS, and the Gold Altas) for creating these JSONs can be found at nb/config_creation
.
The models can be trained using
python train/train_yoda_ddp.py -n new_hope output/rs_FLAIR_from_T1T2/config.yml
You can either add configs or cmd-line flags to the train script.
Child nodes (c
) of parent nodes (p
) can be specified in the dot notation (--p.c <value>
), so e.g. the batch size can be set using --data.batch_size <value>
.
Note that, again, the configs are overwritten from left to right, and cmd flags overwrite the respective configs, e.g. assume we want to train the BraTS model on the RS with an effective batch size of 96 (12*8):
The options and their default values are defined in the configs/defaults.yml
file.
See the comments for an explanation of the options.
python train/train_yoda_ddp.py -n empire_strikes_back \
output/brats_FLAIR_from_T1T2/config.yml configs/datasets/rs_train.yml \
--data.batch_size 12 --data.num_workers 8 --trainer.gradient_accumulation_steps 8
YODA can be easily trained on multiple GPUs (on a single node) with DDP
(again assuming that torchrun refers to the correct env, e.g., by
alias torchrun="singularity exec --nv -B <binds> $SING_FILE torchrun"
):
NUM_GPUS=$(nvidia-smi -L | wc -l) # use all GPUs, assume 8, i.e. 12 * 8 = 96 effective batch size
torchrun --nproc_per_node $NUM_GPUS train/train_yoda_ddp.py -n return_of_jedi \
output/brats_FLAIR_from_T1T2/config.yml configs/datasets/rs_train.yml \
--data.batch_size 12 --data.num_workers 8
We also provide a template for SLURM jobs (batch/example_train_job_slurm.sh
).
Congrats, you have now trained your very own first YODA model! "I feel the force is strong with you."