Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Transformers have been successful across many domains, but their learning mechanics are not well understood.
  • Recent research has begun to study the representational aspects of transformers, but there is no guarantee the learning dynamics will converge.
  • This paper provides a mechanistic understanding of how transformers learn “semantic structure” by combining experiments on synthetic data, Wikipedia data, and mathematical analysis.

Paper Content

Introduction

  • Transformer architecture is used in many areas of computer science
  • Little is known about how transformers learn from training data
  • Studying semantic structure through co-occurrences of words and topical structure
  • BERT model produces token embeddings that are more similar if they belong to the same topic
  • Synthetic LDA data used to understand how embeddings and attention learn topical structure
  • Mathematical proof of why such structure arises
  • Two-stage training dynamics verified for a variety of optimizers and hyperparameter settings

Overview of results

  • Focused on understanding optimization dynamics of transformers in a single-layer architecture
  • Validated results transfer to real data
  • Topic structure can be encoded in embedding layer and attention mechanism
  • Even if one component is not trained, the other can compensate
  • Characterized how topic structure is learned in two extremal cases
  • Empirically verified characterization on synthetic and real data

Topic structure is encoded in token embeddings

  • Theorem states that optimal embedding layer of single layer transformer encodes topic structure in embedding weights.
  • Inner product of embeddings of a pair of words is larger when they belong to same topic, and smaller when they belong to different topics.

Topic structure is encoded in self-attention

  • Self-attention behavior studied in a transformer trained on a topic modeling distribution without token embeddings
  • Training process broken down into two stages
  • In first stage, attention frozen to be uniform and matrix WV trained
  • In second stage, WV frozen to optimal value and optimal attention weights analyzed
  • Intuitively, two-stage approximation is reasonable
  • Self-attention function is Attn(Z) := WVZA(Z)
  • Optimal WV has block-wise structure, each block corresponding to a topic
  • Optimal attention weights have convex combination of same-word and same-topic-different-words attention
  • Empirically show that behavior of attention weights follows relations described in theorem

Empirical results

  • Results remain robust under more complex and realistic settings
  • Tested on synthetic data using Latent Dirichlet Allocation (LDA)
  • Results reported for model pre-trained on Wikipedia corpus
  • Detailed experimental setup and results in Section 6 and Appendix E

Problem setup

Topic models

  • Data distribution generated by a topic model
  • Special case of an LDA model
  • Each document consists of words
  • Each word belongs to exactly one topic
  • Topics do not share common words
  • Distribution of documents follows a generative process
  • Document one-hot encodings
  • Infinitely-long-document setting
  • Empirical token distribution is equal to the groundtruth token distribution

Training objective

  • Train a transformer network using the masked language modeling objective
  • Define a token [MASK] and three constant probabilities p m , p c , p r
  • Train the model to predict the original words at the masked positions
  • Consider a regularized version of the masked language modeling objective
  • Use the squared loss and cross entropy loss for theoretical and empirical analysis

Transformer network architecture

  • One-layer transformer model with no residual connection or normalization layers
  • Input representation Z is in R d×N
  • Decoder weights W dec and biases b dec in R V ×d, V is vocabulary size
  • Value matrix weight W V in R d×d
  • Attention head size d a
  • Key matrix W K in R da×d
  • Query matrix W Q in R da×d
  • Input Z is embedding of masked document, W E ∈ R d× (T v+1)
  • Part of analysis and experiments freeze one-hot word embeddings

Topic structure can be encoded in token em-beddings

  • Theorem 1 states that the embedding layer can encode the topic structure
  • The 0-th row of E satisfies a certain condition
  • Point 3 is the important one among the list of conclusions
  • Same-topic words have more similar embeddings than different-topic words

Topic structure can be encoded in self-attention

  • Self-attention can encode topic structures when token embeddings are frozen.
  • Token embedding layer can learn the topic-modeling distribution.

The two-stage optimization process of self-attention

  • Training dynamics of one-layer transformer on topic modeling data distribution observed in two-stage process
  • Stage 1: key and query matrices stay close to 0, value matrix norm increases significantly
  • Stage 2: value matrix norm plateaus, key and query matrices start to move
  • Simplification of two-stage process: Stage 1 attention frozen to be uniform, only value matrix trained; Stage 2 value matrix frozen, key and query matrices trained
  • Theorem 2: Optimal value matrix with mild L2-regularization when freezing uniform attention

