Skip to content

Commit

Permalink
[Docs] Tpu v6 docs (#4221)
Browse files Browse the repository at this point in the history
* Update TPU v6 docs

* tpu v6 docs

* add TPU v6

* update

* Fix tpu docs

* fix indents

* restructure TPU doc

* Fix

* Fix

* fix

* Fix TPU

* fix docs

* Update docs/source/reference/tpu.rst

Co-authored-by: Tian Xia <cblmemo@gmail.com>

---------

Co-authored-by: Tian Xia <cblmemo@gmail.com>
  • Loading branch information
Michaelvll and cblmemo authored Oct 31, 2024
1 parent 22fd238 commit 1a9e90d
Showing 1 changed file with 179 additions and 144 deletions.
323 changes: 179 additions & 144 deletions docs/source/reference/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,104 +20,232 @@ Use one command to quickly get TPU nodes for development:

.. code-block:: bash
sky launch --gpus tpu-v2-8
# Use latest TPU v6 (Trillium) VMs:
sky launch --gpus tpu-v6e-8
# Use TPU v4 (Titan) VMs:
sky launch --gpus tpu-v4-8
# Preemptible TPUs:
sky launch --gpus tpu-v2-8 --use-spot
# Change TPU type to tpu-v3-8:
sky launch --gpus tpu-v3-8
# Change the host VM type to n1-highmem-16:
sky launch --gpus tpu-v3-8 -t n1-highmem-16
sky launch --gpus tpu-v6e-8 --use-spot
After the command finishes, you will be dropped into a TPU host VM and can start developing code right away.

Below, we show examples of using SkyPilot to run MNIST training on (1) TPU VMs and (2) TPU Nodes.
Below, we show examples of using SkyPilot to (1) train LLMs on TPU VMs/Pods and (2) train MNIST on TPU Nodes (legacy).

TPU Architectures
=================

Two different TPU architectures are available on GCP:

- `TPU VMs <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-vm>`_
- `TPU VMs/Pods <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-vm>`_
- `TPU Nodes <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-node>`_

Both are supported by SkyPilot. We recommend TPU VMs which is a newer architecture encouraged by GCP.
Both are supported by SkyPilot. We recommend TPU VMs and Pods which are newer architectures encouraged by GCP.

The two architectures differ as follows.
For TPU VMs, you can directly SSH into the "TPU host" VM that is physically connected to the TPU device.
For TPU Nodes, a user VM (an `n1` instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC.

* For TPU VMs/Pods, you can directly SSH into the "TPU host" VM that is physically connected to the TPU device.
* For TPU Nodes, a user VM (an `n1` instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC.

More details can be found on GCP `documentation <https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch>`_.

TPU VMs
-------

To use TPU VMs, set the following in a task YAML's ``resources`` field:
.. _tpu-vms:

TPU VMs/Pods
------------

Google's latest TPU v6 (Trillium) VMs offers great performance and it is now supported by SkyPilot.

To use TPU VMs/Pods, set the following in a task YAML's ``resources`` field:

.. code-block:: yaml
resources:
accelerators: tpu-v2-8
accelerators: tpu-v6e-8
accelerator_args:
runtime_version: tpu-vm-base # optional
runtime_version: v2-alpha-tpuv6e # optional
The ``accelerators`` field specifies the TPU type, and the :code:`accelerator_args` dict includes the optional :code:`tpu_vm` bool (defaults to true, which means TPU VM is used), and an optional TPU ``runtime_version`` field.
To show what TPU types are supported, run :code:`sky show-gpus`.

Here is a complete task YAML that runs `MNIST training <https://cloud.google.com/tpu/docs/run-calculation-jax#running_jax_code_on_a_tpu_vm>`_ on a TPU VM using JAX.
Here is a complete task YAML that trains a `Llama 3 model <https://ai.meta.com/blog/meta-llama-3/>`_ on a TPU VM using Torch XLA.

.. code-block:: yaml
name: mnist-tpu-vm
resources:
accelerators: tpu-v2-8
accelerator_args:
tpu_vm: True
runtime_version: tpu-vm-base
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
setup: |
git clone https://github.com/google/flax.git
envs:
HF_TOKEN: # fill in your huggingface token
conda activate flax
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n flax python=3.8 -y
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --upgrade clu
pip install -e flax
pip install tensorflow tensorflow-datasets
fi
workdir: .
run: |
conda activate flax
cd flax/examples/mnist
python3 main.py --workdir=/tmp/mnist \
--config=configs/default.py \
--config.learning_rate=0.05 \
--config.num_epochs=10
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for training
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets evaluate scikit-learn accelerate
This YAML lives under the `SkyPilot repo <https://github.com/skypilot-org/skypilot/tree/master/examples/tpu>`_ (``examples/tpu/tpuvm_mnist.yaml``), or you can paste it into a local file.
run: |
unset LD_PRELOAD
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 16 \
--do_train \
--output_dir /home/$USER/tmp/test-clm \
--overwrite_output_dir \
--config_name /home/$USER/sky_workdir/config-8B.json \
--cache_dir /home/$USER/cache \
--tokenizer_name meta-llama/Meta-Llama-3-8B \
--block_size 8192 \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--max_steps 20
This YAML lives under the `SkyPilot repo <https://github.com/skypilot-org/skypilot/blob/tpu-v6/examples/tpu/v6e/train-llama3-8b.yaml>`__, or you can paste it into a local file.

Launch it with:

.. code-block:: console
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
$ HF_TOKEN=<your-huggingface-token> sky launch train-llama3-8b.yaml -c llama-3-train --env HF_TOKEN
You should see the following outputs when the job finishes.

.. code-block:: console
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
$ sky launch train-llama3-8b.yaml -c llama-3-train
(task, pid=17499) ***** train metrics *****
(task, pid=17499) epoch = 1.1765
(task, pid=17499) total_flos = 109935420GF
(task, pid=17499) train_loss = 10.6011
(task, pid=17499) train_runtime = 0:11:12.77
(task, pid=17499) train_samples = 282
(task, pid=17499) train_samples_per_second = 0.476
(task, pid=17499) train_steps_per_second = 0.03
Multi-Host TPU Pods
-------------------

A `TPU Pod <https://cloud.google.com/tpu/docs/training-on-tpu-pods>`_ is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training.

To use a TPU Pod, simply change the ``accelerators`` field in the task YAML (e.g., :code:`tpu-v6e-8` -> :code:`tpu-v6e-32`).

.. code-block:: yaml
:emphasize-lines: 2-2
resources:
accelerators: tpu-v6e-32 # Pods have > 8 cores (the last number)
.. note::

Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs.

To show all available TPU Pod types, run :code:`sky show-gpus` (more than 8 cores means Pods):

.. code-block:: console
GOOGLE_TPU AVAILABLE_QUANTITIES
tpu-v6e-8 1
tpu-v6e-32 1
tpu-v6e-128 1
tpu-v6e-256 1
...
(mnist-tpu-vm pid=10155) I0823 07:49:25.468526 139641357117440 train.py:146] epoch: 9, train_loss: 0.0120, train_accuracy: 99.64, test_loss: 0.0278, test_accuracy: 99.02
(mnist-tpu-vm pid=10155) I0823 07:49:26.966874 139641357117440 train.py:146] epoch: 10, train_loss: 0.0095, train_accuracy: 99.73, test_loss: 0.0264, test_accuracy: 99.19
After creating a TPU Pod, multiple host VMs (e.g., :code:`tpu-v6e-32` comes with 4 host VMs) are launched.
Normally, the user needs to SSH into all hosts to prepare files and setup environments, and
then launch the job on each host, which is a tedious and error-prone process.

SkyPilot automates away this complexity. From your laptop, a single :code:`sky launch` command will perform:

- workdir/file_mounts syncing; and
- execute the setup/run commands on every host of the pod.

We can run the same Llama 3 training job in on a TPU Pod with the following command, with a slight change to the YAML (``--per_device_train_batch_size`` from 16 to 32):

.. code-block:: console
$ HF_TOKEN=<your-huggingface-token> sky launch -c tpu-pod --gpus tpu-v6e-32 train-llama3-8b.yaml --env HF_TOKEN
You should see the following output.

.. code-block:: console
(head, rank=0, pid=17894) ***** train metrics *****
(head, rank=0, pid=17894) epoch = 2.5
(head, rank=0, pid=17894) total_flos = 219870840GF
(head, rank=0, pid=17894) train_loss = 10.1527
(head, rank=0, pid=17894) train_runtime = 0:11:13.18
(head, rank=0, pid=17894) train_samples = 282
(head, rank=0, pid=17894) train_samples_per_second = 0.951
(head, rank=0, pid=17894) train_steps_per_second = 0.03
(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03
(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03
(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03
To submit more jobs to the same TPU Pod, use :code:`sky exec`:

.. code-block:: console
$ HF_TOKEN=<your-huggingface-token> sky exec tpu-pod train-llama3-8b.yaml --env HF_TOKEN
TPU Nodes
---------
**You can find more useful examples for Serving LLMs on TPUs in** `SkyPilot repo <https://github.com/skypilot-org/skypilot/tree/master/examples/tpu/v6e>`__.



TPU Nodes (Legacy)
------------------

In a TPU Node, a normal CPU VM (an `n1` instance) needs to be provisioned to communicate with the TPU host/device.

Expand Down Expand Up @@ -215,96 +343,3 @@ This YAML lives under the `SkyPilot repo <https://github.com/skypilot-org/skypil
Using TPU Pods
==============

A `TPU Pod <https://cloud.google.com/tpu/docs/training-on-tpu-pods>`_ is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training.

To use a TPU Pod, simply change the ``accelerators`` field in the task YAML (e.g., :code:`v2-8` -> :code:`v2-32`).

.. code-block:: yaml
:emphasize-lines: 2-2
resources:
accelerators: tpu-v2-32 # Pods have > 8 cores (the last number)
accelerator_args:
runtime_version: tpu-vm-base
.. note::

Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs.

To show all available TPU Pod types, run :code:`sky show-gpus` (more than 8 cores means Pods):

.. code-block:: console
GOOGLE_TPU AVAILABLE_QUANTITIES
tpu-v2-8 1
tpu-v2-32 1
tpu-v2-128 1
tpu-v2-256 1
tpu-v2-512 1
tpu-v3-8 1
tpu-v3-32 1
tpu-v3-64 1
tpu-v3-128 1
tpu-v3-256 1
tpu-v3-512 1
tpu-v3-1024 1
tpu-v3-2048 1
After creating a TPU Pod, multiple host VMs (e.g., :code:`v2-32` comes with 4 host VMs) are launched.
Normally, the user needs to SSH into all hosts (depending on the architecture used, either the ``n1`` User VMs or the TPU Host VMs) to prepare files and setup environments, and
then launch the job on each host, which is a tedious and error-prone process.

SkyPilot automates away this complexity. From your laptop, a single :code:`sky launch` command will perform:

- workdir/file_mounts syncing; and
- execute the setup/run commands on every host of the pod.

Here is a task YAML for a cifar10 training job on a :code:`v2-32` TPU Pod with JAX (`code repo <https://github.com/infwinston/tpu-example>`_):

.. code-block:: yaml
name: cifar-tpu-pod
resources:
accelerators: tpu-v2-32
accelerator_args:
runtime_version: tpu-vm-base
setup: |
git clone https://github.com/infwinston/tpu-example.git
cd tpu-example
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
run: |
python -u tpu-example/train.py
Launch it with:

.. code-block:: console
$ sky launch examples/tpu/cifar_pod.yaml -c mycluster
You should see the following output.

.. code-block:: console
(node-0 pid=57977, ip=10.164.0.24) JAX process: 1 / 4
(node-3 pid=57963, ip=10.164.0.26) JAX process: 3 / 4
(node-2 pid=57922, ip=10.164.0.25) JAX process: 2 / 4
(node-1 pid=63223) JAX process: 0 / 4
...
(node-0 pid=57977, ip=10.164.0.24) [ 1000/100000] time 0.034 ( 0.063) data 0.008 ( 0.008) loss 1.215 ( 1.489) acc 68.750 (46.163)
.. note::

By default, outputs from all hosts are shown with the ``node-<i>`` prefix. Use :code:`jax.process_index()` to control which host to print messages.

To submit more jobs to the same TPU Pod, use :code:`sky exec`:

.. code-block:: console
$ sky exec mycluster examples/tpu/cifar_pod.yaml

0 comments on commit 1a9e90d

Please sign in to comment.