Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Explored a new class of diffusion models based on the transformer architecture
- Replaced U-Net backbone with a transformer that operates on latent patches
- Analyzed scalability of Diffusion Transformers (DiTs) through Gflops
- Higher Gflops leads to lower FID
- DiT-XL/2 models outperform prior diffusion models on ImageNet 512x512 and 256x256 benchmarks
Paper Content
Introduction
- Machine learning is being powered by transformers
- Transformers are used in natural language processing, vision, and other domains
- Image-level generative models have not adopted transformers as much
- U-Net architecture is the de-facto choice for generative models
- This paper aims to replace U-Net with transformers
- Diffusion Transformers (DiTs) are based on Vision Transformers (ViTs)
- DiTs are scalable architectures for diffusion models
- DiTs can achieve state-of-the-art results on ImageNet generation benchmark
Related work
- Transformers have replaced domain-specific architectures in language, vision, reinforcement learning and meta-learning.
- Transformers have been used as autoregressive models and to predict pixels.
- Transformers have been used to generate non-spatial data.
- Parameter counts are not always a good measure of complexity, Gflops are often used instead.
Diffusion transformers
Preliminaries
- Diffusion models assume a forward noising process which gradually applies noise to real data
- Reparameterization trick is used to sample data
- Diffusion models are trained to learn the reverse process that inverts forward process corruptions
- Variational lower bound is used to train the model
- Classifier-free guidance is used to encourage the sampling procedure to find data with high p(x|c)
- Latent diffusion models use an autoencoder to compress images into smaller spatial representations
- DiTs are applied to latent space, but could be applied to pixel space without modification
- Hybrid-based approach uses convolutional VAEs and transformer-based DDPMs
Diffusion transformer design space
- DiTs are a new architecture for diffusion models
- DiTs are based on Vision Transformer (ViT) architecture
- Input to DiT is a spatial representation
- First layer of DiT is “patchify” which converts spatial input into sequence of tokens
- Standard ViT frequency-based positional embeddings are applied to all input tokens
- Number of tokens determined by patch size hyperparameter
- Four variants of transformer blocks to process conditional inputs
- In-context conditioning, cross-attention block, adaptive layer norm and adaLN-Zero blocks in DiT design space
- Sequence of N DiT blocks, each operating at hidden dimension size d
- Four configs: DiT-S, DiT-B, DiT-L and DiT-XL
- Standard linear decoder to decode sequence of image tokens into output noise prediction and output diagonal covariance prediction
Experimental setup
- We explore the DiT design space and study the scaling properties of our model class.
- We train class-conditional latent DiT models on the ImageNet dataset.
- We use standard weight initialization techniques and a constant learning rate of 1 × 10 −4.
- We use an off-the-shelf pre-trained variational autoencoder (VAE) model for diffusion.
- We measure scaling performance with Fréchet Inception Distance (FID), Inception Score, sFID and Precision/Recall.
- We implement all models in JAX and train them using TPU-v3 pods.
Experiments
- Four DiT-XL/2 models trained with different block designs
- FID lower with adaLN-Zero block than cross-attention and in-context conditioning
- Initializing each DiT block as identity function improves performance
- Increasing model size and decreasing patch size improves diffusion models
- Gflops more important than parameter counts for model quality
- Larger DiT models more compute-efficient
- Scaling model size and patch size improves sample quality
State-of-the-art diffusion models
- Trained DiT-XL/2 model on 256x256 ImageNet for 7M steps
- Compared against state-of-the-art class-conditional generative models
- DiT-XL/2 outperforms all prior diffusion models, decreasing FID-50K from 3.60 to 2.27
- DiT-XL/2 is more compute-efficient than U-Net models
- DiT-XL/2 achieves lowest FID of all prior generative models
- DiT-XL/2 achieves higher recall values than LDM-
Model compute vs. sampling compute
- Diffusion models can use additional compute after training by increasing the number of sampling steps when generating an image.
- Smaller-model compute DiTs can outperform larger ones by using more sampling compute.
- Sampling compute cannot compensate for a lack of model compute.
Conclusion
- Introduce Diffusion Transformers (DiTs)
- Outperform prior U-Net models
- Scaling properties of transformer model class
- Future work should continue to scale DiTs
- Explore as drop-in backbone for text-to-image models
- Include information about DiT models in Table 4
- Include Gflop counts for DDPM U-Net models in Table 6
- Frequency embedding followed by two-layer MLP
- GELU nonlinearities in core transformer
- Pre-trained VAEs used across experiments
- Ablate three different choices of VAE decoder
- XL/2 outperforms all prior diffusion models
- Figures 1 and 11 show selected samples
- Figures 13-32 show uncurated samples
- Bubble area indicates flops of diffusion model
- Comparing different conditioning strategies
- Scaling the DiT model improves FID
- Increasing transformer forward pass Gflops increases sample quality
- Gflops of DiT models correlated with FID
- Larger DiT models use large compute more efficiently
- More sampling compute does not compensate for less model compute
- Benchmarking class-conditional image generation on ImageNet 256x256 and 512x512
- DiT-XL/2 achieves state-of-the-art FID