-
Notifications
You must be signed in to change notification settings - Fork 501
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init * fix * nit * format * add readme * add inference example * nit * add multi-host training * rephrase catalog doc * Update examples/tpu/v6e/README.md Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com> --------- Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com>
- Loading branch information
1 parent
5dda9cf
commit c4eeeb5
Showing
9 changed files
with
290 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# TPU v6e | ||
|
||
Trillium (also refers to v6e) is Cloud TPU’s latest generation AI accelerator. SkyPilot support TPU v6e with provisioning, training and serving. | ||
|
||
## Catalogs | ||
|
||
Currently, for TPU v6e, the public APIs for regions and pricing is not released yet, and pricing info for `us-central1`, `us-central2`, `us-south1` is not available. We set the price to `0.0` in those regions for now. | ||
|
||
``` | ||
## Provisioning | ||
To provision TPU v6e, use the following command: | ||
```bash | ||
$ sky launch --gpus tpu-v6e-16 -c tpu-v6e | ||
``` | ||
|
||
After that, you can SSH to the instance and start developing your model: | ||
|
||
```bash | ||
$ ssh tpu-v6e | ||
``` | ||
|
||
## Training | ||
|
||
Examples in this directory (`train-llama3-8b.yaml`) shows how to use TPU v6e to train a Llama3 8b model, using PyTorch (XLA) on the wikitext dataset. To start the training, use the following command: | ||
|
||
```bash | ||
$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN | ||
``` | ||
|
||
### Single-Host Training | ||
|
||
The training throughput for a `tpu-v6e-8` instance should around 0.5 samples/s: | ||
|
||
```bash | ||
(task, pid=17499) ***** train metrics ***** | ||
(task, pid=17499) epoch = 1.1765 | ||
(task, pid=17499) total_flos = 109935420GF | ||
(task, pid=17499) train_loss = 10.6011 | ||
(task, pid=17499) train_runtime = 0:11:12.77 | ||
(task, pid=17499) train_samples = 282 | ||
(task, pid=17499) train_samples_per_second = 0.476 | ||
(task, pid=17499) train_steps_per_second = 0.03 | ||
INFO: Job finished (status: SUCCEEDED). | ||
``` | ||
|
||
### Multi-Host Training | ||
|
||
By changing the TPU type to `tpu-v6e-16` and the `--per_device_train_batch_size` to `32`, the training throughput increased to around 1 samples/s: | ||
|
||
```bash | ||
(head, rank=0, pid=17894) ***** train metrics ***** | ||
(head, rank=0, pid=17894) epoch = 2.5 | ||
(head, rank=0, pid=17894) total_flos = 219870840GF | ||
(head, rank=0, pid=17894) train_loss = 10.1527 | ||
(head, rank=0, pid=17894) train_runtime = 0:11:13.18 | ||
(head, rank=0, pid=17894) train_samples = 282 | ||
(head, rank=0, pid=17894) train_samples_per_second = 0.951 | ||
(head, rank=0, pid=17894) train_steps_per_second = 0.03 | ||
|
||
(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics ***** | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5 | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527 | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08 | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282 | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948 | ||
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03 | ||
|
||
(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics ***** | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5 | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527 | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61 | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282 | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947 | ||
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03 | ||
|
||
(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics ***** | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5 | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527 | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10 | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282 | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948 | ||
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03 | ||
|
||
INFO: Job finished (status: SUCCEEDED). | ||
``` | ||
|
||
# Serving | ||
|
||
TPU v6e also supports serving. Examples in this directory (`serve-llama2-7b.yaml`) shows how to use TPU v6e to serve a Llama2 7b model, using PyTorch (XLA) and the JetStream lib. To start the serving, use the following command: | ||
|
||
```bash | ||
$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN | ||
``` | ||
|
||
After the server is ready, you should see the following message: | ||
|
||
```bash | ||
(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads | ||
(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads | ||
(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False | ||
(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False | ||
(task, pid=26431) Started jetstream_server.... | ||
``` | ||
|
||
You can now start a benchmark to test the serving performance: | ||
|
||
```bash | ||
$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml | ||
... (emitted logs) | ||
(task, pid=25491) Successful requests: 100 | ||
(task, pid=25491) Benchmark duration: 8.753792 s | ||
(task, pid=25491) Total input tokens: 21888 | ||
(task, pid=25491) Total generated tokens: 18803 | ||
(task, pid=25491) Request throughput: 11.42 requests/s | ||
(task, pid=25491) Input token throughput: 2500.40 tokens/s | ||
(task, pid=25491) Output token throughput: 2147.98 tokens/s | ||
(task, pid=25491) Mean TTFT: 1981.93 ms | ||
(task, pid=25491) Median TTFT: 1829.33 ms | ||
(task, pid=25491) P99 TTFT: 4511.95 ms | ||
(task, pid=25491) Mean TPOT: 130.71 ms | ||
(task, pid=25491) Median TPOT: 18.88 ms | ||
(task, pid=25491) P99 TPOT: 2487.37 ms | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
envs: | ||
model_name: llama-2 | ||
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model | ||
|
||
run: | | ||
cd JetStream | ||
python benchmarks/benchmark_serving.py \ | ||
--tokenizer=$tokenizer_path --num-prompts=100 \ | ||
--dataset openorca --save-request-outputs \ | ||
--warmup-mode=sampled --model=$model_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
{ | ||
"architectures": [ | ||
"LlamaForCausalLM" | ||
], | ||
"attention_bias": false, | ||
"attention_dropout": 0.0, | ||
"bos_token_id": 128000, | ||
"eos_token_id": 128001, | ||
"hidden_act": "silu", | ||
"hidden_size": 4096, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 14336, | ||
"max_position_embeddings": 8192, | ||
"model_type": "llama", | ||
"num_attention_heads": 32, | ||
"num_hidden_layers": 32, | ||
"num_key_value_heads": 8, | ||
"pretraining_tp": 1, | ||
"rms_norm_eps": 1e-05, | ||
"rope_scaling": null, | ||
"rope_theta": 500000.0, | ||
"tie_word_embeddings": false, | ||
"torch_dtype": "bfloat16", | ||
"transformers_version": "4.40.0.dev0", | ||
"use_cache": true, | ||
"vocab_size": 128256 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"fsdp_transformer_layer_cls_to_wrap": [ | ||
"LlamaDecoderLayer" | ||
], | ||
"xla": true, | ||
"xla_fsdp_v2": true, | ||
"xla_fsdp_grad_ckpt": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
resources: | ||
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use | ||
|
||
envs: | ||
HF_TOKEN: # fill in your huggingface token | ||
HF_REPO_ID: meta-llama/Llama-2-7b | ||
model_name: llama-2 | ||
input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original | ||
output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted | ||
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model | ||
|
||
setup: | | ||
pip3 install huggingface_hub | ||
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" | ||
# Setup TPU | ||
pip3 install cloud-tpu-client | ||
sudo apt update | ||
sudo apt install -y libopenblas-base | ||
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ | ||
--index-url https://download.pytorch.org/whl/nightly/cpu | ||
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ | ||
-f https://storage.googleapis.com/libtpu-releases/index.html | ||
pip install torch_xla[pallas] \ | ||
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ | ||
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | ||
# Setup runtime for serving | ||
git clone https://github.com/google/JetStream.git | ||
cd JetStream | ||
git checkout main | ||
git pull origin main | ||
pip install -e . | ||
cd benchmarks | ||
pip install -r requirements.in | ||
cd ../.. | ||
git clone https://github.com/google/jetstream-pytorch.git | ||
cd jetstream-pytorch/ | ||
git checkout jetstream-v0.2.3 | ||
source install_everything.sh | ||
pip3 install -U --pre jax jaxlib libtpu-nightly requests \ | ||
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ | ||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html | ||
# Prepare checkpoint, inside jetstream-pytorch repo | ||
mkdir -p ${input_ckpt_dir} | ||
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')" | ||
mkdir -p ${output_ckpt_dir} | ||
python -m convert_checkpoints --model_name=$model_name \ | ||
--input_checkpoint_dir=$input_ckpt_dir \ | ||
--output_checkpoint_dir=$output_ckpt_dir | ||
run: | | ||
cd jetstream-pytorch | ||
python run_server.py --model_name=$model_name \ | ||
--size=7b --batch_size=24 --max_cache_length=2048 \ | ||
--checkpoint_path=$output_ckpt_dir \ | ||
--tokenizer_path=$tokenizer_path \ | ||
--sharding_config="default_shardings/llama.yaml" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
resources: | ||
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use | ||
|
||
envs: | ||
HF_TOKEN: # fill in your huggingface token | ||
|
||
workdir: . | ||
|
||
setup: | | ||
pip3 install huggingface_hub | ||
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" | ||
# Setup TPU | ||
pip3 install cloud-tpu-client | ||
sudo apt update | ||
sudo apt install -y libopenblas-base | ||
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ | ||
--index-url https://download.pytorch.org/whl/nightly/cpu | ||
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ | ||
-f https://storage.googleapis.com/libtpu-releases/index.html | ||
pip install torch_xla[pallas] \ | ||
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ | ||
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | ||
# Setup runtime for training | ||
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git | ||
cd transformers | ||
pip3 install -e . | ||
pip3 install datasets evaluate scikit-learn accelerate | ||
run: | | ||
unset LD_PRELOAD | ||
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \ | ||
python3 transformers/examples/pytorch/language-modeling/run_clm.py \ | ||
--dataset_name wikitext \ | ||
--dataset_config_name wikitext-2-raw-v1 \ | ||
--per_device_train_batch_size 16 \ | ||
--do_train \ | ||
--output_dir /home/$USER/tmp/test-clm \ | ||
--overwrite_output_dir \ | ||
--config_name /home/$USER/sky_workdir/config-8B.json \ | ||
--cache_dir /home/$USER/cache \ | ||
--tokenizer_name meta-llama/Meta-Llama-3-8B \ | ||
--block_size 8192 \ | ||
--optim adafactor \ | ||
--save_strategy no \ | ||
--logging_strategy no \ | ||
--fsdp "full_shard" \ | ||
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \ | ||
--torch_dtype bfloat16 \ | ||
--dataloader_drop_last yes \ | ||
--flash_attention \ | ||
--max_steps 20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters