Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Denoising Diffusion models are proficient for generative sampling.
  • Techniques such as BTD have been proposed to reduce the number of network calls.
  • TRACT is a new method that extends BTD.
  • TRACT improves FID by up to 2.4x on the same architecture.
  • TRACT achieves new single-step Denoising Diffusion Implicit Models state-of-the-art FID.
  • PyTorch implementation will be released soon.

Paper Content

Introduction

  • Diffusion models are state-of-the-art generative models for many domains and applications
  • Training a diffusion model is simpler than other generative modeling approaches
  • Diffusion models demonstrate excellent scalability to large models and datasets
  • Inference efficiency is a major challenge for diffusion models
  • Existing efforts to speed up inference of diffusion models can be categorized into three classes
  • Denoising Diffusion Implicit Models (DDIMs) use a noise schedule for inference
  • Neural ODEs can be used to solve DDIMs
  • Advanced ODE solvers can be used to accelerate sampling
  • Binary Time-Distillation (BTD) can be used to distill a teacher model to a student model with fewer steps
  • Log 2 T training phases are required to distill a T-steps teacher to a single-step model

Method

Adapting tract to a runge-kutta teacher and variance exploding noise schedule

  • TRACT is applied to teachers from EDM that use a VE noise schedule and an RK sampler
  • VE noise schedules are parameterized by a sequence of noise standard deviations
  • Algorithm 1 is used to train from T timesteps to T/S timesteps for groups of size S
  • RK and DDIM-VE step functions are used to estimate x t from x t
  • Loss is a weighted loss between the student network prediction and the target
  • Weighting and network preconditioning strategies are introduced in the EDM paper

Experiments

  • TRACT is tested on two image generation benchmarks: CIFAR-10 and classconditional 64x64 ImageNet
  • TRACT improves FID from 9.1 to 4.5 on CIFAR-10 and from 17.5 to 7.4 on 64x64 ImageNet
  • Results further improved to 3.8 on CIFAR-10 when distilling EDM teacher models

Image generation results with btd teachers

  • Teacher model initialized from teacher checkpoints of BTD paper
  • Two-phase distillation schedule used
  • Student weights initialized from teacher in each phase
  • Experiments with two training lengths: 96M and 256M samples
  • 1-step FID of 5.02, almost half of previous state-of-the-art
  • TRACT-256M further improves 1-step FID to 4.45
  • State-of-the-art models obtained at all steps
  • TRACT-96M student achieves FID of 7.43, 2.4x better than BTD

Image generation results with edm teachers

  • EDM models are based off NCSN++ and ADM architectures.
  • Results for TRACT-EDM models are presented in Tables 1, 2, 7 and 8.
  • Experimental details can be found in Appendix A.4.

Stochastic weight averaging ablations

  • TRACT uses two different EMAs: one for the self-teacher and one for the student model.
  • Momentum parameters µ S and µ I are studied across ablations on CIFAR-10.
  • Self-teacher weights adapt rapidly to training updates but incorporate noise from the optimization process.
  • High µ S values yield stable self-teacher targets but introduce latency between the student model state and that of its self-teacher.
  • Slow-moving EMA of student weights at inference time yields better test time performance.

Influence of the number of distillation phases

  • TRACT performs best with a 2-phase distillation schedule
  • Schedules with more phases suffer from objective degeneracy
  • Worst results obtained with single-phase distillation
  • Results get worse with more distillation phases
  • BTD outperforms TRACT with a 10-phase distillation schedule

Beyond time distillation

  • TRACT can reduce quality degradation with fewer sampling steps.
  • TRACT can be used for knowledge distillation to smaller architectures.
  • FID decreased from 5.02 to 6.47 when distilling from 60.0M to 19.4M parameters on CIFAR-10.

Conclusion

  • Generating samples in a single step can improve the tractability of diffusion models
  • We introduce TRAnsitive Closure Time-distillation (TRACT) to improve the quality of generated samples
  • TRACT improves singlestep FID by up to 2.4x
  • TRACT can be applied to other types of data
  • TRACT can improve the quality-efficiency trade-off
  • We use Adam optimizer with a constant learning rate of 2 x 10-4
  • We present random samples from distilled models with varying sampling steps
  • We present a heuristic to pick the EMA momentum parameter
  • We compare TRACT to a direct grid search for a fixed µ I parameter
  • We develop equations for the training target to obtain a perfect student
  • TRACT can be used for knowledge distillation from one architecture to another
  • We compare results to a grid search for µ I with 5 different values