Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Chain of thought prompting improves reasoning capabilities of large language models.
  • Reasoning capabilities only appear in models with over 100 billion parameters.
  • Knowledge distillation can transfer reasoning capabilities to models with less than 100 billion parameters.
  • Experiments show improved task performance across arithmetic, commonsense and symbolic reasoning datasets.
  • Accuracy of T5 XXL on GSM8K improves from 8.11% to 21.99% when finetuned on PaLM-540B generated chains of thought.

Paper Content

Introduction

  • Chain of thought (CoT) prompting increases accuracy of large language models (LLMs)
  • Smaller language models (LMs) do not improve with CoT prompting
  • CoT prompting reduces accuracy of models with less than 10 billion parameters
  • Research question: can reasoning capabilities of LLMs be transferred to smaller LMs via finetuning?
  • CoT prompting encourages models to break down problems into intermediate steps
  • Wei et al. (2022) show that prefixing an input with CoT reasoning can improve performance
  • Wang et al. (2022) show that task accuracy can be improved by using self-consistency in CoT prompting
  • Chung et al. (2022) explore finetuning on CoT data
  • Huang et al. (2022) explore self-improvement by finetuning on self-labelled solutions
  • Our work explores student-teacher knowledge distillation across multiple datasets and model architectures

Method

  • Proposed pipeline has two steps for CoT knowledge distillation
  • First step involves annotating existing supervised dataset with CoT reasoning generated by teacher model
  • Teacher model is an LLM, such as PaLM 540B or GPT-3 175B
  • Few-shot prompting with 8 exemplars is used to generate CoTs
  • Key modification to CoT prompts is to provide model with solution to task before providing example CoT
  • Remove incorrect CoT based on target answer to prevent student model from learning from bad examples
  • Second step of pipeline involves finetuning student model via teacher forcing

Experimental setup

  • Focused on similar set of tasks
  • Tasks cover arithmetic, commonsense and symbolic reasoning

Arithmetic reasoning

  • Benchmarked proposed method on 3 math word problem datasets
  • Used official training and testing split for GSM8K
  • Performed 5-fold cross validation for MAWPS and ASDiv
  • Did not evaluate on SVAMP or AQuA
  • Evaluated task accuracy by checking if target answer is provided in CoT
  • Computed task accuracy with external calculator to account for arithmetic mistakes

Commonsense reasoning

  • Benchmarked model’s ability to perform commonsense reasoning on StrategyQA dataset
  • 80% of dataset used for training, 10% for validation, 10% for testing
  • Not benchmarking on CSQA dataset as it is too costly to infer
  • Not evaluating on Date and Sports tasks from Big-bench effort or SayCan dataset as they are too small
  • Task accuracy computed by checking for target answer in CoT

Symbolic reasoning

  • Benchmarked model on two symbolic reasoning tasks
  • Evaluated task accuracy by checking CoT for target answer
  • Evaluated model’s generalisability to out-of-distribution examples
  • Finetuned model on examples of length two, evaluated model’s ability to generalise to sequences of length three and four

Baselines and setup

  • PaLM 540B and GPT-3 175B selected as teacher models
  • Teacher models prompted as described in Section 3
  • T5 selected as student model
  • Student models trained on PaLM 540B or GPT-3 175B generated CoT data
  • T5 XXL model finetuned on original target as baseline
  • 5-fold cross validation for MAWPS and ASDiv datasets
  • 10% of training set as validation set for all other datasets
  • Dataset not shuffled to allow for reproducibility

Results

Arithmetic reasoning

  • Providing a LLM with the answer to the question during CoT generation is beneficial
  • Accuracy is higher when the model is conditioned on the expected answer
  • Most of the benefits come from the model correcting its CoT

Commonsense reasoning

  • CoT finetuning improved accuracy from 68.12% to 71.98% on the Strate-gyQA dataset.
  • Improvement was not as significant as on the arithmetic reasoning datasets.
  • This is likely due to the model lacking the factual knowledge required by the dataset.

Symbolic reasoning

  • Traditional finetuning and proposed method do not improve Last Letter Concatenation tasks
  • Model fails to generalise to longer sequence length
  • Proposed method increases accuracy for Coinflip dataset
  • Finetuning on CoT does not improve generalisation to four coinflips

Replicating results using different teacher models

  • Demonstrated robustness of proposed method using GPT-3 175B
  • Compared accuracy of PaLM 540B and GPT-3 175B

Ablation study on model size

  • Finetuning T5 of different sizes on the GSM8K dataset can lead to performance gain.
  • T5 base with 44 times fewer parameters than T5 XXL can match the performance of the baseline T5 XXL.
  • T5 small can outperform the baseline T5 XXL with an external calculator.

Ablation study on dataset size

  • Performance improved through student-teacher knowledge distillation
  • Investigating trade-off between performance gain and dataset size
  • Training T5 XXL on smaller subset of finetuning data

Discussion

  • Finetuning smaller language models on data generated by larger language models can improve task accuracy.
  • Improvements are task dependent.
  • Knowledge distillation pipeline allows trade-off between model and dataset size and accuracy.
  • Future work could explore improving reasoning of small models and generating new training data.

Conclusion

  • Explored knowledge distillation from large language models to smaller language models
  • Proposed a knowledge distillation pipeline with two steps
  • Finetuning on CoT improves task accuracy across benchmarking datasets
  • Evaluated model’s ability to generalise to out-of-distribution examples
  • Reported accuracy of PaLM 540B for reference