Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Data from many fields can be represented by discrete, compositional structures.
- Latent structure models are useful for learning to extract such representations.
- Three strategies for learning with discrete latent structure are explored.
Paper Content
Motivation
- Machine Learning is used to analyze data such as images, text, and sound.
- Natural language sentences can be analyzed in terms of their dependency structure.
- ML systems are often structured as pipelines with off-the-shelf components.
- Deep neural networks can outperform pipelines and learn dense, continuous representations of the data.
Notation
- Vector, matrix and indexing are mentioned
- u, v, W, X are mentioned as a scalar, vector, matrix and set
- v i is the ith element of vector v
- w j is the jth column of matrix W
- Partial derivative of f is identified with a d i × d matrix
- Jacobian transpose is more convenient for backpropagation
- ∂ θ (expr.) interprets expression as single-argument function of θ
- R d + is non-negative orthant
- ∑︁ i α i is simplex with d bins
- conv(Z) is convex hull of Z
Supervised learning
- Predictive machine learning is used to learn a model to predict labels from inputs
- Model is implemented as a feed-forward network with input and output
- Predictions are made using a Gibbs distribution
- Parameters are learned by minimizing an average loss
- Gradient and stochastic gradient methods are used to minimize a differentiable function
- Gradient and stochastic gradient methods can be extended using momentum and adaptivity
- Backpropagation is used to propagate gradients from output to input
- Popular software frameworks provide composable modules with forward and backward calls
Latent representations
- Main motivation is to go beyond direct mappings x → y
- Latent representation z captures relevant property of data point x
- Latent representation designed with downstream tasks in mind
- Problem is to learn to predict z from x using an encoder
- Latent representation captures specific aspect of x relevant to modeler
- Pretraining and pipelines can be used to predict z
- Probabilistic latent variables can be used for generative modelling
Further history and scope
- Latent variable models have a long history in ML, specifically for unsupervised learning.
- Factor analysis and Gaussian mixture model are two popular models.
- Probabilistic PCA is a special case of Factor Analysis.
- Variational Auto-Encoder is a nonlinear model parametrized by neural networks.
- ELBO is used to approximate the objective.
Roadmap
- Chapter 2 revisits tools of supervised structure prediction
- Chapter 3 explores deterministic approach to learning latent structure using relaxation
Overview
- Structure prediction is the learning setting wherein an unobserved structured-valued variable must be predicted.
- Structured representations consist of a combination of interdependent parts, which can be seen as a binary vector.
- Examples of structures include unconstrained, one-of-K, linear assignment, and non-projective dependency parsing.
- Structure prediction provides a prediction rule for selecting a best structure and, if possible, a probability distribution over all possible structures.
- Two main strategies for structure prediction are incremental prediction and global prediction.
Incremental prediction
- Model a structured variable z with a probabilistic model P (z | x)
- Z is discrete and finite and may depend on x
- Factor the distribution of z with the probability chain rule
- Multi-label setting is a particular instance of unconstrained structure prediction
- Blockwise prediction involves predicting multiple interdependent one-of-K choices
- Variable-length target structures can be achieved by designating a specific variable assignment as a “stop” control token
- Neural conditional language generation model parametrizes the probability of the next word
- Transition-based systems enable the extension of incremental prediction to more complex and constrained structured problems
- Shift-reduce parsing encodes candidate binary parse trees as bit vectors
- Uninformed shift-reduce model prefers highly skewed structures
- Multiple sequences of transitions can lead to the same target structure
Computation.
- Maximization of the highest-probability structure is not feasible in general.
- Beam search algorithm is a popular heuristic algorithm that keeps track of the k best partial prefixes at each time step.
- Sampling from the distribution can be done by sampling z 1 , . . . , z |P| sequentially.
- Approximate maximization can be done by way of sampling.
- Monte Carlo methods provide a good approximation for computing expectations.
Global prediction
- Structured models can be built with global optimality guarantees.
- Score of a structure is a measure of compatibility with an input.
- Score of a structure can decompose into parts.
- Example 8 is an arc-factored dependency parser.
- Example 9 is a Markov sequence tagging model.
The marginal polytope
- Decomposition assumption allows structure prediction problems to be solved as an optimization problem over the convex hull of possible structures.
- The simplest possible marginal polytope is the convex hull of one-of-K indicator vectors.
- Any other finite marginal polytope can be written as a linear transformation of the |Z|-simplex.
- Vectors α ∈ △ |Z| can be interpreted as distributions over the set of structures Z.
- Any point ˆ︁ z ∈ conv(Z) can be seen as an expected structure.
Computation
- Finding the highest-scoring structure is a difficult combinatorial optimization problem.
- Specialized algorithms are available for specific problems.
- Finding the highest-scoring structure is equivalent to maximizing a linear function over a polytope.
- Computing expectations is intractable in general, but efficient exact algorithms are available for many structures of interest.
- Sampling from the exact Gibbs distribution is difficult, but powerful general frameworks exist.
Challenges of deterministic choices
- Deterministic latent representations are a way to define the mapping from x to a representation z ˆ
- End-to-end downstream model is obtained through function composition of g with ˆ︁ z
- Training of θ f and the downstream θ g together
- Gradient with respect to decoder parameters poses no problems
- Dependency on θ f requires application of the chain rule
- Deterministic end-to-end learning has two requirements
- Assumption 2 constrains the possible downstream models
- Gradient-based learning of deterministic latent variables requires that z ∈ R d
- Downstream g must accept weighted averages of structures
- Downstream g must be almost everywhere differentiable w.r.t. z
- Breakdown in Eq. 3.3 highlights two terms
- Relaxation moves expectation inside the model
- Examples demonstrate implications of relaxations
- Relaxations cannot be applied if Assumption 2 does not hold
Regularized argmax operators
- Model is a well-trained pipeline
- Assignment to latent variable is highest-scoring one
- Dependency of s on x and θ f is omitted
- Mapping s → z ˆ is continuous and differentiable
- Argmax in right-hand side is a discrete optimization problem
- Mapping θ f → z ˆ0 is discontinuous and almost everywhere flat
- Example 13 (Single Coinflip) illustrates this
- Take Z = {0, 1} and let s ∈ R
- Problem z ˆ(s) = arg max z∈{0,1} zs can be solved by considering sign of s
- Mapping z ˆ(s) is not always a function
- Derivative ∂z ˆ(s) is zero almost everywhere
- Theorem of Dantzig et al. (1955) is fundamental to linear programming
- Maximum of optimization problem over polytope is always achieved at one of the vertices
- Relax a discrete problem to a continuous but constrained one
- Smoothing relies on insight that argmax mapping of optimization problem with strongly concave objective over a convex domain is a smooth function
- Objective of Eq. 3.8 is linear, thus concave but not strongly so
Categorical relaxation and attention
- Consider the categorical (one-of-K) case, where Z = {e 1 , . . . , e K } and conv(Z) = △ K
- Points z ∈ △ K can be interpreted as discrete probability distributions
- Shannon entropy is a meaningful regularization
- Softmax is a differentiable relaxation of the one-of-K argmax mapping
- Softmax output is fully dense
- Sparsemax is a sparse alternative to softmax
- Sparsemax corresponds to regularizing the argmax with the Gini entropy
- α-entmax family of mappings recovers softmax and sparsemax
- Solutions can be written in a thresholded form
- Attention mechanisms in deep learning extract a single contextual representation of a set of K objects
- Deterministic key-value lookup retrieves the value corresponding to the key with highest dot product with the query
- Attention is the entropy-regularized relaxation of lookup, replacing z ˆ0 with z ˆH1 = softmax(s)
- Entropy regularization forces all attention weights to be nonzero
- Sparse attention using H α can induce combinations between only a few objects
Global structured relaxations and structured attention
- Conv Z is a polytope that lacks a compact description
- Two ways to generalize softmax and sparsemax-style mappings in the structured case
- H measures the highest entropy among all decompositions of z as convex combinations of the elements of Z
- Solution z H is the average structure under the Gibbs distribution
- Kim et al. (2017) and Liu et al. (2018) use differentiating through dynamic programs to propose structured attention mechanisms
Mean structure regularization: sinkhorn and sparsemap
- Regularizer can be applied directly to marginal vector z instead of decomposition α
- Entropy-inspired regularizer can be applied to coordinates of z
- Linear assignment problem can be represented as permutation matrices
- Unbalanced assignments between different-sized sets can be represented by replacing equality constraint with inequality constraint
- Discrete optimal transport problem can be represented by constraining row and column sums to match given marginal distributions
- Quadratic regularization yields projection onto marginal polytope
- SparseMAP strategy relies on quadratic regularization and efficient general-purpose algorithm
- Solutions tend to be sparse
- Active set algorithm used to compute solution and gradients
- Implicit perturbation-based regularization uses random noise variable U
- Maximizations can be performed in parallel
- Regularizer chosen based on availability of efficient and numerically stable algorithm and desired sparsity
Summary
- Probabilistic formulation allows for flexible handling of discrete latent structure
- Price to pay in complexity due to entire distributions needing to be considered
- Correia et al. (2020) compare variance-reduced SFE, GumbelST, top-k sparsemax and SparseMAP on a variational autoencoder
- Gumbel-STE and related methods perform well in practice for bit vectors
- SparseMAP applies to permutation latent variables, but SFE and Gumbel-STE do not
Straight-through gradients
- Discrete latent variable assignment cannot be plugged into a downstream model.
- Encoder generates factorized scores.
- Discrete variable assignment is determined by arg max.
- Backward pass replaces problematic derivative with different function.
- Straight-through strategy passes gradient directly through argmax.
- Straight-through strategy is easy to implement.
Straight-through variants
- Softmax is a differentiable relaxation of the argmax operator
- Softmax can be used in the backward pass of a neural network
- SPIGOT computes an intermediate approximation of the gradient and projects it onto the marginal polytope
- LI is a linear interpolation of the downstream value at two points
- Surrogate gradient methods override the actual gradient computation
- PyTorch custom functions can be used to implement surrogate gradient methods
Quantization: straight-through friendly models
- Neural networks can be rearranged to make challenging computations more linear.
- Specific model choices can be beneficial and perform well.
Rounding
- Discrete latent variable with domain of integers
- Score assigned to each integer is infeasible
- Natural ordering of integers suggests single real-valued score
- Floor operation ignored in backward pass
- Proven useful in constructing integer normalizing flow models and learning efficient neural networks
Vector quantization
- VQ-VAE construction introduced by Oord et al. (2017) draws ideas from vector quantization.
- Encoder-decoder architecture with discrete categorical variables represented by embeddings.
- Quantization step snaps encoder output to nearest neighbor anchor point.
- STE applied to quantization mapping.
- VQ-VAE loss includes term to learn embeddings and encourage encoder to output values close to allowed embeddings.
Interpretation via pulled-back labels
- STE and SPIGOT are practical solutions for learning models with null gradients
- Consider a multitask learning scenario where a latent variable is predicted in addition to the downstream task
- If ground truth supervision were available for the latent variable, the parameters could be trained jointly with an auxiliary loss
- As such supervision is not available, a best-guess pulled-back label is induced by pulling back the downstream loss
- This strategy recovers the STE and SPIGOT estimators
- The pulled-back label is a sensible approximation to the value for z that minimizes the downstream loss
- SPIGOT minimizes the perceptron loss w.r.t. a pulled-back label computed by one projected gradient step
- Relaxing the constraints, STE can be interpreted as minimizing a perceptron loss on a pulled-back label
- Probabilistic form of discrete latent variable learning problem can be computed or approximated without assumptions on the downstream model
Explicit marginalization by enumeration
- Probabilistic framework optimizes expected loss
- Sum can be over exponentially large set
- Expectation pushed inside
- Discrete mapping approximated
- Explicit marginalization not required for discrete index
- Enumeration possible for manageable size
- Explicit marginalization recommended for testing
Monte carlo gradient estimation
- Monte Carlo method can be used to approximate expectations
- Differentiating expectations is linear and can be done directly
- Differentiating expectations is harder if the distribution depends on parameters
- Two main strategies used for Monte Carlo gradient estimation with discrete and structured latent variables
Path gradient estimator (the reparametrization trick)
- Estimating gradients of expectations is challenging in both discrete and continuous settings
- An alternative to the SFE is the Path Gradient Estimator (PGE)
- The challenge is that the distribution of the latent variable depends on the parameters
- The PGE circumvents this problem by extracting the source of randomness into a parameter-free random variable
- For any function F (z) we may then write an expectation over a distribution that does not depend on θ
- We can then obtain the Monte Carlo gradient estimator
- If U is a continuous random variable, the right-hand-side expectation is an integral
- The Gumbel-argmax construction is a reparametrization of a categorical distribution
- The Gumbel-Softmax is a perturbed version of the softmax relaxation
- The Straight-Through Gumbel-Softmax estimator combines ideas of the PGE with the surrogate gradient strategies
- The Implicit Maximum Likelihood Estimation (I-MLE) approach combines structured perturbation methods with surrogate gradient approaches
- Mixed random variables are hybrids of discrete and continuous random variables
- Examples of mixed random variables are the Hard Concrete and Hard-Kuma distributions
- The Gaussian-Sparsemax is a “mixed” counterpart of the multivariate logistic normal distribution
Score function estimator
- Score Function Estimator (SFE) relies on the fact that (log t) ′ = 1/t
- Expectation in Eq. 5.4 can be approximated with Monte-Carlo
- Gradient-based learning with discrete or structured latent variables is possible
- SFE has high variance and requires a sampling oracle for p(Z | x)
Variance reduction
- SFE has high variance
- Variance reduction techniques exist
- Control variates, learned data-driven baselines, and NVIL are three main directions
- Control variates use h and β to reduce variance
- Learned data-driven baselines use gradient descent with a squared loss
- Relaxation-based control variates use a deterministic mapping
- Rao-Blackwellization uses tractable conditioning to reduce variance