Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Data from many fields can be represented by discrete, compositional structures.
  • Latent structure models are useful for learning to extract such representations.
  • Three strategies for learning with discrete latent structure are explored.

Paper Content

Motivation

  • Machine Learning is used to analyze data such as images, text, and sound.
  • Natural language sentences can be analyzed in terms of their dependency structure.
  • ML systems are often structured as pipelines with off-the-shelf components.
  • Deep neural networks can outperform pipelines and learn dense, continuous representations of the data.

Notation

  • Vector, matrix and indexing are mentioned
  • u, v, W, X are mentioned as a scalar, vector, matrix and set
  • v i is the ith element of vector v
  • w j is the jth column of matrix W
  • Partial derivative of f is identified with a d i × d matrix
  • Jacobian transpose is more convenient for backpropagation
  • ∂ θ (expr.) interprets expression as single-argument function of θ
  • R d + is non-negative orthant
  • ∑︁ i α i is simplex with d bins
  • conv(Z) is convex hull of Z

Supervised learning

  • Predictive machine learning is used to learn a model to predict labels from inputs
  • Model is implemented as a feed-forward network with input and output
  • Predictions are made using a Gibbs distribution
  • Parameters are learned by minimizing an average loss
  • Gradient and stochastic gradient methods are used to minimize a differentiable function
  • Gradient and stochastic gradient methods can be extended using momentum and adaptivity
  • Backpropagation is used to propagate gradients from output to input
  • Popular software frameworks provide composable modules with forward and backward calls

Latent representations

  • Main motivation is to go beyond direct mappings x → y
  • Latent representation z captures relevant property of data point x
  • Latent representation designed with downstream tasks in mind
  • Problem is to learn to predict z from x using an encoder
  • Latent representation captures specific aspect of x relevant to modeler
  • Pretraining and pipelines can be used to predict z
  • Probabilistic latent variables can be used for generative modelling

Further history and scope

  • Latent variable models have a long history in ML, specifically for unsupervised learning.
  • Factor analysis and Gaussian mixture model are two popular models.
  • Probabilistic PCA is a special case of Factor Analysis.
  • Variational Auto-Encoder is a nonlinear model parametrized by neural networks.
  • ELBO is used to approximate the objective.

Roadmap

  • Chapter 2 revisits tools of supervised structure prediction
  • Chapter 3 explores deterministic approach to learning latent structure using relaxation

Overview

  • Structure prediction is the learning setting wherein an unobserved structured-valued variable must be predicted.
  • Structured representations consist of a combination of interdependent parts, which can be seen as a binary vector.
  • Examples of structures include unconstrained, one-of-K, linear assignment, and non-projective dependency parsing.
  • Structure prediction provides a prediction rule for selecting a best structure and, if possible, a probability distribution over all possible structures.
  • Two main strategies for structure prediction are incremental prediction and global prediction.

Incremental prediction

  • Model a structured variable z with a probabilistic model P (z | x)
  • Z is discrete and finite and may depend on x
  • Factor the distribution of z with the probability chain rule
  • Multi-label setting is a particular instance of unconstrained structure prediction
  • Blockwise prediction involves predicting multiple interdependent one-of-K choices
  • Variable-length target structures can be achieved by designating a specific variable assignment as a “stop” control token
  • Neural conditional language generation model parametrizes the probability of the next word
  • Transition-based systems enable the extension of incremental prediction to more complex and constrained structured problems
  • Shift-reduce parsing encodes candidate binary parse trees as bit vectors
  • Uninformed shift-reduce model prefers highly skewed structures
  • Multiple sequences of transitions can lead to the same target structure

Computation.

  • Maximization of the highest-probability structure is not feasible in general.
  • Beam search algorithm is a popular heuristic algorithm that keeps track of the k best partial prefixes at each time step.
  • Sampling from the distribution can be done by sampling z 1 , . . . , z |P| sequentially.
  • Approximate maximization can be done by way of sampling.
  • Monte Carlo methods provide a good approximation for computing expectations.

Global prediction

  • Structured models can be built with global optimality guarantees.
  • Score of a structure is a measure of compatibility with an input.
  • Score of a structure can decompose into parts.
  • Example 8 is an arc-factored dependency parser.
  • Example 9 is a Markov sequence tagging model.

The marginal polytope

  • Decomposition assumption allows structure prediction problems to be solved as an optimization problem over the convex hull of possible structures.
  • The simplest possible marginal polytope is the convex hull of one-of-K indicator vectors.
  • Any other finite marginal polytope can be written as a linear transformation of the |Z|-simplex.
  • Vectors α ∈ △ |Z| can be interpreted as distributions over the set of structures Z.
  • Any point ˆ︁ z ∈ conv(Z) can be seen as an expected structure.

