Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Denoising Diffusion models are proficient for generative sampling.
- Techniques such as BTD have been proposed to reduce the number of network calls.
- TRACT is a new method that extends BTD.
- TRACT improves FID by up to 2.4x on the same architecture.
- TRACT achieves new single-step Denoising Diffusion Implicit Models state-of-the-art FID.
- PyTorch implementation will be released soon.
Paper Content
Introduction
- Diffusion models are state-of-the-art generative models for many domains and applications
- Training a diffusion model is simpler than other generative modeling approaches
- Diffusion models demonstrate excellent scalability to large models and datasets
- Inference efficiency is a major challenge for diffusion models
- Existing efforts to speed up inference of diffusion models can be categorized into three classes
- Denoising Diffusion Implicit Models (DDIMs) use a noise schedule for inference
- Neural ODEs can be used to solve DDIMs
- Advanced ODE solvers can be used to accelerate sampling
- Binary Time-Distillation (BTD) can be used to distill a teacher model to a student model with fewer steps
- Log 2 T training phases are required to distill a T-steps teacher to a single-step model
Method
Adapting tract to a runge-kutta teacher and variance exploding noise schedule
- TRACT is applied to teachers from EDM that use a VE noise schedule and an RK sampler
- VE noise schedules are parameterized by a sequence of noise standard deviations
- Algorithm 1 is used to train from T timesteps to T/S timesteps for groups of size S
- RK and DDIM-VE step functions are used to estimate x t from x t
- Loss is a weighted loss between the student network prediction and the target
- Weighting and network preconditioning strategies are introduced in the EDM paper
Experiments
- TRACT is tested on two image generation benchmarks: CIFAR-10 and classconditional 64x64 ImageNet
- TRACT improves FID from 9.1 to 4.5 on CIFAR-10 and from 17.5 to 7.4 on 64x64 ImageNet
- Results further improved to 3.8 on CIFAR-10 when distilling EDM teacher models
Image generation results with btd teachers
- Teacher model initialized from teacher checkpoints of BTD paper
- Two-phase distillation schedule used
- Student weights initialized from teacher in each phase
- Experiments with two training lengths: 96M and 256M samples
- 1-step FID of 5.02, almost half of previous state-of-the-art
- TRACT-256M further improves 1-step FID to 4.45
- State-of-the-art models obtained at all steps
- TRACT-96M student achieves FID of 7.43, 2.4x better than BTD
Image generation results with edm teachers
- EDM models are based off NCSN++ and ADM architectures.
- Results for TRACT-EDM models are presented in Tables 1, 2, 7 and 8.
- Experimental details can be found in Appendix A.4.
Stochastic weight averaging ablations
- TRACT uses two different EMAs: one for the self-teacher and one for the student model.
- Momentum parameters µ S and µ I are studied across ablations on CIFAR-10.
- Self-teacher weights adapt rapidly to training updates but incorporate noise from the optimization process.
- High µ S values yield stable self-teacher targets but introduce latency between the student model state and that of its self-teacher.
- Slow-moving EMA of student weights at inference time yields better test time performance.
Influence of the number of distillation phases
- TRACT performs best with a 2-phase distillation schedule
- Schedules with more phases suffer from objective degeneracy
- Worst results obtained with single-phase distillation
- Results get worse with more distillation phases
- BTD outperforms TRACT with a 10-phase distillation schedule
Beyond time distillation
- TRACT can reduce quality degradation with fewer sampling steps.
- TRACT can be used for knowledge distillation to smaller architectures.
- FID decreased from 5.02 to 6.47 when distilling from 60.0M to 19.4M parameters on CIFAR-10.
Conclusion
- Generating samples in a single step can improve the tractability of diffusion models
- We introduce TRAnsitive Closure Time-distillation (TRACT) to improve the quality of generated samples
- TRACT improves singlestep FID by up to 2.4x
- TRACT can be applied to other types of data
- TRACT can improve the quality-efficiency trade-off
- We use Adam optimizer with a constant learning rate of 2 x 10-4
- We present random samples from distilled models with varying sampling steps
- We present a heuristic to pick the EMA momentum parameter
- We compare TRACT to a direct grid search for a fixed µ I parameter
- We develop equations for the training target to obtain a perfect student
- TRACT can be used for knowledge distillation from one architecture to another
- We compare results to a grid search for µ I with 5 different values