This repository contains the official code and data for our ACL 2024 paper Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning.
Teaching small language models (e.g., T5-large) chain-of-thought reasoning by distilling from larger models like GPT-4 is shown to be effective. However, relying on such propietry large models can be both economically and computationally costly. Our paper demonstrates that small language models are capable of learning from their own generations in a self-training manner, starting with a limited amount of high-quality, human-annotated training data. Additionally, we present an efficient method for integrating external calculators during inference to boost performance.
Our approach demonstrates superior performance while minimizing the required compute cost.
DPO-augmented Self-Training is built upon the conventional self-training framework. Unlike traditional self-training framework where the pseudo-labels are generated by the SFT models, we add an additional DPO step in each self-training iteration and make pseudo-labels from the DPO model. We empirically found that the DPO models can generate more diverse pseudo-labels with higher quality.
Integrating external calculators during model inference can enhance math reasoning performance. However, many previous efforts support only a batch size of 1, significantly slowing down inference speed. In this work, we present an efficient method for integrating external calculators that supports larger inference batch sizes. Specifically, we design a LogitsProcessor that modifies model's output during inference. More details about our implementation can be found at generate.py.
Inference speed-up comparison with Flan-T5-Large on a single A40 GPU.
Please follow the following steps before running our code.
- Use Conda to create a Python virtual environment:
conda create -n dpo-st python=3.10
conda activate dpo-st
- Install the Python dependencies with pip.
pip install requirements.txt
- Loggin to huggingface for downloading pre-trained model weights
huggingface-cli login --token "${your_hf_token}"
- Set the environment variable DATA_DIR and download pre-trained model weights from huggingface into
DATA_DIR/hf_models
. For example,
DATA_DIR='.'
huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir DATA_DIR/hf_models/llama-2
We recommend using python-dotenv
to define the DATA_DIR
in to your .env
file
as this environment variable will be used in the subsequent steps.
The first step of DPO-ST is to warm-up the pre-trained language model by fine-tuning it on the labeled dataset.
For Flan-T5-Large, run the following command:
ACC_CONFIG='acc_config/ddp8.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-name=sft-0
For Llama-2-7b, run the following command:
ACC_CONFIG='acc_config/fsdp.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-path=exp_config/llama --config-name=sft-0
First, to sample pseudo-labels from the SFT model:
ARGS='+data.split="train" eval.mode="sampling" eval.sampling.max_seed=5'
torchrun --nproc_per_node 8 generate.py --config-name=sft-0 $ARGS
python3 eval_sampling.py --config-name=sft-0 $ARGS
Then, make DPO training data from the SFT model generations:
python3 utils/make_dpo_data.py --config-name=sft-0
Note that the above code is for T5 models. For Llama, add --config-path=exp_config/llama
for each command.
For T5:
ACC_CONFIG='acc_config/ddp8.yaml'
accelerate launch --config_file $ACC_CONFIG dpo.py --config-name=dpo-1
For Llama:
ACC_CONFIG='acc_config/fsdp.yaml'
accelerate launch --config_file $ACC_CONFIG dpo.py --config-path=exp_config/llama --config-name=dpo-1
ARGS='+data.split="train" eval.mode="sampling" eval.sampling.max_seed=3'
torchrun --nproc_per_node 8 greedy_decode.py --config-name=dpo-1 $ARGS
python3 eval_sampling.py --config-name=dpo-1 $ARGS
python3 utils/make_rft_data.py --config-name=dpo-1
You can control the number of sampled generations per question by adjusting eval.sampling.max_seed
.
For T5:
ACC_CONFIG='acc_config/ddp8.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-name=sft-1
For Llama:
ACC_CONFIG='acc_config/fsdp.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-path=exp_config/llama --config-name=sft-1
CONFIG_PATH='exp_config/t5'
SPLIT='test'
torchrun --nproc_per_node 8 generate.py --config-path=$CONFIG_PATH --config-name=dpo-1 +data.split=$SPLIT
python3 eval_greedy.py --config-path=$CONFIG_PATH --config-name=dpo-1 +data.split=$SPLIT
CONFIG_PATH
: set it toexp_config/t5
for t5 models andexp_config/llama
for llama modelsSPLIT
: set it todev
for dev set results andtest
for test set results
If you find this paper useful, please consider citing it
@inproceedings{wang2024dpost,
title={Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning},
author={Tianduo Wang and Shichen Li and Wei Lu},
year={2024},
booktitle = {Proceedings of ACL},
}
This repo is largely inspired by GSM8K-ScRel and TRL. We are grateful to the authors for their brilliant work.