Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Machine learning architectures for processing neural networks in their raw weight matrix form is a new research direction.
  • This design is challenging due to the unique symmetry structure of deep weight spaces.
  • If successful, these architectures could be used for a range of tasks, such as adapting a pre-trained network to a new domain.
  • A novel network architecture is presented that is equivariant to the natural permutation symmetry of the MLP’s weights.
  • This architecture is composed of layers that are implemented using pooling, broadcasting, and fully connected layers.
  • The effectiveness of the architecture is demonstrated in a variety of learning tasks.

Paper Content

Introduction

  • Deep neural networks are used to learn functions from data and represent data samples
  • It is desirable to operate directly over the weights of a pre-trained deep model
  • Learning in neural weight spaces is still in its infancy
  • Few studies have used generic architectures to predict model accuracy or hyperparameters
  • Three papers have partially addressed the question in the context of INRs

Preliminaries

  • Notation used: [n], [k, m], Π d, S d, 1 (all ones vector)
  • Group representations and equivariance: vector space V, group G, homomorphism ρ, G-linear map, invariant function, sub-representation, direct sum of representations
  • MultiLayer Perceptrons: sequential neural networks, fully connected layers, M-layer MLP, weight matrices, bias vectors, pointwise non-linearity

Permutation symmetries of neural networks

  • MLPs have permutation symmetries
  • Weight-space of an M-layer MLP is defined
  • Symmetry group of the weight space is the direct product of symmetric groups for all the intermediate dimensions
  • Transformed set of parameters represents the same function as the initial set
  • Symmetries have been used to investigate the loss landscape of neural networks

A characterization of linear invariant and equivariant layers for weight-spaces

  • Overview of DWS-layers in computer science paper
  • Explanation of how results can be formally proved

Overview and main results

  • Characterizing affine equivariant and invariant maps for the weight space V requires finding bases for three linear spaces.
  • A decomposition of the weight space V into multiple sub-representations is used to simplify the characterization.
  • The layer L is split into four linear maps between the weight and bias spaces.
  • Equivariant layers between W, B are constructed by splitting them into sub-representations.
  • All equivariant maps are either previously characterized linear equivaraint layers or simple combinations of pooling, broadcasting and fully connected linear layers.
  • The layer L is implemented by executing all the blocks independently and then summing the outputs.

Linear equivariant maps for direct sums

  • Every equivariant linear operator between direct sums of representations can be written in a block matrix form.
  • Proposition 5.2 reduces the problem of characterizing equivariant maps between direct sums of representations to simpler problems of characterizing equivariant maps between the constituent sub-representations.

Linear equivariant layers for deep weight spaces

  • Constructing a basis for linear equivariant functions between a weightspace to itself
  • Bias-to-bias layers composed of blocks that map between bias spaces
  • Three examples of bias-to-bias layers: fully connected linear layer, DeepSets layer, and a layer that sums and multiplies
  • Basic operations for constructing layers between sub-representations: pooling, broadcasting, and fully-connected linear maps
  • Rules for constructing equivariant layers between sub-representations
  • Dimension-counting argument to prove that layers form a basis
  • Generalization of Theorem 5.1 to multiple input and output channels in Appendix B

Linear invariant maps for weight-spaces

  • Linear G-invariant maps can be characterized by Proposition 5.3.
  • Weight vectors must obey an equation for all group elements.
  • Table 8 summarizes the orbits of the action on the indices of V.

Extension to other architectures.

  • Focus on MLP architectures as input for DWSNets
  • Possible extensions to CNNs and Transformers

Expressive power

  • Equivariant networks have expressive power which can be impaired if hypothesis class is restricted.
  • Graph neural networks are an example of this.
  • DWSNets can approximate feed-forward procedures on input networks.

Experiments

  • Evaluate DWSNets in two families of tasks
  • Train model to classify INRs and predict continuous properties of objects
  • Train model to operate on standard input-output mappings
  • Compare different architectures that operate directly on weight spaces
  • Train all input networks independently with different random seeds