Optimal attention weights

  • Freezing WV to representative optima from stage 1
  • Comparing 3 types of attention weights
  • Model often converges to uniform blocks
  • Simplified setting: Assumption 2
  • Asymptotic setting: T→∞, τ→∞, v> certain value
  • Optimal attention weights minimize loss
  • Attention between same-topic words larger than different-topic words
  • WV with uniform blocks sums up attention on all words in each topic
  • No closed form for finite T, τ, loss landscape computed numerically

Experiments

  • Analyzed properties of training dynamics
  • Used extensive experimental analysis
  • Setup for synthetic and Wikipedia data

Results on synthetic lda-generated data

  • Generate data with T = 10, v = 10, N uniformly randomly chosen from [100,150]
  • Training objective follows Section 3.2 with p m = 0.15, p c = 0.1, p r = 0.1
  • Model architecture follows Section 3.3 with bias terms
  • Learned embedding weight W E displays block-wise pattern
  • W V has block-wise pattern when word embeddings are frozen and attention weights are uniform
  • W V converges to block-wise pattern when attention weights are trained
  • Theorem 3 holds when W V is frozen or trained
  • Words pay more attention to words of same topic than different topics

Results on natural language data

  • Compared pre-trained transformer-based models and tokenizers from Huggingface
  • Used LDA model to determine topics from Wikipedia corpus
  • Filtered out stop words and kept fraction of tokens with highest likelihood in each topic
  • Compared embedding similarity and attention weights between same-topic and different-topic tokens
  • Results showed that same-topic embedding similarity and attention weight were consistently greater than different-topic counterparts
  • Transformers have been successful in solving a range of tasks
  • Prior works have combined theoretical constructions and experiments to explain the expressive power of transformers
  • Characterizing the capacity of neural network models by assessing their abilities in learning simple models of the data
  • Multiple reasonable representational optima may exist
  • Transformers learn spatial structure of image-type datasets through gradient-descent-based optimization algorithms
  • Theoretically analyzing the optimization process of the transformer architecture on topic modeling data distribution

Discussion

The two-stage optimization process

  • Alternating optimization procedure used to train WV, WK, WQ
  • WK, WQ, WV typically trained jointly instead of alternatingly
  • Initial steps of training process show WV growing faster than WK and WQ
  • WV plateaus in Stage 2
  • Two-stage phenomenon sensitive to hyperparameters
  • Figure 5 shows two-stage learning dynamics of 4-layer, 4-head-per-layer transformer
  • Stage 1: WV norm increases, WK and WQ stay close to 0
  • Stage 2: WK and WQ norms start increasing significantly

Are topic-wise behaviors fully explained by co-occurrence counts?

  • Topic-wise behavior of token embeddings and attention weights cannot be fully explained by co-occurrence
  • Compared average attention weights and average embedding dot products between same-topic word pairs and the N pairs of words that co-occur the most
  • Results show same-topic tokens have smaller average attention weight, but larger average embedding cosine similarity

Conclusion

  • β > 100c4 vτ : then β ≤ 100c4 vτ : then Comparing the above cases Note that L(α, β) in Case 3 is strictly smaller than L(α, β) in Case 1 and Case 2, because: Comparing equation D.37 and equation D.39: (1 Comparing equation D.38 and equation D.39: in the former, the term p r ( 100c4 101v ) 2 > 0 is the extra constant (of scale Ω(1), i.e. non-vanishing even under our asymptotic assumptions Assumption 3) compared with the latter.
  • Initiated study of understanding training dynamics of transformers in presence of semantic structure
  • Extending analysis to data distributions that capture “syntactic” structure
  • Disentangling different aspects of data (semantic and syntactic elements) learned through different parts of model architecture
  • Position information of word in document irrelevant to topic model
  • Single-head attention used
  • Lemma on optimal linear transform when freezing uniform attention
  • Optimal linear transform when freezing uniform attention consists of all W that satisfy certain conditions
  • Optimal token embedding satisfies certain conditions
  • Optimal attention weights when freezing block-wise W V satisfy certain conditions
  • Optimal attention weights when freezing diagonal W V satisfy certain conditions