Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Introduce a joint diffusion model that learns meaningful internal representations for both generative and predictive tasks.
  • Joint machine learning models often offer uneven performance or are unstable to train.
  • Contemporary deep diffusion-based generative models can be used in both generative and predictive settings.
  • Extension of the vanilla diffusion model with a classifier allows for stable joint training with shared parametrization.
  • Joint diffusion model offers superior performance across various tasks.

Paper Content

Introduction

  • Training a single machine learning model to synthesize data and make predictions
  • Benefits on downstream problems such as calibration of model uncertainty, semi-supervised learning, unsupervised domain adaptation, and continual learning
  • Deep generative models such as Variational Autoencoders used to build joint models
  • Diffusion-based deep generative models (DDGMs) popular for sample quality
  • Joint diffusion model with shared parametrization between classifier and UNet encoder
  • Conditional sampling algorithm to optimize internal diffusion representations with a classifier
  • Works with semi-supervised learning, domain adaption and counterfactual explanations

Background

  • Joint models consider two random variables: x and y
  • Joint distribution is factorized in two manners
  • Training the joint model has advantages
  • Diffusion-based deep generative models are used
  • Variational lower bound is optimized with respect to parameters of the backward diffusion
  • Training objective is a modified version of equation (4) without scaling
  • Diffusion models can be used to improve the quality of sampled generations
  • Diffusion models can be conditioned with class identities
  • A joint model can be used to share parameterization between a diffusion model and a classifier
  • Diffusion models can be used to learn data representations
  • Joint training can be used to combine diffusion models with a classifier

Diffusion models learn data representations

  • Learning useful data representations is important for having a good generator or classifier.
  • Investigating parameterizations of DDGMs and using an autoencoder as a denoising decoder.
  • Decomposing denoising function into two parts: encoding and decoding.
  • Pooling features encoded by the same filter and concatenating the averaged representations into a single vector.

Unet representations are useful for prediction

  • Averaged representations extracted from an original image by the UNet contain predictive information.
  • Performance of an MLP-based classifier fed with z 0 is measured.
  • Representations encoded in z 0 are very informative and can lead to performance comparable to a stand-alone classifier.

Diffusion models learn features of increasing granularity

  • Trained an unsupervised DDGM on the CelebA dataset
  • Fitted a binary logistic regression classifier for each of the 40 attributes in the dataset

Method

  • Train a joint model composed of a classifier and a DDGM
  • Use a shared parameterization, a shared encoder of the UNet architecture
  • Shared encoder serves as the generative part and for calculating pooled features for the classifier

Joint diffusion models: ddgms with classifiers

  • Pool latent representations of data from different levels of UNet architecture into one vector z
  • Build classifier model to assign label to data example represented by vector z
  • Consider parameterization of denoising diffusion model within single diffusion timestep
  • Distinguish encoder that maps input into set of vectors
  • Use logarithm of categorical distribution (cross-entropy loss) as training objective
  • Optimize objective jointly with single optimizer over parameters

An alternative training of joint diffusion models

  • Training of proposed approach is straightforward
  • Monte-Carlo approximation of sum over all timesteps used for training loss
  • Cross-entropy loss calculated for given y
  • Diffusion model provides data representations related to different granularity of input features
  • Robustness of method improved by applying same classifier to intermediate noisy images
  • Noisy classifiers used for training, not prediction

Conditional sampling in joint diffusion models

  • DDGM (Dhariwal & Nichol, 2021) proposes a classifier guidance approach to improve sample quality.
  • The equation for sampling is changed to add a scaled gradient with respect to the target class from an externally trained classifier.
  • Joint training of a classifier and diffusion model is used to simplify the classifier guidance technique.
  • Algorithm 1 is used to optimize the representations with respect to the desired class.

Experiments

  • We measure the quality of a classifier to evaluate if training with a diffusion model improves robustness.
  • We measure the generative capability of our model to check if representations optimized by the classifier lead to more accurate conditional generations.

Predictive performance of joint diffusion models

  • Evaluated predictive performance of method
  • Compared method with baseline classifier and MLP classifier
  • Results show proposed joint diffusion model achieves best performance on all four datasets
  • Reason for this could be unsupervised training and shared encoder part being more robust

Generative performance of joint diffusion models

  • Joint diffusion model with conditional sampling outperforms DDGMs
  • Vanilla DDGM and DDGM with classifier guidance perform better in some cases
  • Increasing step size value leads to more precise but less diverse samples
  • Sweet spot for Precision and Recall is around ฮฑ โˆˆ [100, 250]

A comparison to state-of-the-art approaches

  • Joint diffusion model performs on par with Joint Energy-based Model (JEM) in terms of classification accuracy, but better in terms of FID score.
  • Flow-based methods (Residual Flows, Glow) perform worse than joint diffusion model.
  • Joint diffusion model with only 5% of labeled data performs almost as well as stand-alone classifier trained with fully labeled training dataset.

Domain adaptation with diffusion-based fine-tuning

  • Joint diffusion model can adapt to new data domain without labels
  • Joint diffusion model outperforms stand-alone classifier in all three cases by a significant margin

Visual counterfactual explanations

  • Joint diffusion model applied to real-world medical data
  • High classification accuracy (98%) on test set
  • Visual counterfactual explanations (VCE) method used to generate counterfactual explanations
  • Joint diffusion model can remove/add parasite from/to image
  • Difference between original example and changed class label indicates malaria plasmodium

Conclusion

  • Introduced joint model combining diffusion model and classifier
  • Demonstrated DDGMs learn semantically meaningful data representations
  • Improved performance of classification task while retaining high quality of generations
  • Enabled conditional generations with built-in classifier guidance
  • Used joint diffusion model in semi-supervised learning, domain adaptation, and counterfactual explanations
  • Generated counterfactual explanations to classifier using medical dataset
  • Evaluated generative capabilities by measuring FID score, Precision and Recall
  • Compared joint diffusion model with other joint models and SOTA discriminative and generative models
  • Evaluated accuracy of classifier trained in semi-supervised and domain adaptation tasks