Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Explored a new class of diffusion models based on the transformer architecture
  • Replaced U-Net backbone with a transformer that operates on latent patches
  • Analyzed scalability of Diffusion Transformers (DiTs) through Gflops
  • Higher Gflops leads to lower FID
  • DiT-XL/2 models outperform prior diffusion models on ImageNet 512x512 and 256x256 benchmarks

Paper Content

Introduction

  • Machine learning is being powered by transformers
  • Transformers are used in natural language processing, vision, and other domains
  • Image-level generative models have not adopted transformers as much
  • U-Net architecture is the de-facto choice for generative models
  • This paper aims to replace U-Net with transformers
  • Diffusion Transformers (DiTs) are based on Vision Transformers (ViTs)
  • DiTs are scalable architectures for diffusion models
  • DiTs can achieve state-of-the-art results on ImageNet generation benchmark
  • Transformers have replaced domain-specific architectures in language, vision, reinforcement learning and meta-learning.
  • Transformers have been used as autoregressive models and to predict pixels.
  • Transformers have been used to generate non-spatial data.
  • Parameter counts are not always a good measure of complexity, Gflops are often used instead.

Diffusion transformers

Preliminaries

  • Diffusion models assume a forward noising process which gradually applies noise to real data
  • Reparameterization trick is used to sample data
  • Diffusion models are trained to learn the reverse process that inverts forward process corruptions
  • Variational lower bound is used to train the model
  • Classifier-free guidance is used to encourage the sampling procedure to find data with high p(x|c)
  • Latent diffusion models use an autoencoder to compress images into smaller spatial representations
  • DiTs are applied to latent space, but could be applied to pixel space without modification
  • Hybrid-based approach uses convolutional VAEs and transformer-based DDPMs

Diffusion transformer design space

  • DiTs are a new architecture for diffusion models
  • DiTs are based on Vision Transformer (ViT) architecture
  • Input to DiT is a spatial representation
  • First layer of DiT is “patchify” which converts spatial input into sequence of tokens
  • Standard ViT frequency-based positional embeddings are applied to all input tokens
  • Number of tokens determined by patch size hyperparameter
  • Four variants of transformer blocks to process conditional inputs
  • In-context conditioning, cross-attention block, adaptive layer norm and adaLN-Zero blocks in DiT design space
  • Sequence of N DiT blocks, each operating at hidden dimension size d
  • Four configs: DiT-S, DiT-B, DiT-L and DiT-XL
  • Standard linear decoder to decode sequence of image tokens into output noise prediction and output diagonal covariance prediction

Experimental setup

  • We explore the DiT design space and study the scaling properties of our model class.
  • We train class-conditional latent DiT models on the ImageNet dataset.
  • We use standard weight initialization techniques and a constant learning rate of 1 × 10 −4.
  • We use an off-the-shelf pre-trained variational autoencoder (VAE) model for diffusion.
  • We measure scaling performance with Fréchet Inception Distance (FID), Inception Score, sFID and Precision/Recall.
  • We implement all models in JAX and train them using TPU-v3 pods.

Experiments

  • Four DiT-XL/2 models trained with different block designs
  • FID lower with adaLN-Zero block than cross-attention and in-context conditioning
  • Initializing each DiT block as identity function improves performance
  • Increasing model size and decreasing patch size improves diffusion models
  • Gflops more important than parameter counts for model quality
  • Larger DiT models more compute-efficient
  • Scaling model size and patch size improves sample quality

State-of-the-art diffusion models

  • Trained DiT-XL/2 model on 256x256 ImageNet for 7M steps
  • Compared against state-of-the-art class-conditional generative models
  • DiT-XL/2 outperforms all prior diffusion models, decreasing FID-50K from 3.60 to 2.27
  • DiT-XL/2 is more compute-efficient than U-Net models
  • DiT-XL/2 achieves lowest FID of all prior generative models
  • DiT-XL/2 achieves higher recall values than LDM-

Model compute vs. sampling compute

  • Diffusion models can use additional compute after training by increasing the number of sampling steps when generating an image.
  • Smaller-model compute DiTs can outperform larger ones by using more sampling compute.
  • Sampling compute cannot compensate for a lack of model compute.

Conclusion

  • Introduce Diffusion Transformers (DiTs)
  • Outperform prior U-Net models
  • Scaling properties of transformer model class
  • Future work should continue to scale DiTs
  • Explore as drop-in backbone for text-to-image models
  • Include information about DiT models in Table 4
  • Include Gflop counts for DDPM U-Net models in Table 6
  • Frequency embedding followed by two-layer MLP
  • GELU nonlinearities in core transformer
  • Pre-trained VAEs used across experiments
  • Ablate three different choices of VAE decoder
  • XL/2 outperforms all prior diffusion models
  • Figures 1 and 11 show selected samples
  • Figures 13-32 show uncurated samples
  • Bubble area indicates flops of diffusion model
  • Comparing different conditioning strategies
  • Scaling the DiT model improves FID
  • Increasing transformer forward pass Gflops increases sample quality
  • Gflops of DiT models correlated with FID
  • Larger DiT models use large compute more efficiently
  • More sampling compute does not compensate for less model compute
  • Benchmarking class-conditional image generation on ImageNet 256x256 and 512x512
  • DiT-XL/2 achieves state-of-the-art FID