Computation

  • Finding the highest-scoring structure is a difficult combinatorial optimization problem.
  • Specialized algorithms are available for specific problems.
  • Finding the highest-scoring structure is equivalent to maximizing a linear function over a polytope.
  • Computing expectations is intractable in general, but efficient exact algorithms are available for many structures of interest.
  • Sampling from the exact Gibbs distribution is difficult, but powerful general frameworks exist.

Challenges of deterministic choices

  • Deterministic latent representations are a way to define the mapping from x to a representation z ˆ
  • End-to-end downstream model is obtained through function composition of g with ˆ︁ z
  • Training of θ f and the downstream θ g together
  • Gradient with respect to decoder parameters poses no problems
  • Dependency on θ f requires application of the chain rule
  • Deterministic end-to-end learning has two requirements
  • Assumption 2 constrains the possible downstream models
  • Gradient-based learning of deterministic latent variables requires that z ∈ R d
  • Downstream g must accept weighted averages of structures
  • Downstream g must be almost everywhere differentiable w.r.t. z
  • Breakdown in Eq. 3.3 highlights two terms
  • Relaxation moves expectation inside the model
  • Examples demonstrate implications of relaxations
  • Relaxations cannot be applied if Assumption 2 does not hold

Regularized argmax operators

  • Model is a well-trained pipeline
  • Assignment to latent variable is highest-scoring one
  • Dependency of s on x and θ f is omitted
  • Mapping s → z ˆ is continuous and differentiable
  • Argmax in right-hand side is a discrete optimization problem
  • Mapping θ f → z ˆ0 is discontinuous and almost everywhere flat
  • Example 13 (Single Coinflip) illustrates this
  • Take Z = {0, 1} and let s ∈ R
  • Problem z ˆ(s) = arg max z∈{0,1} zs can be solved by considering sign of s
  • Mapping z ˆ(s) is not always a function
  • Derivative ∂z ˆ(s) is zero almost everywhere
  • Theorem of Dantzig et al. (1955) is fundamental to linear programming
  • Maximum of optimization problem over polytope is always achieved at one of the vertices
  • Relax a discrete problem to a continuous but constrained one
  • Smoothing relies on insight that argmax mapping of optimization problem with strongly concave objective over a convex domain is a smooth function
  • Objective of Eq. 3.8 is linear, thus concave but not strongly so

Categorical relaxation and attention

  • Consider the categorical (one-of-K) case, where Z = {e 1 , . . . , e K } and conv(Z) = △ K
  • Points z ∈ △ K can be interpreted as discrete probability distributions
  • Shannon entropy is a meaningful regularization
  • Softmax is a differentiable relaxation of the one-of-K argmax mapping
  • Softmax output is fully dense
  • Sparsemax is a sparse alternative to softmax
  • Sparsemax corresponds to regularizing the argmax with the Gini entropy
  • α-entmax family of mappings recovers softmax and sparsemax
  • Solutions can be written in a thresholded form
  • Attention mechanisms in deep learning extract a single contextual representation of a set of K objects
  • Deterministic key-value lookup retrieves the value corresponding to the key with highest dot product with the query
  • Attention is the entropy-regularized relaxation of lookup, replacing z ˆ0 with z ˆH1 = softmax(s)
  • Entropy regularization forces all attention weights to be nonzero
  • Sparse attention using H α can induce combinations between only a few objects

Global structured relaxations and structured attention

  • Conv Z is a polytope that lacks a compact description
  • Two ways to generalize softmax and sparsemax-style mappings in the structured case
  • H measures the highest entropy among all decompositions of z as convex combinations of the elements of Z
  • Solution z H is the average structure under the Gibbs distribution
  • Kim et al. (2017) and Liu et al. (2018) use differentiating through dynamic programs to propose structured attention mechanisms

Mean structure regularization: sinkhorn and sparsemap

  • Regularizer can be applied directly to marginal vector z instead of decomposition α
  • Entropy-inspired regularizer can be applied to coordinates of z
  • Linear assignment problem can be represented as permutation matrices
  • Unbalanced assignments between different-sized sets can be represented by replacing equality constraint with inequality constraint
  • Discrete optimal transport problem can be represented by constraining row and column sums to match given marginal distributions
  • Quadratic regularization yields projection onto marginal polytope
  • SparseMAP strategy relies on quadratic regularization and efficient general-purpose algorithm
  • Solutions tend to be sparse
  • Active set algorithm used to compute solution and gradients
  • Implicit perturbation-based regularization uses random noise variable U
  • Maximizations can be performed in parallel
  • Regularizer chosen based on availability of efficient and numerically stable algorithm and desired sparsity

Summary

  • Probabilistic formulation allows for flexible handling of discrete latent structure
  • Price to pay in complexity due to entire distributions needing to be considered
  • Correia et al. (2020) compare variance-reduced SFE, GumbelST, top-k sparsemax and SparseMAP on a variational autoencoder
  • Gumbel-STE and related methods perform well in practice for bit vectors
  • SparseMAP applies to permutation latent variables, but SFE and Gumbel-STE do not