Analysis of the results

  • DWSNets outperform all other methods
  • DWSNets scale better with data than network alignment
  • Network alignment is hard
  • DWSNets are equivariant to natural symmetries of weight spaces
  • Limitations include difficulty training on some tasks and complicated implementation
  • Future work includes modeling other weight space symmetries, understanding weight initialization, and data augmentation
  • Processing neural networks is studied to infer the final performance of a model
  • Predicting properties of trained NNs based on their weights is studied
  • Introduces useful inductive biases for learning tasks
  • Neural networks can be processed by applying a neural network to a concatenation of their high-order spatial derivatives
  • Ability of these networks to handle more general tasks is not well understood
  • Solving the alignment tasks is hard and these strategies suffer from scaling issues to large datasets
  • Complex data types are associated with groups of transformations that change data representation without changing the underlying data
  • Functions defined on these objects are often invariant or equivariant to these transformations
  • Constraining learning models to be equivariant or invariant to these transformations has many advantages
  • Spaces of linear equivariant layers can be solved by solving a system of linear equations
  • Learning on set-structured data is studied

B features and biases for equivariant maps

  • Equivariant maps between weight spaces with multiple features and bias terms are discussed.
  • Deep networks often use multiple feature channels to represent their input objects.
  • Equivariant layers for multiple input and output channels can be obtained using Proposition 5.2.
  • Bias terms are added to each output channel of the linear equivariant maps to create affine transformations.

C specification of all affine equivariant layers between sub-representations

  • G acts on the indices of V
  • Orbits define linear invariant layers and equivariant bias layers
  • Show that layers in Tables 4-5 are linear, G-equivariant and parameters are linearly independent
  • Number of parameters in layers matches dimension of space of G-equivariant maps
  • Use fact that 1 |Sn| σ∈Sn tr(P (σ)) k = bell(k)
  • Index on which G acts by permutation is called set index, other indices are free indices
  • Shared set dimensions add multiplicative factor of bell(2) = 2 to dimension of equivariant layer space
  • Free dimensions add multiplicative factor equal to their dimensionality
  • Unshared set dimensions add multiplicative factor of bell(1) = 1
  • Dimension of space of equivariant maps is 1 |G| g∈G tr(ρ(g)) tr(ρ (g))
  • Elements in B are linear and linearly independent
  • Equivariance is straightforward
  • Show that B is a basis by showing number of elements is equal to dimension of space of linear maps
  • Identity transformation, summation, broadcasting, feature-wise Hadamard product, and non-linearity operations used
  • Approximate Hadamard product using universal approximation theorem
  • Construct sequence of equivariant layers to mimic propagation of x through MLP
  • Uniform approximation is preserved by composition of continuous functions on compact domains

H extensions

  • Extension to nonlinear aggregation mechanisms
  • Extension to different architectures (MLP, CNN, Transformers)
  • Permuting the channel dimensions of adjacent layers does not change the function represented by the CNN

I alternative characterization strategies

  • Worked directly with the direct sum of the weight and bias spaces
  • Other strategies can be employed to characterize spaces of linear equivariant layers
  • Schur’s Lemma states that linear equivariant maps between irreducible representations are either zero or a scaled identity map

J experimental and technical details

  • Data is normalized and split into train, test and validation sets
  • Hyperparameters are optimized and early stopping is used
  • Data augmentation is employed for all experiments
  • Initialization is important for tasks where output parameterizes a network
  • Linear transformations are used to control the complexity of DWSNets
  • Approximation is used for model alignment
  • ReLU activations and Batch-Normalization layers are used
  • AdamW optimizer with weight-decay of 5e-4 is used
  • Regression of sine waves is trained with 4 hidden layers and 8 hidden features
  • Predicting generalization error of neural networks is trained with 4M parameters
  • Learning to adapt networks to new domains is trained with 4M parameters
  • Self-supervised learning for dense representation is trained with 100K parameters
  • DWSNet is used to predict generalization performance of MLP classifiers

K.2 dense representation

  • Learning a dense representation
  • Generates an embedding with a clear and intuitive 2D structure
  • Investigating the effect of using only part of the blocks in the proposed architecture
  • Classification task of MNIST INRs
  • W2W block is the most contributing factor to the overall performance
  • Adding other blocks increases the overall performance
  • Two challenging cases encountered while experimenting with the method
  • Symmetries of deep weight spaces
  • Block matrix structure for linear equivariant maps between weight spaces
  • Sine wave regression
  • MLP, MLP + augmentations, MLP + weight alignment, INR2Vec, Transformer
  • 2D TSNE of the resulting low-dimensional space
  • Supervised learning for dense representation
  • INR classification
  • Proposition 6.1
  • Adapting a network to a new domain
  • MSE of a linear regressor that predicts frequency and amplitude