Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Denoising diffusion models are used for image generation.
- Training these models can be slow.
- Conflicting optimization directions between timesteps can cause slow convergence.
- Min-SNR-$\gamma$ is a method to address this issue.
- Min-SNR-$\gamma$ improves converging speed and achieves a new record FID score.
Paper Content
Introduction
- We proposed a Min-SNR-γ weighting strategy to tackle the conflicting gradients issue.
- We demonstrated that the proposed strategy can significantly improve the convergence rate of diffusion training, and achieve a new record of FID score 2.06 on the ImageNet 256×256 benchmark.
- We showed that the proposed strategy is effective and general for various generation scenarios.
- We conducted a thorough comparison between our proposed strategy and existing works.
- We provided a comprehensive analysis of the proposed strategy.
- We released the source code of our proposed strategy.
- Diffusion models have demonstrated superior performance across a range of generation tasks
- Diffusion models have a slow convergence rate
- Conflict in gradients across various timesteps is likely the cause of slow convergence
- Proposed Min-SNR-γ weighting strategy to tackle conflicting gradients issue
- Strategy improves convergence rate and achieves new record of FID score
- Strategy is effective and general for various generation scenarios
- Conducted comparison between proposed strategy and existing works
- Released source code of proposed strategy
Related works
- Diffusion models are strong generative models used in image generation
- Diffusion models have been used in text-to-video, 3D Avatar, image to image translation, image manipulation, music generation, and drug discovery
- UNet and Vision Transformers are widely used network structures for diffusion models
- Recent studies have tried to improve diffusion models by guiding the sampling process, fast sampling methods, and denoising networks
- Multi-task learning is used to learn multiple related tasks jointly, but can lead to negative transfer
- GradNorm and MTO are approaches used to address negative transfer in multi-task learning
Method
Preliminary
- Diffusion models have two processes: forward noising and reverse denoising
- Forward process adds noise to real data point to obtain a series of noisy latent variables
- Reverse process denoises the latent variables and restores the real data
- Previous works predict the noise, while later works reparameterize to predict the noiseless state
- Network can be used to predict velocity, which can benefit or harm surrounding timesteps
Diffusion training as multi-task learning
- Previous studies have shared parameters of denoising models across all steps.
- Different steps may have different requirements for denoising.
- Experiment conducted to analyze correlation between different timesteps.
- Finetuning specific steps benefited those surrounding steps.
- Goal is to find an efficient solution that benefits all timesteps simultaneously.
Pareto optimality of diffusion models
- Theorem 1 considers a solution to an optimization problem
- A more general form of the theorem was proposed in [11]
- Diffusion models require all timesteps to be included in training
- A regularization term is included to prevent loss weights from becoming too small
- Frank-Wolfe and Unconstrained Gradient Descent algorithms are used to solve the optimization problem
- Min-SNR-γ weighting strategy is proposed to optimize different timesteps simultaneously
Experiments
- Overview of experimental setup
- Ablation studies to show versatility
- Comparison to state-of-the-art methods
Setup
- Experiments performed on CelebA and ImageNet datasets
- Data pre-processing involves center cropping and resizing
- ViT and UNet used as diffusion model backbones
- AdamW optimizer used
- Heun sampler from EDM used for image generation
- FID score used to measure quality of generated images
- Comparison of different weighting strategies (constant, SNR, truncated SNR, Min-SNR-γ)
- Min-SNR-γ strategy converges faster than other strategies
- Experiments conducted on different prediction targets (x0, noise, velocity)
- Min-SNR-γ strategy robust across different network architectures
- Robustness analysis conducted with different truncate values (1, 5, 10, 20)
Conclusion
- Diffusion training process is a multi-task learning problem
- Introduce a novel weighting strategy, Min-SNR-γ, to balance different timesteps
- Experiments demonstrate faster convergence and state-of-the-art FID score
- Pareto Optimality: point where any change leads to increase of one loss item
- Equation 15 converted to min t g t , u ≥ 0
- C T = (w 1 , w 2 , . . . , w T )|w 1 , w 2 , . . . , w T ≥ 0
- Leveraging non-conflicting weighting strategy boosts convergence 3.4 times
- Ablation studies on UNet backbone show robustness to hyper-parameter γ