diff --git a/examples/tpu/v6e/README.md b/examples/tpu/v6e/README.md new file mode 100644 index 00000000000..c5eca93ca91 --- /dev/null +++ b/examples/tpu/v6e/README.md @@ -0,0 +1,128 @@ +# TPU v6e + +Trillium (also refers to v6e) is Cloud TPU’s latest generation AI accelerator. SkyPilot support TPU v6e with provisioning, training and serving. + +## Catalogs + +Currently, for TPU v6e, the public APIs for regions and pricing is not released yet, and pricing info for `us-central1`, `us-central2`, `us-south1` is not available. We set the price to `0.0` in those regions for now. + +``` +## Provisioning + +To provision TPU v6e, use the following command: + +```bash +$ sky launch --gpus tpu-v6e-16 -c tpu-v6e +``` + +After that, you can SSH to the instance and start developing your model: + +```bash +$ ssh tpu-v6e +``` + +## Training + +Examples in this directory (`train-llama3-8b.yaml`) shows how to use TPU v6e to train a Llama3 8b model, using PyTorch (XLA) on the wikitext dataset. To start the training, use the following command: + +```bash +$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN +``` + +### Single-Host Training + +The training throughput for a `tpu-v6e-8` instance should around 0.5 samples/s: + +```bash +(task, pid=17499) ***** train metrics ***** +(task, pid=17499) epoch = 1.1765 +(task, pid=17499) total_flos = 109935420GF +(task, pid=17499) train_loss = 10.6011 +(task, pid=17499) train_runtime = 0:11:12.77 +(task, pid=17499) train_samples = 282 +(task, pid=17499) train_samples_per_second = 0.476 +(task, pid=17499) train_steps_per_second = 0.03 +INFO: Job finished (status: SUCCEEDED). +``` + +### Multi-Host Training + +By changing the TPU type to `tpu-v6e-16` and the `--per_device_train_batch_size` to `32`, the training throughput increased to around 1 samples/s: + +```bash +(head, rank=0, pid=17894) ***** train metrics ***** +(head, rank=0, pid=17894) epoch = 2.5 +(head, rank=0, pid=17894) total_flos = 219870840GF +(head, rank=0, pid=17894) train_loss = 10.1527 +(head, rank=0, pid=17894) train_runtime = 0:11:13.18 +(head, rank=0, pid=17894) train_samples = 282 +(head, rank=0, pid=17894) train_samples_per_second = 0.951 +(head, rank=0, pid=17894) train_steps_per_second = 0.03 + +(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics ***** +(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5 +(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03 + +(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics ***** +(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5 +(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03 + +(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics ***** +(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5 +(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03 + +INFO: Job finished (status: SUCCEEDED). +``` + +# Serving + +TPU v6e also supports serving. Examples in this directory (`serve-llama2-7b.yaml`) shows how to use TPU v6e to serve a Llama2 7b model, using PyTorch (XLA) and the JetStream lib. To start the serving, use the following command: + +```bash +$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN +``` + +After the server is ready, you should see the following message: + +```bash +(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads +(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads +(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False +(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False +(task, pid=26431) Started jetstream_server.... +``` + +You can now start a benchmark to test the serving performance: + +```bash +$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml +... (emitted logs) +(task, pid=25491) Successful requests: 100 +(task, pid=25491) Benchmark duration: 8.753792 s +(task, pid=25491) Total input tokens: 21888 +(task, pid=25491) Total generated tokens: 18803 +(task, pid=25491) Request throughput: 11.42 requests/s +(task, pid=25491) Input token throughput: 2500.40 tokens/s +(task, pid=25491) Output token throughput: 2147.98 tokens/s +(task, pid=25491) Mean TTFT: 1981.93 ms +(task, pid=25491) Median TTFT: 1829.33 ms +(task, pid=25491) P99 TTFT: 4511.95 ms +(task, pid=25491) Mean TPOT: 130.71 ms +(task, pid=25491) Median TPOT: 18.88 ms +(task, pid=25491) P99 TPOT: 2487.37 ms +``` diff --git a/examples/tpu/v6e/benchmark-llama2-7b.yaml b/examples/tpu/v6e/benchmark-llama2-7b.yaml new file mode 100644 index 00000000000..d6fa002e160 --- /dev/null +++ b/examples/tpu/v6e/benchmark-llama2-7b.yaml @@ -0,0 +1,10 @@ +envs: + model_name: llama-2 + tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model + +run: | + cd JetStream + python benchmarks/benchmark_serving.py \ + --tokenizer=$tokenizer_path --num-prompts=100 \ + --dataset openorca --save-request-outputs \ + --warmup-mode=sampled --model=$model_name diff --git a/examples/tpu/v6e/config-8B.json b/examples/tpu/v6e/config-8B.json new file mode 100644 index 00000000000..175a749e1bd --- /dev/null +++ b/examples/tpu/v6e/config-8B.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/examples/tpu/v6e/fsdp_config.json b/examples/tpu/v6e/fsdp_config.json new file mode 100644 index 00000000000..a37a04352ee --- /dev/null +++ b/examples/tpu/v6e/fsdp_config.json @@ -0,0 +1,8 @@ +{ + "fsdp_transformer_layer_cls_to_wrap": [ + "LlamaDecoderLayer" + ], + "xla": true, + "xla_fsdp_v2": true, + "xla_fsdp_grad_ckpt": true +} diff --git a/examples/tpu/v6e/serve-llama2-7b.yaml b/examples/tpu/v6e/serve-llama2-7b.yaml new file mode 100644 index 00000000000..49d0bf9fcd2 --- /dev/null +++ b/examples/tpu/v6e/serve-llama2-7b.yaml @@ -0,0 +1,60 @@ +resources: + accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use + +envs: + HF_TOKEN: # fill in your huggingface token + HF_REPO_ID: meta-llama/Llama-2-7b + model_name: llama-2 + input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original + output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted + tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model + +setup: | + pip3 install huggingface_hub + python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + + # Setup TPU + pip3 install cloud-tpu-client + sudo apt update + sudo apt install -y libopenblas-base + pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ + --index-url https://download.pytorch.org/whl/nightly/cpu + pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ + -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Setup runtime for serving + git clone https://github.com/google/JetStream.git + cd JetStream + git checkout main + git pull origin main + pip install -e . + cd benchmarks + pip install -r requirements.in + cd ../.. + git clone https://github.com/google/jetstream-pytorch.git + cd jetstream-pytorch/ + git checkout jetstream-v0.2.3 + source install_everything.sh + pip3 install -U --pre jax jaxlib libtpu-nightly requests \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + + # Prepare checkpoint, inside jetstream-pytorch repo + mkdir -p ${input_ckpt_dir} + python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')" + mkdir -p ${output_ckpt_dir} + python -m convert_checkpoints --model_name=$model_name \ + --input_checkpoint_dir=$input_ckpt_dir \ + --output_checkpoint_dir=$output_ckpt_dir + +run: | + cd jetstream-pytorch + python run_server.py --model_name=$model_name \ + --size=7b --batch_size=24 --max_cache_length=2048 \ + --checkpoint_path=$output_ckpt_dir \ + --tokenizer_path=$tokenizer_path \ + --sharding_config="default_shardings/llama.yaml" diff --git a/examples/tpu/v6e/train-llama3-8b.yaml b/examples/tpu/v6e/train-llama3-8b.yaml new file mode 100644 index 00000000000..3acdbdbbe0b --- /dev/null +++ b/examples/tpu/v6e/train-llama3-8b.yaml @@ -0,0 +1,53 @@ +resources: + accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use + +envs: + HF_TOKEN: # fill in your huggingface token + +workdir: . + +setup: | + pip3 install huggingface_hub + python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + + # Setup TPU + pip3 install cloud-tpu-client + sudo apt update + sudo apt install -y libopenblas-base + pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ + --index-url https://download.pytorch.org/whl/nightly/cpu + pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ + -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Setup runtime for training + git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git + cd transformers + pip3 install -e . + pip3 install datasets evaluate scikit-learn accelerate + +run: | + unset LD_PRELOAD + PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \ + python3 transformers/examples/pytorch/language-modeling/run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --per_device_train_batch_size 16 \ + --do_train \ + --output_dir /home/$USER/tmp/test-clm \ + --overwrite_output_dir \ + --config_name /home/$USER/sky_workdir/config-8B.json \ + --cache_dir /home/$USER/cache \ + --tokenizer_name meta-llama/Meta-Llama-3-8B \ + --block_size 8192 \ + --optim adafactor \ + --save_strategy no \ + --logging_strategy no \ + --fsdp "full_shard" \ + --fsdp_config /home/$USER/sky_workdir/fsdp_config.json \ + --torch_dtype bfloat16 \ + --dataloader_drop_last yes \ + --flash_attention \ + --max_steps 20 diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index b0a064afe7c..a256277085f 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2467,7 +2467,7 @@ def num_ips_per_node(self) -> int: """Returns number of IPs per node in the cluster, handling TPU Pod.""" is_tpu_vm_pod = gcp_utils.is_tpu_vm_pod(self.launched_resources) if is_tpu_vm_pod: - num_ips = gcp_utils.get_num_tpu_devices(self.launched_resources) + num_ips = len(self.internal_ips()) else: num_ips = 1 return num_ips diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index 68e6192d351..cfb893c8cb4 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -49,14 +49,6 @@ def is_tpu_vm_pod(resources: Optional['resources_lib.Resources']) -> bool: return not acc.endswith('-8') -def get_num_tpu_devices(resources: Optional['resources_lib.Resources']) -> int: - if resources is None or not is_tpu(resources): - raise ValueError('resources must be a valid TPU resource.') - acc, _ = list(resources.accelerators.items())[0] - num_tpu_devices = int(int(acc.split('-')[2]) / 8) - return num_tpu_devices - - @dataclasses.dataclass class SpecificReservation: count: int diff --git a/sky/resources.py b/sky/resources.py index 9d853ac81f0..3b33476713b 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -602,6 +602,9 @@ def _get_default_runtime_version() -> str: # TPU V5 requires a newer runtime version. if acc.startswith('tpu-v5'): return 'v2-alpha-tpuv5' + # TPU V6e requires a newer runtime version. + if acc.startswith('tpu-v6e'): + return 'v2-alpha-tpuv6e' return 'tpu-vm-base' accelerator_args['runtime_version'] = (