Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Introduce a joint diffusion model that learns meaningful internal representations for both generative and predictive tasks.
- Joint machine learning models often offer uneven performance or are unstable to train.
- Contemporary deep diffusion-based generative models can be used in both generative and predictive settings.
- Extension of the vanilla diffusion model with a classifier allows for stable joint training with shared parametrization.
- Joint diffusion model offers superior performance across various tasks.
Paper Content
Introduction
- Training a single machine learning model to synthesize data and make predictions
- Benefits on downstream problems such as calibration of model uncertainty, semi-supervised learning, unsupervised domain adaptation, and continual learning
- Deep generative models such as Variational Autoencoders used to build joint models
- Diffusion-based deep generative models (DDGMs) popular for sample quality
- Joint diffusion model with shared parametrization between classifier and UNet encoder
- Conditional sampling algorithm to optimize internal diffusion representations with a classifier
- Works with semi-supervised learning, domain adaption and counterfactual explanations
Background
- Joint models consider two random variables: x and y
- Joint distribution is factorized in two manners
- Training the joint model has advantages
- Diffusion-based deep generative models are used
- Variational lower bound is optimized with respect to parameters of the backward diffusion
- Training objective is a modified version of equation (4) without scaling
Related work
- Diffusion models can be used to improve the quality of sampled generations
- Diffusion models can be conditioned with class identities
- A joint model can be used to share parameterization between a diffusion model and a classifier
- Diffusion models can be used to learn data representations
- Joint training can be used to combine diffusion models with a classifier
Diffusion models learn data representations
- Learning useful data representations is important for having a good generator or classifier.
- Investigating parameterizations of DDGMs and using an autoencoder as a denoising decoder.
- Decomposing denoising function into two parts: encoding and decoding.
- Pooling features encoded by the same filter and concatenating the averaged representations into a single vector.
Unet representations are useful for prediction
- Averaged representations extracted from an original image by the UNet contain predictive information.
- Performance of an MLP-based classifier fed with z 0 is measured.
- Representations encoded in z 0 are very informative and can lead to performance comparable to a stand-alone classifier.
Diffusion models learn features of increasing granularity
- Trained an unsupervised DDGM on the CelebA dataset
- Fitted a binary logistic regression classifier for each of the 40 attributes in the dataset
Method
- Train a joint model composed of a classifier and a DDGM
- Use a shared parameterization, a shared encoder of the UNet architecture
- Shared encoder serves as the generative part and for calculating pooled features for the classifier
Joint diffusion models: ddgms with classifiers
- Pool latent representations of data from different levels of UNet architecture into one vector z
- Build classifier model to assign label to data example represented by vector z
- Consider parameterization of denoising diffusion model within single diffusion timestep
- Distinguish encoder that maps input into set of vectors
- Use logarithm of categorical distribution (cross-entropy loss) as training objective
- Optimize objective jointly with single optimizer over parameters
An alternative training of joint diffusion models
- Training of proposed approach is straightforward
- Monte-Carlo approximation of sum over all timesteps used for training loss
- Cross-entropy loss calculated for given y
- Diffusion model provides data representations related to different granularity of input features
- Robustness of method improved by applying same classifier to intermediate noisy images
- Noisy classifiers used for training, not prediction
Conditional sampling in joint diffusion models
- DDGM (Dhariwal & Nichol, 2021) proposes a classifier guidance approach to improve sample quality.
- The equation for sampling is changed to add a scaled gradient with respect to the target class from an externally trained classifier.
- Joint training of a classifier and diffusion model is used to simplify the classifier guidance technique.
- Algorithm 1 is used to optimize the representations with respect to the desired class.
Experiments
- We measure the quality of a classifier to evaluate if training with a diffusion model improves robustness.
- We measure the generative capability of our model to check if representations optimized by the classifier lead to more accurate conditional generations.
Predictive performance of joint diffusion models
- Evaluated predictive performance of method
- Compared method with baseline classifier and MLP classifier
- Results show proposed joint diffusion model achieves best performance on all four datasets
- Reason for this could be unsupervised training and shared encoder part being more robust
Generative performance of joint diffusion models
- Joint diffusion model with conditional sampling outperforms DDGMs
- Vanilla DDGM and DDGM with classifier guidance perform better in some cases
- Increasing step size value leads to more precise but less diverse samples
- Sweet spot for Precision and Recall is around ฮฑ โ [100, 250]
A comparison to state-of-the-art approaches
- Joint diffusion model performs on par with Joint Energy-based Model (JEM) in terms of classification accuracy, but better in terms of FID score.
- Flow-based methods (Residual Flows, Glow) perform worse than joint diffusion model.
- Joint diffusion model with only 5% of labeled data performs almost as well as stand-alone classifier trained with fully labeled training dataset.
Domain adaptation with diffusion-based fine-tuning
- Joint diffusion model can adapt to new data domain without labels
- Joint diffusion model outperforms stand-alone classifier in all three cases by a significant margin
Visual counterfactual explanations
- Joint diffusion model applied to real-world medical data
- High classification accuracy (98%) on test set
- Visual counterfactual explanations (VCE) method used to generate counterfactual explanations
- Joint diffusion model can remove/add parasite from/to image
- Difference between original example and changed class label indicates malaria plasmodium
Conclusion
- Introduced joint model combining diffusion model and classifier
- Demonstrated DDGMs learn semantically meaningful data representations
- Improved performance of classification task while retaining high quality of generations
- Enabled conditional generations with built-in classifier guidance
- Used joint diffusion model in semi-supervised learning, domain adaptation, and counterfactual explanations
- Generated counterfactual explanations to classifier using medical dataset
- Evaluated generative capabilities by measuring FID score, Precision and Recall
- Compared joint diffusion model with other joint models and SOTA discriminative and generative models
- Evaluated accuracy of classifier trained in semi-supervised and domain adaptation tasks