Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Public model zoo contains powerful pretrained models
- Question of how to assemble models for accuracy-efficiency trade-offs
- SN-Net framework for model deployment produces networks with different complexity and performance
- SN-Net splits pretrained networks (anchors) and stitches them together
- SN-Net can adapt to dynamic resource constraints
- SN-Net can challenge hundreds of models with a single network
Paper Content
Introduction
- Computational resources and data have enabled researchers to build powerful deep neural networks
- There are thousands of models available to download and execute
- Existing scalable deep learning frameworks are limited to a single model design space
- Stitchable Neural Network (SN-Net) is a novel scalable deep learning framework for efficient model design and deployment
- SN-Net stitches an off-the-shelf pretrained model family with much less training effort
- SN-Net covers a fine-grained level of model complexity/performance for a wide range of deployment scenarios
- SN-Net breaks the limit of a single pretrained model or supernet design
- Training SN-Net is as easy as training individual models
- SN-Net performance is almost predictable
- SN-Net is a new universal paradigm with a “many-to-many” pipeline
- SN-Net is a general approach for utilising the pretrained model families in the large-scale model zoo
Method
- Introduce model stitching
- Describe proposed stitchable neural networks
Preliminaries of model stitching
- Model parameters of a pretrained neural network are indicated by θ
- A feed-forward neural network can be defined as a composition of functions
- Model stitching involves splitting a neural network into two portions of functions at a layer index l
- A stitching layer is used to implement a transformation between the activation space of two different networks
- Model stitching can produce a sequence of stitched networks
- Different architectures can be stitched together without significant performance drop
Stitchable neural networks
- SN-Net is a new “many-to-many” elastic model paradigm
- It is motivated by the increasing number of pretrained models in the publicly available model zoo
- SN-Net inserts a few stitching layers to connect a family of pretrained models
- Anchors should be consistent in terms of the pretrained domain
- Stitching layers are 1x1 convolutional layers
- Least-squares solution is used as the default initialization approach
- Fast-to-Slow is the default stitching direction
- Nearest stitching strategy is used
- Stitching is done as sliding windows
- Training strategy uses knowledge distillation
- Experiments are conducted on ImageNet-1K
- Models studied are DeiT, Swin Transformer, ResNet and CNN with ViT
Main results
- Generate stitching configuration set by assembling ImageNet-1K pretrained DeiT-Ti/S/B
- Jointly train stitches in DeiT-based SN-Net on ImageNet with 50 epochs
- Visualize performance of all 71 stitches, including 3 anchors
- Performance increases when stitching more blocks from larger anchor
- Model-level interpolation between two anchors
- SN-Net achieves better performance than individually trained models from scratch
- SN-Net reduces training cost and disk storage compared to training and saving all individual networks
Ablation study
- SN-Net ablated with default training strategy of 50 epochs on ImageNet
- LS Init serves as a good starting point for learning stitching layers compared to Kaiming Init
- Fast-to-Slow helps to ensure better performance for most stitches
- Nearest stitching strategy limits a stitch to connect with a pair of anchors with nearest model complexity/performance
- Tuning stitching layers is only promising for some stitches
- Tuning full model improves performance of stitches
Conclusion
- Introduced Stitchable Neural Networks (SN-Net)
- Framework for developing elastic neural networks
- Inherit knowledge from pretrained model families
- Deliver fast and flexible accuracy-efficiency trade-offs
- Low cost for massive deployment of deep models
- Extendable to natural language processing, dense prediction and transfer learning
- Limitations: large stitching space requires more training epochs
- Nearest stitching strategy limits stitches to two types
- Different settings of sliding windows can produce different number of stitches
- Default setting of 50 training epochs
- 15 epochs still produces good performance
- Simple training strategy of randomly sampling a stitch
- Sandwich sampling rule and inplace distillation explored
- Pretrained weights of anchors necessary for convergence
- Default of 100 training images to initialize stitching layers
- More samples does not bring more performance gain
- SN-Net able to switch network topology at runtime