Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Denoising diffusion models are used for image generation.
  • Training these models can be slow.
  • Conflicting optimization directions between timesteps can cause slow convergence.
  • Min-SNR-$\gamma$ is a method to address this issue.
  • Min-SNR-$\gamma$ improves converging speed and achieves a new record FID score.

Paper Content

Introduction

  • We proposed a Min-SNR-γ weighting strategy to tackle the conflicting gradients issue.
  • We demonstrated that the proposed strategy can significantly improve the convergence rate of diffusion training, and achieve a new record of FID score 2.06 on the ImageNet 256×256 benchmark.
  • We showed that the proposed strategy is effective and general for various generation scenarios.
  • We conducted a thorough comparison between our proposed strategy and existing works.
  • We provided a comprehensive analysis of the proposed strategy.
  • We released the source code of our proposed strategy.
  • Diffusion models have demonstrated superior performance across a range of generation tasks
  • Diffusion models have a slow convergence rate
  • Conflict in gradients across various timesteps is likely the cause of slow convergence
  • Proposed Min-SNR-γ weighting strategy to tackle conflicting gradients issue
  • Strategy improves convergence rate and achieves new record of FID score
  • Strategy is effective and general for various generation scenarios
  • Conducted comparison between proposed strategy and existing works
  • Released source code of proposed strategy
  • Diffusion models are strong generative models used in image generation
  • Diffusion models have been used in text-to-video, 3D Avatar, image to image translation, image manipulation, music generation, and drug discovery
  • UNet and Vision Transformers are widely used network structures for diffusion models
  • Recent studies have tried to improve diffusion models by guiding the sampling process, fast sampling methods, and denoising networks
  • Multi-task learning is used to learn multiple related tasks jointly, but can lead to negative transfer
  • GradNorm and MTO are approaches used to address negative transfer in multi-task learning

Method

Preliminary

  • Diffusion models have two processes: forward noising and reverse denoising
  • Forward process adds noise to real data point to obtain a series of noisy latent variables
  • Reverse process denoises the latent variables and restores the real data
  • Previous works predict the noise, while later works reparameterize to predict the noiseless state
  • Network can be used to predict velocity, which can benefit or harm surrounding timesteps

Diffusion training as multi-task learning

  • Previous studies have shared parameters of denoising models across all steps.
  • Different steps may have different requirements for denoising.
  • Experiment conducted to analyze correlation between different timesteps.
  • Finetuning specific steps benefited those surrounding steps.
  • Goal is to find an efficient solution that benefits all timesteps simultaneously.

Pareto optimality of diffusion models

  • Theorem 1 considers a solution to an optimization problem
  • A more general form of the theorem was proposed in [11]
  • Diffusion models require all timesteps to be included in training
  • A regularization term is included to prevent loss weights from becoming too small
  • Frank-Wolfe and Unconstrained Gradient Descent algorithms are used to solve the optimization problem
  • Min-SNR-γ weighting strategy is proposed to optimize different timesteps simultaneously

Experiments

  • Overview of experimental setup
  • Ablation studies to show versatility
  • Comparison to state-of-the-art methods

Setup

  • Experiments performed on CelebA and ImageNet datasets
  • Data pre-processing involves center cropping and resizing
  • ViT and UNet used as diffusion model backbones
  • AdamW optimizer used
  • Heun sampler from EDM used for image generation
  • FID score used to measure quality of generated images
  • Comparison of different weighting strategies (constant, SNR, truncated SNR, Min-SNR-γ)
  • Min-SNR-γ strategy converges faster than other strategies
  • Experiments conducted on different prediction targets (x0, noise, velocity)
  • Min-SNR-γ strategy robust across different network architectures
  • Robustness analysis conducted with different truncate values (1, 5, 10, 20)

Conclusion

  • Diffusion training process is a multi-task learning problem
  • Introduce a novel weighting strategy, Min-SNR-γ, to balance different timesteps
  • Experiments demonstrate faster convergence and state-of-the-art FID score
  • Pareto Optimality: point where any change leads to increase of one loss item
  • Equation 15 converted to min t g t , u ≥ 0
  • C T = (w 1 , w 2 , . . . , w T )|w 1 , w 2 , . . . , w T ≥ 0
  • Leveraging non-conflicting weighting strategy boosts convergence 3.4 times
  • Ablation studies on UNet backbone show robustness to hyper-parameter γ