Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transformers have been successful across many domains, but their learning mechanics are not well understood.
- Recent research has begun to study the representational aspects of transformers, but there is no guarantee the learning dynamics will converge.
- This paper provides a mechanistic understanding of how transformers learn “semantic structure” by combining experiments on synthetic data, Wikipedia data, and mathematical analysis.
Paper Content
Introduction
- Transformer architecture is used in many areas of computer science
- Little is known about how transformers learn from training data
- Studying semantic structure through co-occurrences of words and topical structure
- BERT model produces token embeddings that are more similar if they belong to the same topic
- Synthetic LDA data used to understand how embeddings and attention learn topical structure
- Mathematical proof of why such structure arises
- Two-stage training dynamics verified for a variety of optimizers and hyperparameter settings
Overview of results
- Focused on understanding optimization dynamics of transformers in a single-layer architecture
- Validated results transfer to real data
- Topic structure can be encoded in embedding layer and attention mechanism
- Even if one component is not trained, the other can compensate
- Characterized how topic structure is learned in two extremal cases
- Empirically verified characterization on synthetic and real data
Topic structure is encoded in token embeddings
- Theorem states that optimal embedding layer of single layer transformer encodes topic structure in embedding weights.
- Inner product of embeddings of a pair of words is larger when they belong to same topic, and smaller when they belong to different topics.
Topic structure is encoded in self-attention
- Self-attention behavior studied in a transformer trained on a topic modeling distribution without token embeddings
- Training process broken down into two stages
- In first stage, attention frozen to be uniform and matrix WV trained
- In second stage, WV frozen to optimal value and optimal attention weights analyzed
- Intuitively, two-stage approximation is reasonable
- Self-attention function is Attn(Z) := WVZA(Z)
- Optimal WV has block-wise structure, each block corresponding to a topic
- Optimal attention weights have convex combination of same-word and same-topic-different-words attention
- Empirically show that behavior of attention weights follows relations described in theorem
Empirical results
- Results remain robust under more complex and realistic settings
- Tested on synthetic data using Latent Dirichlet Allocation (LDA)
- Results reported for model pre-trained on Wikipedia corpus
- Detailed experimental setup and results in Section 6 and Appendix E
Problem setup
Topic models
- Data distribution generated by a topic model
- Special case of an LDA model
- Each document consists of words
- Each word belongs to exactly one topic
- Topics do not share common words
- Distribution of documents follows a generative process
- Document one-hot encodings
- Infinitely-long-document setting
- Empirical token distribution is equal to the groundtruth token distribution
Training objective
- Train a transformer network using the masked language modeling objective
- Define a token [MASK] and three constant probabilities p m , p c , p r
- Train the model to predict the original words at the masked positions
- Consider a regularized version of the masked language modeling objective
- Use the squared loss and cross entropy loss for theoretical and empirical analysis
Transformer network architecture
- One-layer transformer model with no residual connection or normalization layers
- Input representation Z is in R d×N
- Decoder weights W dec and biases b dec in R V ×d, V is vocabulary size
- Value matrix weight W V in R d×d
- Attention head size d a
- Key matrix W K in R da×d
- Query matrix W Q in R da×d
- Input Z is embedding of masked document, W E ∈ R d× (T v+1)
- Part of analysis and experiments freeze one-hot word embeddings
Topic structure can be encoded in token em-beddings
- Theorem 1 states that the embedding layer can encode the topic structure
- The 0-th row of E satisfies a certain condition
- Point 3 is the important one among the list of conclusions
- Same-topic words have more similar embeddings than different-topic words
Topic structure can be encoded in self-attention
- Self-attention can encode topic structures when token embeddings are frozen.
- Token embedding layer can learn the topic-modeling distribution.
The two-stage optimization process of self-attention
- Training dynamics of one-layer transformer on topic modeling data distribution observed in two-stage process
- Stage 1: key and query matrices stay close to 0, value matrix norm increases significantly
- Stage 2: value matrix norm plateaus, key and query matrices start to move
- Simplification of two-stage process: Stage 1 attention frozen to be uniform, only value matrix trained; Stage 2 value matrix frozen, key and query matrices trained
- Theorem 2: Optimal value matrix with mild L2-regularization when freezing uniform attention
Optimal attention weights
- Freezing WV to representative optima from stage 1
- Comparing 3 types of attention weights
- Model often converges to uniform blocks
- Simplified setting: Assumption 2
- Asymptotic setting: T→∞, τ→∞, v> certain value
- Optimal attention weights minimize loss
- Attention between same-topic words larger than different-topic words
- WV with uniform blocks sums up attention on all words in each topic
- No closed form for finite T, τ, loss landscape computed numerically
Experiments
- Analyzed properties of training dynamics
- Used extensive experimental analysis
- Setup for synthetic and Wikipedia data
Results on synthetic lda-generated data
- Generate data with T = 10, v = 10, N uniformly randomly chosen from [100,150]
- Training objective follows Section 3.2 with p m = 0.15, p c = 0.1, p r = 0.1
- Model architecture follows Section 3.3 with bias terms
- Learned embedding weight W E displays block-wise pattern
- W V has block-wise pattern when word embeddings are frozen and attention weights are uniform
- W V converges to block-wise pattern when attention weights are trained
- Theorem 3 holds when W V is frozen or trained
- Words pay more attention to words of same topic than different topics
Results on natural language data
- Compared pre-trained transformer-based models and tokenizers from Huggingface
- Used LDA model to determine topics from Wikipedia corpus
- Filtered out stop words and kept fraction of tokens with highest likelihood in each topic
- Compared embedding similarity and attention weights between same-topic and different-topic tokens
- Results showed that same-topic embedding similarity and attention weight were consistently greater than different-topic counterparts
Related works
- Transformers have been successful in solving a range of tasks
- Prior works have combined theoretical constructions and experiments to explain the expressive power of transformers
- Characterizing the capacity of neural network models by assessing their abilities in learning simple models of the data
- Multiple reasonable representational optima may exist
- Transformers learn spatial structure of image-type datasets through gradient-descent-based optimization algorithms
- Theoretically analyzing the optimization process of the transformer architecture on topic modeling data distribution
Discussion
The two-stage optimization process
- Alternating optimization procedure used to train WV, WK, WQ
- WK, WQ, WV typically trained jointly instead of alternatingly
- Initial steps of training process show WV growing faster than WK and WQ
- WV plateaus in Stage 2
- Two-stage phenomenon sensitive to hyperparameters
- Figure 5 shows two-stage learning dynamics of 4-layer, 4-head-per-layer transformer
- Stage 1: WV norm increases, WK and WQ stay close to 0
- Stage 2: WK and WQ norms start increasing significantly
Are topic-wise behaviors fully explained by co-occurrence counts?
- Topic-wise behavior of token embeddings and attention weights cannot be fully explained by co-occurrence
- Compared average attention weights and average embedding dot products between same-topic word pairs and the N pairs of words that co-occur the most
- Results show same-topic tokens have smaller average attention weight, but larger average embedding cosine similarity
Conclusion
- β > 100c4 vτ : then β ≤ 100c4 vτ : then Comparing the above cases Note that L(α, β) in Case 3 is strictly smaller than L(α, β) in Case 1 and Case 2, because: Comparing equation D.37 and equation D.39: (1 Comparing equation D.38 and equation D.39: in the former, the term p r ( 100c4 101v ) 2 > 0 is the extra constant (of scale Ω(1), i.e. non-vanishing even under our asymptotic assumptions Assumption 3) compared with the latter.
- Initiated study of understanding training dynamics of transformers in presence of semantic structure
- Extending analysis to data distributions that capture “syntactic” structure
- Disentangling different aspects of data (semantic and syntactic elements) learned through different parts of model architecture
- Position information of word in document irrelevant to topic model
- Single-head attention used
- Lemma on optimal linear transform when freezing uniform attention
- Optimal linear transform when freezing uniform attention consists of all W that satisfy certain conditions
- Optimal token embedding satisfies certain conditions
- Optimal attention weights when freezing block-wise W V satisfy certain conditions
- Optimal attention weights when freezing diagonal W V satisfy certain conditions