Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Chain of thought prompting improves reasoning capabilities of large language models.
- Reasoning capabilities only appear in models with over 100 billion parameters.
- Knowledge distillation can transfer reasoning capabilities to models with less than 100 billion parameters.
- Experiments show improved task performance across arithmetic, commonsense and symbolic reasoning datasets.
- Accuracy of T5 XXL on GSM8K improves from 8.11% to 21.99% when finetuned on PaLM-540B generated chains of thought.
Paper Content
Introduction
- Chain of thought (CoT) prompting increases accuracy of large language models (LLMs)
- Smaller language models (LMs) do not improve with CoT prompting
- CoT prompting reduces accuracy of models with less than 10 billion parameters
- Research question: can reasoning capabilities of LLMs be transferred to smaller LMs via finetuning?
Related work
- CoT prompting encourages models to break down problems into intermediate steps
- Wei et al. (2022) show that prefixing an input with CoT reasoning can improve performance
- Wang et al. (2022) show that task accuracy can be improved by using self-consistency in CoT prompting
- Chung et al. (2022) explore finetuning on CoT data
- Huang et al. (2022) explore self-improvement by finetuning on self-labelled solutions
- Our work explores student-teacher knowledge distillation across multiple datasets and model architectures
Method
- Proposed pipeline has two steps for CoT knowledge distillation
- First step involves annotating existing supervised dataset with CoT reasoning generated by teacher model
- Teacher model is an LLM, such as PaLM 540B or GPT-3 175B
- Few-shot prompting with 8 exemplars is used to generate CoTs
- Key modification to CoT prompts is to provide model with solution to task before providing example CoT
- Remove incorrect CoT based on target answer to prevent student model from learning from bad examples
- Second step of pipeline involves finetuning student model via teacher forcing
Experimental setup
- Focused on similar set of tasks
- Tasks cover arithmetic, commonsense and symbolic reasoning
Arithmetic reasoning
- Benchmarked proposed method on 3 math word problem datasets
- Used official training and testing split for GSM8K
- Performed 5-fold cross validation for MAWPS and ASDiv
- Did not evaluate on SVAMP or AQuA
- Evaluated task accuracy by checking if target answer is provided in CoT
- Computed task accuracy with external calculator to account for arithmetic mistakes
Commonsense reasoning
- Benchmarked model’s ability to perform commonsense reasoning on StrategyQA dataset
- 80% of dataset used for training, 10% for validation, 10% for testing
- Not benchmarking on CSQA dataset as it is too costly to infer
- Not evaluating on Date and Sports tasks from Big-bench effort or SayCan dataset as they are too small
- Task accuracy computed by checking for target answer in CoT
Symbolic reasoning
- Benchmarked model on two symbolic reasoning tasks
- Evaluated task accuracy by checking CoT for target answer
- Evaluated model’s generalisability to out-of-distribution examples
- Finetuned model on examples of length two, evaluated model’s ability to generalise to sequences of length three and four
Baselines and setup
- PaLM 540B and GPT-3 175B selected as teacher models
- Teacher models prompted as described in Section 3
- T5 selected as student model
- Student models trained on PaLM 540B or GPT-3 175B generated CoT data
- T5 XXL model finetuned on original target as baseline
- 5-fold cross validation for MAWPS and ASDiv datasets
- 10% of training set as validation set for all other datasets
- Dataset not shuffled to allow for reproducibility
Results
Arithmetic reasoning
- Providing a LLM with the answer to the question during CoT generation is beneficial
- Accuracy is higher when the model is conditioned on the expected answer
- Most of the benefits come from the model correcting its CoT
Commonsense reasoning
- CoT finetuning improved accuracy from 68.12% to 71.98% on the Strate-gyQA dataset.
- Improvement was not as significant as on the arithmetic reasoning datasets.
- This is likely due to the model lacking the factual knowledge required by the dataset.
Symbolic reasoning
- Traditional finetuning and proposed method do not improve Last Letter Concatenation tasks
- Model fails to generalise to longer sequence length
- Proposed method increases accuracy for Coinflip dataset
- Finetuning on CoT does not improve generalisation to four coinflips
Replicating results using different teacher models
- Demonstrated robustness of proposed method using GPT-3 175B
- Compared accuracy of PaLM 540B and GPT-3 175B
Ablation study on model size
- Finetuning T5 of different sizes on the GSM8K dataset can lead to performance gain.
- T5 base with 44 times fewer parameters than T5 XXL can match the performance of the baseline T5 XXL.
- T5 small can outperform the baseline T5 XXL with an external calculator.
Ablation study on dataset size
- Performance improved through student-teacher knowledge distillation
- Investigating trade-off between performance gain and dataset size
- Training T5 XXL on smaller subset of finetuning data
Discussion
- Finetuning smaller language models on data generated by larger language models can improve task accuracy.
- Improvements are task dependent.
- Knowledge distillation pipeline allows trade-off between model and dataset size and accuracy.
- Future work could explore improving reasoning of small models and generating new training data.
Conclusion
- Explored knowledge distillation from large language models to smaller language models
- Proposed a knowledge distillation pipeline with two steps
- Finetuning on CoT improves task accuracy across benchmarking datasets
- Evaluated model’s ability to generalise to out-of-distribution examples
- Reported accuracy of PaLM 540B for reference