Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • FTRs are natural language explanations of reasoning processes.
  • Recent works have studied how to use FTRs to improve language model generalization.
  • KNIFE distills FTR knowledge from an FTR-augmented teacher LM to a student LM.
  • KNIFE outperforms existing FTR learning methods on two question answering datasets.

Paper Content

Introduction

  • Pretrained language models (LMs) are typically finetuned on downstream tasks using only task labels as supervision.
  • Task labels provide feedback on whether the LM’s outputs are correct, but not on whether the LM’s reasoning process is correct.
  • Explanation-based learning (EBL) aims to improve LM generalization by also explaining the correct reasoning process behind a given correct output.
  • Most EBL works have focused on learning from extractive rationales, which explain how to solve a task by highlighting important features in the task input.
  • There is growing interest in learning from free-text rationales (FTRs), which explain reasoning processes via natural language.
  • FTRs can reference things beyond the task input and support high flexibility in content, style, and length.
  • Existing works have attempted to learn from FTRs by appending FTRs to the LM’s input or jointly training the LM to generate FTRs.
  • Both paradigms have key limitations.
  • KNIFE distills FTR knowledge from an FTR-augmented teacher LM to a student LM, which is used for inference.
  • KNIFE does not suffer from input distribution shift or conflicting learning objectives.
  • KNIFE achieves significant performance gains over existing FTR-based EBL methods.
  • KNIFE can outperform baselines in low-resource settings.

Background

  • LM for text classification predicts a confidence score for each (x, y i ) pair
  • Free-text rationales (FTRs) explain the reasoning process for predicting label y i
  • FTRs can be more intuitive and understandable to humans than extractive rationales
  • FTRs can reference things beyond the task input
  • FTRs support high flexibility in content, style, and length
  • Extractive rationales assign an importance score to each input feature
  • FTRs do not have such a feature-score correspondence
  • Existing works have attempted to learn from FTRs by appending FTRs to the LM’s input or jointly training the LM to generate FTRs
  • These paradigms have key limitations
  • Proposed KNIFE approach for injecting FTR knowledge into LMs

Lm designs

  • KNIFE consists of a teacher LM and a student LM
  • Both LMs use text-to-text Transformer architectures
  • Task input is denoted as n x-token sequence x
  • Decoding is done by taking a special start token as input
  • Teacher LM takes both task input and FTR as input
  • Teacher LM has a bottleneck stage where FTR states are masked out
  • Teacher LM’s task output states are called task output states
  • Student LM takes only task input as input
  • Student LM is trained to align with teacher LM’s task input/output states
  • Goal is to maximize P (y* | x)

Knowledge distillation

  • Teacher LM T is trained to predict compatibility of (y i , r) pairs
  • FTR bottleneck design causes FTR knowledge to be transferred to other parts of T
  • T enc must be trained to distill knowledge from r into task input states
  • T dec must contain useful knowledge for reasoning process to compute compatibility scores
  • Student LM S is trained with KD losses and task loss
  • Total loss is defined as cross-entropy loss, task input states KD loss, and task output states KD loss

Experiments

  • KNIFE achieves better performance than existing FTR-based EBL methods in fully-supervised settings
  • KNIFE outperforms baselines in low-resource settings
  • Extensive ablation studies validate KNIFE design choices

Evaluation protocol

  • We consider two types of text classification tasks: multi-choice and closed-set
  • We focus on two popular QA datasets: OBQA and StrategyQA
  • We report mean and standard deviation accuracy over three random seeds
  • We consider four different modes of LM input/output settings

Baselines

  • Experiments consider a range of finetuning baselines
  • Vanilla Finetuning refers to finetuning methods without KD
  • Vanilla Finetuning methods distinguished by input-output mode
  • Vanilla Finetuning + I→O and Vanilla Finetuning + IR→O trained using task loss
  • Vanilla Finetuning + I→OR and Vanilla Finetuning + I→RO trained using greedy search decoding
  • Three KNIFE variants, each using different combination of KD-in and KD-out losses
  • KNIFE Student Finetuning (In), (Out), (In+Out) trained without task loss
  • KNIFE Teacher Finetuning as upper bound

Implementation details

  • Uses T5-Base and T5-Large as the LM architecture
  • Teacher and student LMs use same architecture by default
  • Student LM’s parameters initialized as teacher LM’s finetuned parameters
  • T5-Large → T5-Base setting has different embedding dimensionalities
  • Linear projection layer transforms student LM’s states to have same dimensionality as teacher LM’s
  • Student LM uses teacher LM’s language modeling head

Main results

  • Vanilla Finetuning + I→O and Vanilla Finetuning (KNIFE Teacher Init.) + I→O outperform other Vanilla Finetuning baselines
  • Vanilla Finetuning + I→OR and Vanilla Finetuning + I→RO are not trained/evaluated with labelspecific teacher-forcing
  • Vanilla Finetuning + I→RO performs poorly
  • Vanilla Finetuning (KNIFE Teacher Init.) + I→O does not outperform Vanilla Finetuning + I→O
  • KNIFE Student Finetuning consistently achieves significant improvements
  • KNIFE Teacher Finetuning outperforms all other methods
  • T5-Large based methods generally outperform T5-Base counterparts

Low-resource learning

  • Low-resource learning experiments consider OBQA dataset and T5-Base architecture
  • Vanilla Finetuning + I→O and Vanilla Finetuning (KNIFE Teacher Init.) + I→O perform competitively
  • Vanilla Finetuning + IR→O is the best performing baseline
  • KNIFE Student Finetuning significantly outperforms all baselines
  • KNIFE Teacher Finetuning upper bound outperforms all other methods
  • FTR type/quality matters a lot and gold FTRs are the source of KNIFE’s performance improvements
  • FTR bottleneck is effective at distilling knowledge from the teacher LM’s FTR states to its task input/output states
  • Student LM trained on both the KD losses and the task loss may get confused and learn a suboptimal reasoning process
  • Extractive rationales identify important features in the input to explain how the decision should be made
  • Free-text rationales explain the decision in a more flexible way
  • FTRs can be appended to the input or used as the target output in addition to the label
  • KNIFE distills the knowledge from a FTR-augmented teacher model to a student model
  • Knowledge distillation has been used to develop efficient deep neural networks
  • Types of knowledge for distillation include the logits of teacher model or the representations extracted from the intermediate layers
  • KNIFE obtains additional knowledge from FTRs
  • KNIFE encodes the FTR knowledge into the task input/output states of a teacher model and distills the knowledge to a student model