Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- FTRs are natural language explanations of reasoning processes.
- Recent works have studied how to use FTRs to improve language model generalization.
- KNIFE distills FTR knowledge from an FTR-augmented teacher LM to a student LM.
- KNIFE outperforms existing FTR learning methods on two question answering datasets.
Paper Content
Introduction
- Pretrained language models (LMs) are typically finetuned on downstream tasks using only task labels as supervision.
- Task labels provide feedback on whether the LM’s outputs are correct, but not on whether the LM’s reasoning process is correct.
- Explanation-based learning (EBL) aims to improve LM generalization by also explaining the correct reasoning process behind a given correct output.
- Most EBL works have focused on learning from extractive rationales, which explain how to solve a task by highlighting important features in the task input.
- There is growing interest in learning from free-text rationales (FTRs), which explain reasoning processes via natural language.
- FTRs can reference things beyond the task input and support high flexibility in content, style, and length.
- Existing works have attempted to learn from FTRs by appending FTRs to the LM’s input or jointly training the LM to generate FTRs.
- Both paradigms have key limitations.
- KNIFE distills FTR knowledge from an FTR-augmented teacher LM to a student LM, which is used for inference.
- KNIFE does not suffer from input distribution shift or conflicting learning objectives.
- KNIFE achieves significant performance gains over existing FTR-based EBL methods.
- KNIFE can outperform baselines in low-resource settings.
Background
- LM for text classification predicts a confidence score for each (x, y i ) pair
- Free-text rationales (FTRs) explain the reasoning process for predicting label y i
- FTRs can be more intuitive and understandable to humans than extractive rationales
- FTRs can reference things beyond the task input
- FTRs support high flexibility in content, style, and length
- Extractive rationales assign an importance score to each input feature
- FTRs do not have such a feature-score correspondence
- Existing works have attempted to learn from FTRs by appending FTRs to the LM’s input or jointly training the LM to generate FTRs
- These paradigms have key limitations
- Proposed KNIFE approach for injecting FTR knowledge into LMs
Lm designs
- KNIFE consists of a teacher LM and a student LM
- Both LMs use text-to-text Transformer architectures
- Task input is denoted as n x-token sequence x
- Decoding is done by taking a special start token as input
- Teacher LM takes both task input and FTR as input
- Teacher LM has a bottleneck stage where FTR states are masked out
- Teacher LM’s task output states are called task output states
- Student LM takes only task input as input
- Student LM is trained to align with teacher LM’s task input/output states
- Goal is to maximize P (y* | x)
Knowledge distillation
- Teacher LM T is trained to predict compatibility of (y i , r) pairs
- FTR bottleneck design causes FTR knowledge to be transferred to other parts of T
- T enc must be trained to distill knowledge from r into task input states
- T dec must contain useful knowledge for reasoning process to compute compatibility scores
- Student LM S is trained with KD losses and task loss
- Total loss is defined as cross-entropy loss, task input states KD loss, and task output states KD loss
Experiments
- KNIFE achieves better performance than existing FTR-based EBL methods in fully-supervised settings
- KNIFE outperforms baselines in low-resource settings
- Extensive ablation studies validate KNIFE design choices
Evaluation protocol
- We consider two types of text classification tasks: multi-choice and closed-set
- We focus on two popular QA datasets: OBQA and StrategyQA
- We report mean and standard deviation accuracy over three random seeds
- We consider four different modes of LM input/output settings
Baselines
- Experiments consider a range of finetuning baselines
- Vanilla Finetuning refers to finetuning methods without KD
- Vanilla Finetuning methods distinguished by input-output mode
- Vanilla Finetuning + I→O and Vanilla Finetuning + IR→O trained using task loss
- Vanilla Finetuning + I→OR and Vanilla Finetuning + I→RO trained using greedy search decoding
- Three KNIFE variants, each using different combination of KD-in and KD-out losses
- KNIFE Student Finetuning (In), (Out), (In+Out) trained without task loss
- KNIFE Teacher Finetuning as upper bound
Implementation details
- Uses T5-Base and T5-Large as the LM architecture
- Teacher and student LMs use same architecture by default
- Student LM’s parameters initialized as teacher LM’s finetuned parameters
- T5-Large → T5-Base setting has different embedding dimensionalities
- Linear projection layer transforms student LM’s states to have same dimensionality as teacher LM’s
- Student LM uses teacher LM’s language modeling head
Main results
- Vanilla Finetuning + I→O and Vanilla Finetuning (KNIFE Teacher Init.) + I→O outperform other Vanilla Finetuning baselines
- Vanilla Finetuning + I→OR and Vanilla Finetuning + I→RO are not trained/evaluated with labelspecific teacher-forcing
- Vanilla Finetuning + I→RO performs poorly
- Vanilla Finetuning (KNIFE Teacher Init.) + I→O does not outperform Vanilla Finetuning + I→O
- KNIFE Student Finetuning consistently achieves significant improvements
- KNIFE Teacher Finetuning outperforms all other methods
- T5-Large based methods generally outperform T5-Base counterparts
Low-resource learning
- Low-resource learning experiments consider OBQA dataset and T5-Base architecture
- Vanilla Finetuning + I→O and Vanilla Finetuning (KNIFE Teacher Init.) + I→O perform competitively
- Vanilla Finetuning + IR→O is the best performing baseline
- KNIFE Student Finetuning significantly outperforms all baselines
- KNIFE Teacher Finetuning upper bound outperforms all other methods
- FTR type/quality matters a lot and gold FTRs are the source of KNIFE’s performance improvements
- FTR bottleneck is effective at distilling knowledge from the teacher LM’s FTR states to its task input/output states
- Student LM trained on both the KD losses and the task loss may get confused and learn a suboptimal reasoning process
- Extractive rationales identify important features in the input to explain how the decision should be made
- Free-text rationales explain the decision in a more flexible way
- FTRs can be appended to the input or used as the target output in addition to the label
- KNIFE distills the knowledge from a FTR-augmented teacher model to a student model
- Knowledge distillation has been used to develop efficient deep neural networks
- Types of knowledge for distillation include the logits of teacher model or the representations extracted from the intermediate layers
- KNIFE obtains additional knowledge from FTRs
- KNIFE encodes the FTR knowledge into the task input/output states of a teacher model and distills the knowledge to a student model