Skip to content

Sequential-Intelligence-Lab/InContextTD

Repository files navigation

InContextTD

Welcome to the InContextTD repository, which accompanies the paper: Transformers Learn Temporal Difference Methods for In-Context Reinforcement Learning.

Table of Contents

Introduction

This repository provides the code to empirically demonstrate how transformers can learn to implement temporal difference (TD) methods for in-context policy evaluation. The experiments explore transformers' ability to apply TD learning during inference without requiring parameter updates.

Dependencies

To install the required dependencies, first clone this repository, then run the following command on the root directory of the project:

pip install .

Usage

Quick Start

To quickly replicate the experiments from Figure 2 of the paper, execute the following command:

python main.py --suffix=linear_standard -v

This will generate the following plots:

P Metrics Plot Q Metrics Plot Final Learned P and Q Batch TD Comparison

The generated figures will be saved in:
  • ./logs/YYYY-MM-DD-HH-MM-SS/linear_standard/averaged_figures/ (aggregated results across all seeds)

  • ./logs/YYYY-MM-DD-HH-MM-SS/linear_standard/seed_SEED/figures/ (diagnostic figures for each individual seed)

Custom Experiment Settings

To run experiments with custom configurations, use:

python main.py [options]

Below is a list of the command-line arguments available for main.py:

  • -d, --dim_feature: Feature dimension (default: 4)
  • -s, --num_states: Number of states (default: 10)
  • -n, --context_length: Context length (default: 30)
  • -l, --num_layers: Number of transformer layers (default: 3)
  • --gamma: Discount factor (default: 0.9)
  • --activation: Activation function (choices: ['identity', 'softmax', 'relu'])
  • --representable: Flag to randomly sample a true weight vector that allows the value function to be fully represented by the features
  • --n_mrps: Number of MRPs used for training (default: 4000)
  • --batch_size: Mini-batch size (default: 64)
  • --n_batch_per_mrp: Number of mini-batches sampled per MRP (default: 5)
  • --lr: Learning rate (default: 0.001)
  • --weight_decay: Regularization term (default: 1e-6)
  • --log_interval: Frequency of logging during training (default: 10)
  • --mode: Training mode auto-regressive or sequential (choices: ['auto', 'sequential'], default: 'auto')
  • --seed: Random seeds (default: list(range(1, 31)))
  • --save_dir: Directory to save logs (default: None)
  • --suffix: Suffix to append to the log save directory (default: None)
  • --gen_gif: Flag to generate a GIF showing the evolution of weights (under construction)
  • -v, --verbose: Flag to print detailed training progress

If no --save_dir is specified, logs will be saved in ./logs/YYYY-MM-DD-HH-MM-SS. If a --suffix is provided, logs will be saved in ./logs/YYYY-MM-DD-HH-MM-SS/SUFFIX.

Demo

We have a demo script to demonstrate the performance of the TD algorithm implemented by the linear transformer under our theoretical construction. The script generates a figure of the mean square value error (MSVE) averaged over the number of randomly generated MRPs against a sequence of increasing context lengths. Note that we employ fully representable value functions here to make sure the minimum MSVE is zero.

Demo

To run the script, use

python demo.py [options]

Below is a list of the command-line arguments available for demo.py:

  • -d, --dim_feature: Feature dimension (default: 5)
  • -l, --num_layers: Number of transformer layers/TD updates (default: 15)
  • -smin, --min_state_num: Minimum possible state number of the randomly generated MRP (default: 5)
  • -smax, --max_state_num: Maximum possible state number of the randomly generated MRP (default: 15)
  • --gamma: Discount factor (default: 0.9)
  • --lr: learning rate of the implemented in-context TD algorithm (default: 0.2)
  • --n_mrps: Number of randomly generated MRPs to test on (default: 300)
  • -nmin, --min_ctxt_len: Minimum context length (default: 1)
  • -nmax, --max_ctxt_len: Maximum context length (default: 40)
  • --ctxt_step: Context length increment step (default: 2)
  • --seed: Random seed (default: 42)
  • --save_dir: Directory to save demo results (default: 'logs')

By default, the result is saved to ./logs/demo.

Theory Verification

We provide a script to numerically verify our theories. The script computes the absolute errors in log scale between the value predictions by the linear transformers and the direct implementations of their corresponding in-context algorithms.

Theory Verification

To run the script, use

python verify.py [options]

Below is a list of the command-line arguments available for verify.py:

  • -d, --dim_feature: Feature dimension (default: 3)
  • -n, --context_length: Context length (default: 100)
  • -l, --num_layers: Number of transformer layers/TD updates (default: 40)
  • --num_trials: Number of trials to run (default: 30)
  • --seed: Random seed (default: 42)
  • --save_dir: Directory to save theory verification results (default: 'logs')

By default, the result is saved to ./logs/theory.

Complete Replication

To run all the experiments from the paper in one go, execute the following shell script:

./run.sh

Test

To test run the experiments in small scale, execute the following shell script:

./test.sh

The test results are stored in ./logs/test.

License

This project is licensed under the MIT License. See the LICENSE file for more information.