Straight-through gradients

  • Discrete latent variable assignment cannot be plugged into a downstream model.
  • Encoder generates factorized scores.
  • Discrete variable assignment is determined by arg max.
  • Backward pass replaces problematic derivative with different function.
  • Straight-through strategy passes gradient directly through argmax.
  • Straight-through strategy is easy to implement.

Straight-through variants

  • Softmax is a differentiable relaxation of the argmax operator
  • Softmax can be used in the backward pass of a neural network
  • SPIGOT computes an intermediate approximation of the gradient and projects it onto the marginal polytope
  • LI is a linear interpolation of the downstream value at two points
  • Surrogate gradient methods override the actual gradient computation
  • PyTorch custom functions can be used to implement surrogate gradient methods

Quantization: straight-through friendly models

  • Neural networks can be rearranged to make challenging computations more linear.
  • Specific model choices can be beneficial and perform well.

Rounding

  • Discrete latent variable with domain of integers
  • Score assigned to each integer is infeasible
  • Natural ordering of integers suggests single real-valued score
  • Floor operation ignored in backward pass
  • Proven useful in constructing integer normalizing flow models and learning efficient neural networks

Vector quantization

  • VQ-VAE construction introduced by Oord et al. (2017) draws ideas from vector quantization.
  • Encoder-decoder architecture with discrete categorical variables represented by embeddings.
  • Quantization step snaps encoder output to nearest neighbor anchor point.
  • STE applied to quantization mapping.
  • VQ-VAE loss includes term to learn embeddings and encourage encoder to output values close to allowed embeddings.

Interpretation via pulled-back labels

  • STE and SPIGOT are practical solutions for learning models with null gradients
  • Consider a multitask learning scenario where a latent variable is predicted in addition to the downstream task
  • If ground truth supervision were available for the latent variable, the parameters could be trained jointly with an auxiliary loss
  • As such supervision is not available, a best-guess pulled-back label is induced by pulling back the downstream loss
  • This strategy recovers the STE and SPIGOT estimators
  • The pulled-back label is a sensible approximation to the value for z that minimizes the downstream loss
  • SPIGOT minimizes the perceptron loss w.r.t. a pulled-back label computed by one projected gradient step
  • Relaxing the constraints, STE can be interpreted as minimizing a perceptron loss on a pulled-back label
  • Probabilistic form of discrete latent variable learning problem can be computed or approximated without assumptions on the downstream model

Explicit marginalization by enumeration

  • Probabilistic framework optimizes expected loss
  • Sum can be over exponentially large set
  • Expectation pushed inside
  • Discrete mapping approximated
  • Explicit marginalization not required for discrete index
  • Enumeration possible for manageable size
  • Explicit marginalization recommended for testing

Monte carlo gradient estimation

  • Monte Carlo method can be used to approximate expectations
  • Differentiating expectations is linear and can be done directly
  • Differentiating expectations is harder if the distribution depends on parameters
  • Two main strategies used for Monte Carlo gradient estimation with discrete and structured latent variables

Path gradient estimator (the reparametrization trick)

  • Estimating gradients of expectations is challenging in both discrete and continuous settings
  • An alternative to the SFE is the Path Gradient Estimator (PGE)
  • The challenge is that the distribution of the latent variable depends on the parameters
  • The PGE circumvents this problem by extracting the source of randomness into a parameter-free random variable
  • For any function F (z) we may then write an expectation over a distribution that does not depend on θ
  • We can then obtain the Monte Carlo gradient estimator
  • If U is a continuous random variable, the right-hand-side expectation is an integral
  • The Gumbel-argmax construction is a reparametrization of a categorical distribution
  • The Gumbel-Softmax is a perturbed version of the softmax relaxation
  • The Straight-Through Gumbel-Softmax estimator combines ideas of the PGE with the surrogate gradient strategies
  • The Implicit Maximum Likelihood Estimation (I-MLE) approach combines structured perturbation methods with surrogate gradient approaches
  • Mixed random variables are hybrids of discrete and continuous random variables
  • Examples of mixed random variables are the Hard Concrete and Hard-Kuma distributions
  • The Gaussian-Sparsemax is a “mixed” counterpart of the multivariate logistic normal distribution

Score function estimator

  • Score Function Estimator (SFE) relies on the fact that (log t) ′ = 1/t
  • Expectation in Eq. 5.4 can be approximated with Monte-Carlo
  • Gradient-based learning with discrete or structured latent variables is possible
  • SFE has high variance and requires a sampling oracle for p(Z | x)

Variance reduction

  • SFE has high variance
  • Variance reduction techniques exist
  • Control variates, learned data-driven baselines, and NVIL are three main directions
  • Control variates use h and β to reduce variance
  • Learned data-driven baselines use gradient descent with a squared loss
  • Relaxation-based control variates use a deterministic mapping
  • Rao-Blackwellization uses tractable conditioning to reduce variance