You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implementation of MedSegDiff in Pytorch - SOTA medical segmentation out of Baidu using DDPM and enhanced conditioning on the feature level, with filtering of features in fourier space.
Appreciation
StabilityAI for the generous sponsorship, as well as my other sponsors out there
Isamu and Daniel for adding a training script for a skin lesion dataset!
Install
$ pip install med-seg-diff-pytorch
Usage
importtorchfrommed_seg_diff_pytorchimportUnet, MedSegDiffmodel=Unet(
dim=64,
image_size=128,
mask_channels=1, # segmentation has 1 channelinput_img_channels=3, # input images have 3 channelsdim_mults= (1, 2, 4, 8)
)
diffusion=MedSegDiff(
model,
timesteps=1000
).cuda()
segmented_imgs=torch.rand(8, 1, 128, 128) # inputs are normalized from 0 to 1input_imgs=torch.rand(8, 3, 128, 128)
loss=diffusion(segmented_imgs, input_imgs)
loss.backward()
# after a lot of trainingpred=diffusion.sample(input_imgs) # pass in your unsegmented imagespred.shape# predicted segmented images - (8, 3, 128, 128)
If you want to add in self condition where we condition with the mask we have so far, do --self_condition
Todo
some basic training code, with Trainer taking in custom dataset tailored for medical image formats - thanks to @isamu-isozaki
full blown transformer of any depth in the middle, as done in simple diffusion
Citations
@article{Wu2022MedSegDiffMI,
title = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
author = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
journal = {ArXiv},
year = {2022},
volume = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
title = {simple diffusion: End-to-end diffusion for high resolution images},
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year = {2023}
}
About
Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space