Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Diffusion models (DMs) have been successful in generative modeling tasks.
- Sampling procedure of DMs is slow and requires hundreds to thousands of time discretization steps.
- Goal is to develop a fast sampling method for DMs with fewer steps while retaining high sample quality.
- Discretization method is most crucial factor affecting sample quality.
- Proposed Diffusion Exponential Integrator Sampler (DEIS) is based on Exponential Integrator designed for discretizing ODEs.
- DEIS can generate high-fidelity samples in as few as 10 steps.
- State-of-art sampling performance achieved with limited number of score function evaluations.
Paper Content
Introduction
- Diffusion Model (DM) is a generative modeling method
- DM has advantage of stable training and avoids mode-collapsing
- DM has achieved impressive performances on a variety of tasks
- DM has slow sampling compared to GANs
- Several studies aim to improve DM sampling speed
- Error analysis of numerical solvers for DM conducted
- Proposed Diffusion Exponential Integrator Sampler (DEIS)
- DEIS achieves superior sampling quality with limited number of network function evaluations
- DEIS generates high quality synthetic images with 12 steps, compared with 1000 steps for DDIM (83x speedup)
Background on diffusion models
- DM consists of a forward diffusion process that adds noise and a backward diffusion process that removes noise
- Forward diffusion is described by a stochastic differential equation
- Two popular SDEs used by diffusion models are summarized in Tab 1
- Backward denoising diffusion matches forward diffusion in probability law
- Score network is used to approximate score and trained by minimizing denoising score matching loss
- Loss can be evaluated using empirical samples and standard stochastic optimization algorithms
Fast sampling with learned score models
- Score network s θ (x, t) can be used to generate new samples by solving a backward SDE.
- A family of SDEs parameterized by λ ≥ 0 is considered.
- When λ = 0, SDE reduces to an ODE known as the probability flow ODE.
- Marginal distributions of SDE and ODE match when score network is accurate.
- Generative model error is caused by fitting and discretization errors.
- Fitting error is due to mismatch between learned score network and ground truth score.
- Discretization error is caused by discretization of SDE.
- Learned score network is not accurate for most x, t.
- Good discretization scheme can help reduce impact of fitting error.
Discretization error
- Investigating the discretization error of solving the probability flow ODE (λ = 0)
- Exact solution to ODE is known as the transition matrix from time s to t associated with F τ
- Different numerical solvers for Eq (5) associated with different discretization schemes to approximate Eq (6)
- To achieve fast sampling with Eq (5), need to approximately solve it with a small number of discretization steps, and thus large step size
- Systematically study the discretization error in solving Eq (5), both theoretically and empirically
- Develop an efficient algorithm for Eq (5) that requires a small number of NFEs
- Euler method is widely used in numerical softwares, but has low accuracy and is sometimes unstable
- Exponential Integrator (EI) leverages the semilinear structure of Eq (5)
- EI performs worse than the Euler method, suggesting other major factors contribute to the error
- Parameterization ∇ log p t (x) ≈ −L −T t θ (x, t) leads to significant improvements of accuracy
- EI discretization Eq (11) coincides with the deterministic DDIM when the forward diffusion Eq (1) is VPSDE
- Use high order polynomial extrapolation of θ in EI method to reduce approximation error
- Compare linear timesteps and quadratic timesteps to balance the approximation error and the accumulation error
Exponential integrator: simplify probability flow ode
- DEIS can be transformed into a non-stiff ODE
- Off-the-shelf ODE solvers can be applied to solve the ODE
- Transformation can be improved by taking into account analytical form of Ψ, Gt, Lt
- Transformation results in a black-box ODE that can only be solved by generic ODE solvers
- Two variants of DEIS proposed: ρRK-DEIS and ρAB-DEIS
Discretization error of sde sampling
- Problem of solving SDE Eq (4) with λ > 0 leads to sampling scheme from DMs
- Exact solution to Eq (4) satisfies linear term
- Goal is to approximate Eq (19) through discretization
- Stochastic DDIM is a numerical solver for Eq (19)
More experiments
- DEIS algorithm compared to several fast diffusion model sampling algorithms on image generation tasks
- Quantitative comparisons on CIFAR10, CELEBA, ImageNet 32x32 datasets
- DDIM, A-DDIM, PNDM used as baselines
- Improved PNDM (iPNDM) proposed to avoid expensive warming start
- Performance evaluations of various DEIS with VPSDE on CIFAR10 in Tab 2
- DEIS algorithms can generate high-fidelity images with a very small NFE budget
- High order polynomial approximation can significantly outperform DDIM
Conclusion and broader impact
- Fast sampling algorithm for DMs developed
- Generates high-fidelity samples with less than 10 NFEs
- Potential ways to improve algorithm
- Can reduce computational cost for generative tasks
- Potential for malicious usage
A more related works
- Research has been conducted to speed up sampling of DMs
- Authors in [23,44] optimize denoising process
- Authors in [34] use non-Markovian forward noising
- Authors in [3] optimize backward Markovian process
- Authors in [46,40,7,43,4] use Schrödinger bridge
- Authors in [25] interpret update step in stochastic DDIM as combination of gradient estimation and transfer step
E.7 more reuslts on vesde
- VESDE sampling can be significantly accelerated compared to previous methods.
- Code will be released in the future and was implemented in Jax and PyTorch.
- Used code from a number of sources.
- FID results of DEIS on VESDE CIFAR10.
- Fitting error on a toy demo.
- Pixel difference between ground truth and numerical solution.
- Approximation error along ground truth solutions.
- Score approximation error when parameterization is used.
- Developed fast sampling algorithms called DEIS.
- Extrapolation error with N = 10, 20.
- High order polynomial can reduce approximation error.
- Sampling quality measured by FID of DEIS on CIFAR10.
- Relative changes of θ with respect to t.
- Generated images with DEIS on VPSDE CIFAR10, ImageNet 32x32, and CelebA.
- Comparison between quadratic timesteps and linear timesteps.
- Optimizing time discretization.
- More results of DEIS for VPSDE on CIFAR10 with limited NFE.
- Comparison with A-DDIM on CIFAR10.
- Blackbox ODE solver reports FID 8.34 with ODE tolerance 1x10-5.
- More results of VPSDE on CIFAR10 with DOPRI5 ODE solver.
- Mean and standard deviation of multiple runs with 4 different random seeds on CELEBA.