Skip to content

Commit

Permalink
[Core] Support TPU v6 (#4220)
Browse files Browse the repository at this point in the history
* init

* fix

* nit

* format

* add readme

* add inference example

* nit

* add multi-host training

* rephrase catalog doc

* Update examples/tpu/v6e/README.md

Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com>

---------

Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com>
  • Loading branch information
cblmemo and Michaelvll authored Oct 31, 2024
1 parent 5dda9cf commit c4eeeb5
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 9 deletions.
128 changes: 128 additions & 0 deletions examples/tpu/v6e/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# TPU v6e

Trillium (also refers to v6e) is Cloud TPU’s latest generation AI accelerator. SkyPilot support TPU v6e with provisioning, training and serving.

## Catalogs

Currently, for TPU v6e, the public APIs for regions and pricing is not released yet, and pricing info for `us-central1`, `us-central2`, `us-south1` is not available. We set the price to `0.0` in those regions for now.

```
## Provisioning
To provision TPU v6e, use the following command:
```bash
$ sky launch --gpus tpu-v6e-16 -c tpu-v6e
```

After that, you can SSH to the instance and start developing your model:

```bash
$ ssh tpu-v6e
```

## Training

Examples in this directory (`train-llama3-8b.yaml`) shows how to use TPU v6e to train a Llama3 8b model, using PyTorch (XLA) on the wikitext dataset. To start the training, use the following command:

```bash
$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN
```

### Single-Host Training

The training throughput for a `tpu-v6e-8` instance should around 0.5 samples/s:

```bash
(task, pid=17499) ***** train metrics *****
(task, pid=17499) epoch = 1.1765
(task, pid=17499) total_flos = 109935420GF
(task, pid=17499) train_loss = 10.6011
(task, pid=17499) train_runtime = 0:11:12.77
(task, pid=17499) train_samples = 282
(task, pid=17499) train_samples_per_second = 0.476
(task, pid=17499) train_steps_per_second = 0.03
INFO: Job finished (status: SUCCEEDED).
```

### Multi-Host Training

By changing the TPU type to `tpu-v6e-16` and the `--per_device_train_batch_size` to `32`, the training throughput increased to around 1 samples/s:

```bash
(head, rank=0, pid=17894) ***** train metrics *****
(head, rank=0, pid=17894) epoch = 2.5
(head, rank=0, pid=17894) total_flos = 219870840GF
(head, rank=0, pid=17894) train_loss = 10.1527
(head, rank=0, pid=17894) train_runtime = 0:11:13.18
(head, rank=0, pid=17894) train_samples = 282
(head, rank=0, pid=17894) train_samples_per_second = 0.951
(head, rank=0, pid=17894) train_steps_per_second = 0.03

(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03

(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03

(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03

INFO: Job finished (status: SUCCEEDED).
```

# Serving

TPU v6e also supports serving. Examples in this directory (`serve-llama2-7b.yaml`) shows how to use TPU v6e to serve a Llama2 7b model, using PyTorch (XLA) and the JetStream lib. To start the serving, use the following command:

```bash
$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN
```

After the server is ready, you should see the following message:

```bash
(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads
(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads
(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False
(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False
(task, pid=26431) Started jetstream_server....
```

You can now start a benchmark to test the serving performance:

```bash
$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml
... (emitted logs)
(task, pid=25491) Successful requests: 100
(task, pid=25491) Benchmark duration: 8.753792 s
(task, pid=25491) Total input tokens: 21888
(task, pid=25491) Total generated tokens: 18803
(task, pid=25491) Request throughput: 11.42 requests/s
(task, pid=25491) Input token throughput: 2500.40 tokens/s
(task, pid=25491) Output token throughput: 2147.98 tokens/s
(task, pid=25491) Mean TTFT: 1981.93 ms
(task, pid=25491) Median TTFT: 1829.33 ms
(task, pid=25491) P99 TTFT: 4511.95 ms
(task, pid=25491) Mean TPOT: 130.71 ms
(task, pid=25491) Median TPOT: 18.88 ms
(task, pid=25491) P99 TPOT: 2487.37 ms
```
10 changes: 10 additions & 0 deletions examples/tpu/v6e/benchmark-llama2-7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
envs:
model_name: llama-2
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model

run: |
cd JetStream
python benchmarks/benchmark_serving.py \
--tokenizer=$tokenizer_path --num-prompts=100 \
--dataset openorca --save-request-outputs \
--warmup-mode=sampled --model=$model_name
27 changes: 27 additions & 0 deletions examples/tpu/v6e/config-8B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0.dev0",
"use_cache": true,
"vocab_size": 128256
}
8 changes: 8 additions & 0 deletions examples/tpu/v6e/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"fsdp_transformer_layer_cls_to_wrap": [
"LlamaDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
60 changes: 60 additions & 0 deletions examples/tpu/v6e/serve-llama2-7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use

envs:
HF_TOKEN: # fill in your huggingface token
HF_REPO_ID: meta-llama/Llama-2-7b
model_name: llama-2
input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original
output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model

setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for serving
git clone https://github.com/google/JetStream.git
cd JetStream
git checkout main
git pull origin main
pip install -e .
cd benchmarks
pip install -r requirements.in
cd ../..
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch/
git checkout jetstream-v0.2.3
source install_everything.sh
pip3 install -U --pre jax jaxlib libtpu-nightly requests \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Prepare checkpoint, inside jetstream-pytorch repo
mkdir -p ${input_ckpt_dir}
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')"
mkdir -p ${output_ckpt_dir}
python -m convert_checkpoints --model_name=$model_name \
--input_checkpoint_dir=$input_ckpt_dir \
--output_checkpoint_dir=$output_ckpt_dir
run: |
cd jetstream-pytorch
python run_server.py --model_name=$model_name \
--size=7b --batch_size=24 --max_cache_length=2048 \
--checkpoint_path=$output_ckpt_dir \
--tokenizer_path=$tokenizer_path \
--sharding_config="default_shardings/llama.yaml"
53 changes: 53 additions & 0 deletions examples/tpu/v6e/train-llama3-8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use

envs:
HF_TOKEN: # fill in your huggingface token

workdir: .

setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for training
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets evaluate scikit-learn accelerate
run: |
unset LD_PRELOAD
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 16 \
--do_train \
--output_dir /home/$USER/tmp/test-clm \
--overwrite_output_dir \
--config_name /home/$USER/sky_workdir/config-8B.json \
--cache_dir /home/$USER/cache \
--tokenizer_name meta-llama/Meta-Llama-3-8B \
--block_size 8192 \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--max_steps 20
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2467,7 +2467,7 @@ def num_ips_per_node(self) -> int:
"""Returns number of IPs per node in the cluster, handling TPU Pod."""
is_tpu_vm_pod = gcp_utils.is_tpu_vm_pod(self.launched_resources)
if is_tpu_vm_pod:
num_ips = gcp_utils.get_num_tpu_devices(self.launched_resources)
num_ips = len(self.internal_ips())
else:
num_ips = 1
return num_ips
Expand Down
8 changes: 0 additions & 8 deletions sky/clouds/utils/gcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,6 @@ def is_tpu_vm_pod(resources: Optional['resources_lib.Resources']) -> bool:
return not acc.endswith('-8')


def get_num_tpu_devices(resources: Optional['resources_lib.Resources']) -> int:
if resources is None or not is_tpu(resources):
raise ValueError('resources must be a valid TPU resource.')
acc, _ = list(resources.accelerators.items())[0]
num_tpu_devices = int(int(acc.split('-')[2]) / 8)
return num_tpu_devices


@dataclasses.dataclass
class SpecificReservation:
count: int
Expand Down
3 changes: 3 additions & 0 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,9 @@ def _get_default_runtime_version() -> str:
# TPU V5 requires a newer runtime version.
if acc.startswith('tpu-v5'):
return 'v2-alpha-tpuv5'
# TPU V6e requires a newer runtime version.
if acc.startswith('tpu-v6e'):
return 'v2-alpha-tpuv6e'
return 'tpu-vm-base'

accelerator_args['runtime_version'] = (
Expand Down

0 comments on commit c4eeeb5

Please sign in to comment.