diff --git a/Dockerfile_k8s b/Dockerfile_k8s index 63def8682b2..45625871078 100644 --- a/Dockerfile_k8s +++ b/Dockerfile_k8s @@ -1,4 +1,4 @@ -FROM continuumio/miniconda3:23.3.1-0 +FROM --platform=linux/amd64 continuumio/miniconda3:23.3.1-0 # TODO(romilb): Investigate if this image can be consolidated with the skypilot # client image (`Dockerfile`) @@ -33,21 +33,15 @@ ENV HOME /home/sky # Set current working directory WORKDIR /home/sky -# Install SkyPilot pip dependencies preemptively to speed up provisioning time -RUN conda init && \ - pip install wheel Click colorama cryptography jinja2 jsonschema networkx \ - oauth2client pandas pendulum PrettyTable rich tabulate filelock packaging \ - 'protobuf<4.0.0' pulp pycryptodome==3.12.0 docker kubernetes==28.1.0 \ - grpcio==1.51.3 python-dotenv==1.0.1 ray[default]==2.9.3 && \ +# Install skypilot dependencies +RUN conda init && export PIP_DISABLE_PIP_VERSION_CHECK=1 && \ + python3 -m venv ~/skypilot-runtime && \ + PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python && \ + $PYTHON_EXEC -m pip install 'skypilot-nightly[remote,kubernetes]' 'ray[default]==2.9.3' 'pycryptodome==3.12.0' && \ + $PYTHON_EXEC -m pip uninstall skypilot-nightly -y && \ curl -LO "https://dl.k8s.io/release/v1.28.11/bin/linux/amd64/kubectl" && \ - sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl - -# Add /home/sky/.local/bin/ to PATH -RUN echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc - -# Copy SkyPilot code base. This is required for the ssh jump pod to find the -# lifecycle management scripts -COPY --chown=sky . /skypilot/sky/ + sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl && \ + echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc # Set PYTHONUNBUFFERED=1 to have Python print to stdout/stderr immediately ENV PYTHONUNBUFFERED=1 diff --git a/Dockerfile_k8s_gpu b/Dockerfile_k8s_gpu index f9bc7258c61..09570d102df 100644 --- a/Dockerfile_k8s_gpu +++ b/Dockerfile_k8s_gpu @@ -41,19 +41,14 @@ RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x8 eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ grep "# >>> conda initialize >>>" ~/.bashrc || { conda init && source ~/.bashrc; } && \ rm Miniconda3-Linux-x86_64.sh && \ - pip install wheel Click colorama cryptography jinja2 jsonschema networkx \ - oauth2client pandas pendulum PrettyTable rich tabulate filelock packaging \ - 'protobuf<4.0.0' pulp pycryptodome==3.12.0 docker kubernetes==28.1.0 \ - grpcio==1.51.3 python-dotenv==1.0.1 ray[default]==2.9.3 && \ + export PIP_DISABLE_PIP_VERSION_CHECK=1 && \ + python3 -m venv ~/skypilot-runtime && \ + PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python && \ + $PYTHON_EXEC -m pip install 'skypilot-nightly[remote,kubernetes]' 'ray[default]==2.9.3' 'pycryptodome==3.12.0' && \ + $PYTHON_EXEC -m pip uninstall skypilot-nightly -y && \ curl -LO "https://dl.k8s.io/release/v1.28.11/bin/linux/amd64/kubectl" && \ - sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl - -# Add /home/sky/.local/bin/ to PATH -RUN echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc - -# Copy SkyPilot code base. This is required for the ssh jump pod to find the -# lifecycle management scripts -COPY --chown=sky . /skypilot/sky/ + sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl && \ + echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc # Set PYTHONUNBUFFERED=1 to have Python print to stdout/stderr immediately ENV PYTHONUNBUFFERED=1 diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst index a47b4345b9f..018a993f588 100644 --- a/docs/source/examples/managed-jobs.rst +++ b/docs/source/examples/managed-jobs.rst @@ -5,14 +5,17 @@ Managed Jobs .. tip:: - This feature is great for scaling out: running a single job for long durations, or running many jobs (pipelines). + This feature is great for scaling out: running a single job for long durations, or running many jobs in parallel. -SkyPilot supports **managed jobs** (:code:`sky jobs`), which can automatically recover from any spot preemptions or hardware failures. -It can be used in three modes: +SkyPilot supports **managed jobs** (:code:`sky jobs`), which can automatically recover from any underlying spot preemptions or hardware failures. +Managed jobs can be used in three modes: -#. :ref:`Managed Spot Jobs `: Jobs run on auto-recovering spot instances. This can **save significant costs** (e.g., up to 70\% for GPU VMs) by making preemptible spot instances useful for long-running jobs. -#. :ref:`On-demand `: Jobs run on auto-recovering on-demand instances. This is useful for jobs that require guaranteed resources. -#. :ref:`Pipelines `: Run pipelines that contain multiple tasks (which can have different resource requirements and ``setup``/``run`` commands). This is useful for running a sequence of tasks that depend on each other, e.g., data processing, training a model, and then running inference on it. +#. :ref:`Managed spot jobs `: Jobs run on auto-recovering spot instances. This **saves significant costs** (e.g., ~70\% for GPU VMs) by making preemptible spot instances useful for long-running jobs. +#. :ref:`Managed on-demand/reserved jobs `: Jobs run on auto-recovering on-demand or reserved instances. Useful for jobs that require guaranteed resources. +#. :ref:`Managed pipelines `: Run pipelines that contain multiple tasks (which + can have different resource requirements and ``setup``/``run`` commands). + Useful for running a sequence of tasks that depend on each other, e.g., data + processing, training a model, and then running inference on it. .. _spot-jobs: @@ -20,28 +23,12 @@ It can be used in three modes: Managed Spot Jobs ----------------- -In this mode, :code:`sky jobs launch --use-spot` is used to launch a managed spot job. SkyPilot automatically finds available spot resources across regions and clouds to maximize availability. -Any spot preemptions are automatically handled by SkyPilot without user intervention. +In this mode, jobs run on spot instances, and preemptions are auto-recovered by SkyPilot. +To launch a managed spot job, use :code:`sky jobs launch --use-spot`. +SkyPilot automatically finds available spot instances across regions and clouds to maximize availability. +Any spot preemptions are automatically handled by SkyPilot without user intervention. -Quick comparison between *unmanaged spot clusters* vs. *managed spot jobs*: - -.. list-table:: - :widths: 30 18 12 35 - :header-rows: 1 - - * - Command - - Managed? - - SSH-able? - - Best for - * - :code:`sky launch --use-spot` - - Unmanaged spot cluster - - Yes - - Interactive dev on spot instances (especially for hardware with low preemption rates) - * - :code:`sky jobs launch --use-spot` - - Managed spot job (auto-recovery) - - No - - Scaling out long-running jobs (e.g., data processing, training, batch inference) Here is an example of a BERT training job failing over different regions across AWS and GCP. @@ -59,6 +46,25 @@ To use managed spot jobs, there are two requirements: #. :ref:`Checkpointing ` (optional): For job recovery due to preemptions, the user application code can checkpoint its progress periodically to a :ref:`mounted cloud bucket `. The program can reload the latest checkpoint when restarted. +Quick comparison between *managed spot jobs* vs. *launching spot clusters*: + +.. list-table:: + :widths: 30 18 12 35 + :header-rows: 1 + + * - Command + - Managed? + - SSH-able? + - Best for + * - :code:`sky jobs launch --use-spot` + - Yes, preemptions are auto-recovered + - No + - Scaling out long-running jobs (e.g., data processing, training, batch inference) + * - :code:`sky launch --use-spot` + - No, preemptions are not handled + - Yes + - Interactive dev on spot instances (especially for hardware with low preemption rates) + .. _job-yaml: Job YAML @@ -93,7 +99,7 @@ We can launch it with the following: setup: | # Fill in your wandb key: copy from https://wandb.ai/authorize # Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY` - # to pass the key in the command line, during `sky spot launch`. + # to pass the key in the command line, during `sky jobs launch`. echo export WANDB_API_KEY=[YOUR-WANDB-API-KEY] >> ~/.bashrc pip install -e . @@ -245,11 +251,11 @@ Real-World Examples .. _on-demand: -Using On-Demand Instances --------------------------------- +Managed On-Demand/Reserved Jobs +------------------------------- The same ``sky jobs launch`` and YAML interfaces can run jobs on auto-recovering -on-demand instances. This is useful to have SkyPilot monitor any underlying +on-demand or reserved instances. This is useful to have SkyPilot monitor any underlying machine failures and transparently recover the job. To do so, simply set :code:`use_spot: false` in the :code:`resources` section, or override it with :code:`--use-spot false` in the CLI. @@ -264,10 +270,10 @@ To do so, simply set :code:`use_spot: false` in the :code:`resources` section, o interface, while ``sky launch`` is a cluster interface (that you can launch tasks on, albeit not managed). -Either Spot Or On-Demand -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Either Spot or On-Demand/Reserved +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can use ``any_of`` to specify either spot or on-demand instances as +You can use ``any_of`` to specify either spot or on-demand/reserved instances as candidate resources for a job. See documentation :ref:`here ` for more details. @@ -280,12 +286,35 @@ candidate resources for a job. See documentation :ref:`here - use_spot: false In this example, SkyPilot will perform cost optimizations to select the resource to use, which almost certainly -will be spot instances. If spot instances are not available, SkyPilot will fall back to launch on-demand instances. +will be spot instances. If spot instances are not available, SkyPilot will fall back to launch on-demand/reserved instances. + + +Jobs Restarts on User Code Failure +----------------------------------- + +By default, SkyPilot will try to recover a job when its underlying cluster is preempted or failed. Any user code failures (non-zero exit codes) are not auto-recovered. + +In some cases, you may want a job to automatically restart on its own failures, e.g., when a training job crashes due to a Nvidia driver issue or NCCL timeouts. To specify this, you +can set :code:`max_restarts_on_errors` in :code:`resources.job_recovery` in the job YAML file. + +.. code-block:: yaml + + resources: + accelerators: A100:8 + job_recovery: + # Restart the job up to 3 times on user code errors. + max_restarts_on_errors: 3 + More advanced policies for resource selection, such as the `Can't Be Late `__ (NSDI'24) paper, may be supported in the future. +Running Many Parallel Jobs +-------------------------- + +For batch jobs such as **data processing** or **hyperparameter sweeps**, you can launch many jobs in parallel. See :ref:`many-jobs`. + Useful CLIs ----------- @@ -323,11 +352,10 @@ Cancel a managed job: If any failure happens for a managed job, you can check :code:`sky jobs queue -a` for the brief reason of the failure. For more details, it would be helpful to check :code:`sky jobs logs --controller `. - .. _pipeline: -Job Pipelines -------------- +Managed Pipelines +----------------- A pipeline is a managed job that contains a sequence of tasks running one after another. @@ -414,8 +442,8 @@ To submit the pipeline, the same command :code:`sky jobs launch` is used. The pi -Dashboard ---------- +Job Dashboard +------------- Use ``sky jobs dashboard`` to open a dashboard to see all jobs: diff --git a/docs/source/reference/faq.rst b/docs/source/reference/faq.rst index 5a966a0014f..6a8a598c1ca 100644 --- a/docs/source/reference/faq.rst +++ b/docs/source/reference/faq.rst @@ -38,7 +38,7 @@ How to ensure my workdir's ``.git`` is synced up for managed spot jobs? Currently, there is a difference in whether ``.git`` is synced up depending on the command used: - For regular ``sky launch``, the workdir's ``.git`` is synced up by default. -- For managed spot jobs ``sky spot launch``, the workdir's ``.git`` is excluded by default. +- For managed jobs ``sky jobs launch``, the workdir's ``.git`` is excluded by default. In the second case, to ensure the workdir's ``.git`` is synced up for managed spot jobs, you can explicitly add a file mount to sync it up: @@ -192,6 +192,22 @@ For example, if you have access to special regions of GCP, add the data to ``~/. Also, you can update the catalog for a specific cloud by deleting the CSV file (e.g., ``rm ~/.sky/catalogs//gcp.csv``). SkyPilot will automatically download the latest catalog in the next run. +Package Installation +--------------------- + +Unable to import PyTorch in a SkyPilot task. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +For `PyTorch `_ installation, if you are using the default SkyPilot images (not passing in `--image-id`), ``pip install torch`` should work. + +But if you use your own image which has an older NVIDIA driver (535.161.08 or lower) and you install the default PyTorch, you may encounter the following error: + +.. code-block:: bash + + ImportError: /home/azureuser/miniconda3/lib/python3.10/site-packages/torch/lib/../../nvidia/cusparse/lib/libcusparse.so.12: undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12 + +You will need to install a PyTorch version that is compatible with your NVIDIA driver, e.g., ``pip install torch --index-url https://download.pytorch.org/whl/cu121``. + + Miscellaneous ------------- diff --git a/docs/source/reference/kubernetes/kubernetes-deployment.rst b/docs/source/reference/kubernetes/kubernetes-deployment.rst index e9489e9149e..d3891b3df51 100644 --- a/docs/source/reference/kubernetes/kubernetes-deployment.rst +++ b/docs/source/reference/kubernetes/kubernetes-deployment.rst @@ -147,10 +147,16 @@ Deploying on Google Cloud GKE .. code-block:: console $ sky show-gpus --cloud kubernetes - GPU QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS - L4 1, 2, 3, 4 8 6 - A100 1, 2 4 2 + GPU REQUESTABLE_QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS + L4 1, 2, 4 8 6 + A100 1, 2 4 2 + Kubernetes per node GPU availability + NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS + my-cluster-0 L4 4 4 + my-cluster-1 L4 4 2 + my-cluster-2 A100 2 2 + my-cluster-3 A100 2 0 .. note:: GKE autopilot clusters are currently not supported. Only GKE standard clusters are supported. @@ -196,8 +202,12 @@ Deploying on Amazon EKS .. code-block:: console $ sky show-gpus --cloud kubernetes - GPU QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS - A100 1, 2 4 2 + GPU REQUESTABLE_QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS + A100 1, 2 4 2 + + Kubernetes per node GPU availability + NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS + my-cluster-0 A100 2 2 .. _kubernetes-setup-onprem: diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst index d7313fba3e2..9d46acf13c0 100644 --- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst +++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst @@ -156,9 +156,9 @@ You can also inspect the real-time GPU usage on the cluster with :code:`sky show $ sky show-gpus --cloud kubernetes Kubernetes GPUs - GPU QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS - L4 1, 2, 4 12 12 - H100 1, 2, 4, 8 16 16 + GPU REQUESTABLE_QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS + L4 1, 2, 4 12 12 + H100 1, 2, 4, 8 16 16 Kubernetes per node GPU availability NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS @@ -174,7 +174,12 @@ You can also inspect the real-time GPU usage on the cluster with :code:`sky show Using Custom Images ------------------- -By default, we use and maintain a SkyPilot container image that has conda and a few other basic tools installed. +By default, we maintain and use two SkyPilot container images for use on Kubernetes clusters: + +1. ``us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot``: used for CPU-only clusters (`Dockerfile `__). +2. ``us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot-gpu``: used for GPU clusters (`Dockerfile `__). + +These images are pre-installed with SkyPilot dependencies for fast startup. To use your own image, add :code:`image_id: docker:` to the :code:`resources` section of your task YAML. diff --git a/docs/source/reference/kubernetes/kubernetes-setup.rst b/docs/source/reference/kubernetes/kubernetes-setup.rst index a827d49ea19..3621d1b5338 100644 --- a/docs/source/reference/kubernetes/kubernetes-setup.rst +++ b/docs/source/reference/kubernetes/kubernetes-setup.rst @@ -262,9 +262,9 @@ You can also check the GPUs available on your nodes by running: $ sky show-gpus --cloud kubernetes Kubernetes GPUs - GPU QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS - L4 1, 2, 4 12 12 - H100 1, 2, 4, 8 16 16 + GPU REQUESTABLE_QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS + L4 1, 2, 4 12 12 + H100 1, 2, 4, 8 16 16 Kubernetes per node GPU availability NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS diff --git a/docs/source/reference/tpu.rst b/docs/source/reference/tpu.rst index c34d10dab3c..a753c26bd31 100644 --- a/docs/source/reference/tpu.rst +++ b/docs/source/reference/tpu.rst @@ -20,104 +20,232 @@ Use one command to quickly get TPU nodes for development: .. code-block:: bash - sky launch --gpus tpu-v2-8 + # Use latest TPU v6 (Trillium) VMs: + sky launch --gpus tpu-v6e-8 + # Use TPU v4 (Titan) VMs: + sky launch --gpus tpu-v4-8 # Preemptible TPUs: - sky launch --gpus tpu-v2-8 --use-spot - # Change TPU type to tpu-v3-8: - sky launch --gpus tpu-v3-8 - # Change the host VM type to n1-highmem-16: - sky launch --gpus tpu-v3-8 -t n1-highmem-16 + sky launch --gpus tpu-v6e-8 --use-spot After the command finishes, you will be dropped into a TPU host VM and can start developing code right away. -Below, we show examples of using SkyPilot to run MNIST training on (1) TPU VMs and (2) TPU Nodes. +Below, we show examples of using SkyPilot to (1) train LLMs on TPU VMs/Pods and (2) train MNIST on TPU Nodes (legacy). TPU Architectures ================= Two different TPU architectures are available on GCP: -- `TPU VMs `_ +- `TPU VMs/Pods `_ - `TPU Nodes `_ -Both are supported by SkyPilot. We recommend TPU VMs which is a newer architecture encouraged by GCP. +Both are supported by SkyPilot. We recommend TPU VMs and Pods which are newer architectures encouraged by GCP. The two architectures differ as follows. -For TPU VMs, you can directly SSH into the "TPU host" VM that is physically connected to the TPU device. -For TPU Nodes, a user VM (an `n1` instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC. + +* For TPU VMs/Pods, you can directly SSH into the "TPU host" VM that is physically connected to the TPU device. +* For TPU Nodes, a user VM (an `n1` instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC. + More details can be found on GCP `documentation `_. -TPU VMs -------- -To use TPU VMs, set the following in a task YAML's ``resources`` field: +.. _tpu-vms: + +TPU VMs/Pods +------------ + +Google's latest TPU v6 (Trillium) VMs offers great performance and it is now supported by SkyPilot. + +To use TPU VMs/Pods, set the following in a task YAML's ``resources`` field: .. code-block:: yaml resources: - accelerators: tpu-v2-8 + accelerators: tpu-v6e-8 accelerator_args: - runtime_version: tpu-vm-base # optional + runtime_version: v2-alpha-tpuv6e # optional The ``accelerators`` field specifies the TPU type, and the :code:`accelerator_args` dict includes the optional :code:`tpu_vm` bool (defaults to true, which means TPU VM is used), and an optional TPU ``runtime_version`` field. To show what TPU types are supported, run :code:`sky show-gpus`. -Here is a complete task YAML that runs `MNIST training `_ on a TPU VM using JAX. +Here is a complete task YAML that trains a `Llama 3 model `_ on a TPU VM using Torch XLA. .. code-block:: yaml - name: mnist-tpu-vm - resources: - accelerators: tpu-v2-8 - accelerator_args: - tpu_vm: True - runtime_version: tpu-vm-base + accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use - setup: | - git clone https://github.com/google/flax.git + envs: + HF_TOKEN: # fill in your huggingface token - conda activate flax - if [ $? -eq 0 ]; then - echo 'conda env exists' - else - conda create -n flax python=3.8 -y - conda activate flax - # Make sure to install TPU related packages in a conda env to avoid package conflicts. - pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install --upgrade clu - pip install -e flax - pip install tensorflow tensorflow-datasets - fi + workdir: . - run: | - conda activate flax - cd flax/examples/mnist - python3 main.py --workdir=/tmp/mnist \ - --config=configs/default.py \ - --config.learning_rate=0.05 \ - --config.num_epochs=10 + setup: | + pip3 install huggingface_hub + python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + + # Setup TPU + pip3 install cloud-tpu-client + sudo apt update + sudo apt install -y libopenblas-base + pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ + --index-url https://download.pytorch.org/whl/nightly/cpu + pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ + -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Setup runtime for training + git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git + cd transformers + pip3 install -e . + pip3 install datasets evaluate scikit-learn accelerate -This YAML lives under the `SkyPilot repo `_ (``examples/tpu/tpuvm_mnist.yaml``), or you can paste it into a local file. + run: | + unset LD_PRELOAD + PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \ + python3 transformers/examples/pytorch/language-modeling/run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --per_device_train_batch_size 16 \ + --do_train \ + --output_dir /home/$USER/tmp/test-clm \ + --overwrite_output_dir \ + --config_name /home/$USER/sky_workdir/config-8B.json \ + --cache_dir /home/$USER/cache \ + --tokenizer_name meta-llama/Meta-Llama-3-8B \ + --block_size 8192 \ + --optim adafactor \ + --save_strategy no \ + --logging_strategy no \ + --fsdp "full_shard" \ + --fsdp_config /home/$USER/sky_workdir/fsdp_config.json \ + --torch_dtype bfloat16 \ + --dataloader_drop_last yes \ + --flash_attention \ + --max_steps 20 + +This YAML lives under the `SkyPilot repo `__, or you can paste it into a local file. Launch it with: .. code-block:: console - $ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster + $ HF_TOKEN= sky launch train-llama3-8b.yaml -c llama-3-train --env HF_TOKEN You should see the following outputs when the job finishes. .. code-block:: console - $ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster + $ sky launch train-llama3-8b.yaml -c llama-3-train + (task, pid=17499) ***** train metrics ***** + (task, pid=17499) epoch = 1.1765 + (task, pid=17499) total_flos = 109935420GF + (task, pid=17499) train_loss = 10.6011 + (task, pid=17499) train_runtime = 0:11:12.77 + (task, pid=17499) train_samples = 282 + (task, pid=17499) train_samples_per_second = 0.476 + (task, pid=17499) train_steps_per_second = 0.03 + + +Multi-Host TPU Pods +------------------- + +A `TPU Pod `_ is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training. + +To use a TPU Pod, simply change the ``accelerators`` field in the task YAML (e.g., :code:`tpu-v6e-8` -> :code:`tpu-v6e-32`). + +.. code-block:: yaml + :emphasize-lines: 2-2 + + resources: + accelerators: tpu-v6e-32 # Pods have > 8 cores (the last number) + +.. note:: + + Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs. + +To show all available TPU Pod types, run :code:`sky show-gpus` (more than 8 cores means Pods): + +.. code-block:: console + + GOOGLE_TPU AVAILABLE_QUANTITIES + tpu-v6e-8 1 + tpu-v6e-32 1 + tpu-v6e-128 1 + tpu-v6e-256 1 ... - (mnist-tpu-vm pid=10155) I0823 07:49:25.468526 139641357117440 train.py:146] epoch: 9, train_loss: 0.0120, train_accuracy: 99.64, test_loss: 0.0278, test_accuracy: 99.02 - (mnist-tpu-vm pid=10155) I0823 07:49:26.966874 139641357117440 train.py:146] epoch: 10, train_loss: 0.0095, train_accuracy: 99.73, test_loss: 0.0264, test_accuracy: 99.19 + +After creating a TPU Pod, multiple host VMs (e.g., :code:`tpu-v6e-32` comes with 4 host VMs) are launched. +Normally, the user needs to SSH into all hosts to prepare files and setup environments, and +then launch the job on each host, which is a tedious and error-prone process. + +SkyPilot automates away this complexity. From your laptop, a single :code:`sky launch` command will perform: + +- workdir/file_mounts syncing; and +- execute the setup/run commands on every host of the pod. + +We can run the same Llama 3 training job in on a TPU Pod with the following command, with a slight change to the YAML (``--per_device_train_batch_size`` from 16 to 32): + +.. code-block:: console + + $ HF_TOKEN= sky launch -c tpu-pod --gpus tpu-v6e-32 train-llama3-8b.yaml --env HF_TOKEN + +You should see the following output. + +.. code-block:: console + + (head, rank=0, pid=17894) ***** train metrics ***** + (head, rank=0, pid=17894) epoch = 2.5 + (head, rank=0, pid=17894) total_flos = 219870840GF + (head, rank=0, pid=17894) train_loss = 10.1527 + (head, rank=0, pid=17894) train_runtime = 0:11:13.18 + (head, rank=0, pid=17894) train_samples = 282 + (head, rank=0, pid=17894) train_samples_per_second = 0.951 + (head, rank=0, pid=17894) train_steps_per_second = 0.03 + + (worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics ***** + (worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5 + (worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF + (worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527 + (worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08 + (worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282 + (worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948 + (worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03 + + (worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics ***** + (worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5 + (worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF + (worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527 + (worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61 + (worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282 + (worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947 + (worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03 + + (worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics ***** + (worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5 + (worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF + (worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527 + (worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10 + (worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282 + (worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948 + (worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03 + + +To submit more jobs to the same TPU Pod, use :code:`sky exec`: + +.. code-block:: console + + $ HF_TOKEN= sky exec tpu-pod train-llama3-8b.yaml --env HF_TOKEN -TPU Nodes ---------- +**You can find more useful examples for Serving LLMs on TPUs in** `SkyPilot repo `__. + + + +TPU Nodes (Legacy) +------------------ In a TPU Node, a normal CPU VM (an `n1` instance) needs to be provisioned to communicate with the TPU host/device. @@ -215,96 +343,3 @@ This YAML lives under the `SkyPilot repo `_ is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training. - -To use a TPU Pod, simply change the ``accelerators`` field in the task YAML (e.g., :code:`v2-8` -> :code:`v2-32`). - -.. code-block:: yaml - :emphasize-lines: 2-2 - - resources: - accelerators: tpu-v2-32 # Pods have > 8 cores (the last number) - accelerator_args: - runtime_version: tpu-vm-base - -.. note:: - - Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs. - -To show all available TPU Pod types, run :code:`sky show-gpus` (more than 8 cores means Pods): - -.. code-block:: console - - GOOGLE_TPU AVAILABLE_QUANTITIES - tpu-v2-8 1 - tpu-v2-32 1 - tpu-v2-128 1 - tpu-v2-256 1 - tpu-v2-512 1 - tpu-v3-8 1 - tpu-v3-32 1 - tpu-v3-64 1 - tpu-v3-128 1 - tpu-v3-256 1 - tpu-v3-512 1 - tpu-v3-1024 1 - tpu-v3-2048 1 - -After creating a TPU Pod, multiple host VMs (e.g., :code:`v2-32` comes with 4 host VMs) are launched. -Normally, the user needs to SSH into all hosts (depending on the architecture used, either the ``n1`` User VMs or the TPU Host VMs) to prepare files and setup environments, and -then launch the job on each host, which is a tedious and error-prone process. - -SkyPilot automates away this complexity. From your laptop, a single :code:`sky launch` command will perform: - -- workdir/file_mounts syncing; and -- execute the setup/run commands on every host of the pod. - -Here is a task YAML for a cifar10 training job on a :code:`v2-32` TPU Pod with JAX (`code repo `_): - -.. code-block:: yaml - - name: cifar-tpu-pod - - resources: - accelerators: tpu-v2-32 - accelerator_args: - runtime_version: tpu-vm-base - - setup: | - git clone https://github.com/infwinston/tpu-example.git - cd tpu-example - pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install -r requirements.txt - - run: | - python -u tpu-example/train.py - -Launch it with: - -.. code-block:: console - - $ sky launch examples/tpu/cifar_pod.yaml -c mycluster - -You should see the following output. - -.. code-block:: console - - (node-0 pid=57977, ip=10.164.0.24) JAX process: 1 / 4 - (node-3 pid=57963, ip=10.164.0.26) JAX process: 3 / 4 - (node-2 pid=57922, ip=10.164.0.25) JAX process: 2 / 4 - (node-1 pid=63223) JAX process: 0 / 4 - ... - (node-0 pid=57977, ip=10.164.0.24) [ 1000/100000] time 0.034 ( 0.063) data 0.008 ( 0.008) loss 1.215 ( 1.489) acc 68.750 (46.163) - -.. note:: - - By default, outputs from all hosts are shown with the ``node-`` prefix. Use :code:`jax.process_index()` to control which host to print messages. - -To submit more jobs to the same TPU Pod, use :code:`sky exec`: - -.. code-block:: console - - $ sky exec mycluster examples/tpu/cifar_pod.yaml diff --git a/docs/source/reference/yaml-spec.rst b/docs/source/reference/yaml-spec.rst index 7c298dd4079..455ee5909c9 100644 --- a/docs/source/reference/yaml-spec.rst +++ b/docs/source/reference/yaml-spec.rst @@ -107,6 +107,10 @@ Available fields: # # default: EAGER_NEXT_REGION job_recovery: none + # Or, to allow up to 3 restarts (default: 0) on user code errors: + # job_recovery: + # strategy: EAGER_NEXT_REGION + # max_restarts_on_errors: 3 # Disk size in GB to allocate for OS (mounted at /). Increase this if you # have a large working directory or tasks that write out large outputs. diff --git a/docs/source/reservations/existing-machines.rst b/docs/source/reservations/existing-machines.rst index 2f9ac2a2441..d8d3fb81e67 100644 --- a/docs/source/reservations/existing-machines.rst +++ b/docs/source/reservations/existing-machines.rst @@ -108,9 +108,9 @@ Deploying SkyPilot $ sky show-gpus --cloud kubernetes Kubernetes GPUs - GPU QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS - L4 1, 2, 4 12 12 - H100 1, 2, 4, 8 16 16 + GPU REQUESTABLE_QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS + L4 1, 2, 4 12 12 + H100 1, 2, 4, 8 16 16 Kubernetes per node GPU availability NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS diff --git a/examples/k8s_cloud_deploy/README.md b/examples/k8s_cloud_deploy/README.md index 64519e2fa53..5ba42cbe836 100644 --- a/examples/k8s_cloud_deploy/README.md +++ b/examples/k8s_cloud_deploy/README.md @@ -44,8 +44,8 @@ NAME STATUS ROLES AGE VERSION $ sky show-gpus --cloud kubernetes Kubernetes GPUs -GPU QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS -A10 1 2 2 +GPU REQUESTABLE_QTY_PER_NODE TOTAL_GPUS TOTAL_FREE_GPUS +A10 1 2 2 Kubernetes per node GPU availability NODE_NAME GPU_NAME TOTAL_GPUS FREE_GPUS diff --git a/examples/managed_job_with_storage.yaml b/examples/managed_job_with_storage.yaml index 61244c16ba0..77b69485269 100644 --- a/examples/managed_job_with_storage.yaml +++ b/examples/managed_job_with_storage.yaml @@ -3,7 +3,7 @@ # Runs a task that uses cloud buckets for uploading and accessing files. # # Usage: -# sky spot launch -c spot-storage examples/managed_job_with_storage.yaml +# sky jobs launch -c spot-storage examples/managed_job_with_storage.yaml # sky down spot-storage resources: @@ -26,8 +26,8 @@ file_mounts: name: sky-output-bucket mode: MOUNT - /imagenet-image: - source: s3://sky-imagenet-data + /public-bucket: + source: s3://fah-public-data-covid19-cryptic-pockets # File mounts for folder /tmp/workdir: ~/tmp-workdir @@ -49,7 +49,7 @@ run: | set -ex ls ~/sky_workdir/managed_job_with_storage.yaml ls ~/bucket_workdir/managed_job_with_storage.yaml - ls -l /imagenet-image/datasets + ls -l /public-bucket mkdir -p /data/logs diff --git a/examples/tpu/v6e/README.md b/examples/tpu/v6e/README.md new file mode 100644 index 00000000000..c5eca93ca91 --- /dev/null +++ b/examples/tpu/v6e/README.md @@ -0,0 +1,128 @@ +# TPU v6e + +Trillium (also refers to v6e) is Cloud TPU’s latest generation AI accelerator. SkyPilot support TPU v6e with provisioning, training and serving. + +## Catalogs + +Currently, for TPU v6e, the public APIs for regions and pricing is not released yet, and pricing info for `us-central1`, `us-central2`, `us-south1` is not available. We set the price to `0.0` in those regions for now. + +``` +## Provisioning + +To provision TPU v6e, use the following command: + +```bash +$ sky launch --gpus tpu-v6e-16 -c tpu-v6e +``` + +After that, you can SSH to the instance and start developing your model: + +```bash +$ ssh tpu-v6e +``` + +## Training + +Examples in this directory (`train-llama3-8b.yaml`) shows how to use TPU v6e to train a Llama3 8b model, using PyTorch (XLA) on the wikitext dataset. To start the training, use the following command: + +```bash +$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN +``` + +### Single-Host Training + +The training throughput for a `tpu-v6e-8` instance should around 0.5 samples/s: + +```bash +(task, pid=17499) ***** train metrics ***** +(task, pid=17499) epoch = 1.1765 +(task, pid=17499) total_flos = 109935420GF +(task, pid=17499) train_loss = 10.6011 +(task, pid=17499) train_runtime = 0:11:12.77 +(task, pid=17499) train_samples = 282 +(task, pid=17499) train_samples_per_second = 0.476 +(task, pid=17499) train_steps_per_second = 0.03 +INFO: Job finished (status: SUCCEEDED). +``` + +### Multi-Host Training + +By changing the TPU type to `tpu-v6e-16` and the `--per_device_train_batch_size` to `32`, the training throughput increased to around 1 samples/s: + +```bash +(head, rank=0, pid=17894) ***** train metrics ***** +(head, rank=0, pid=17894) epoch = 2.5 +(head, rank=0, pid=17894) total_flos = 219870840GF +(head, rank=0, pid=17894) train_loss = 10.1527 +(head, rank=0, pid=17894) train_runtime = 0:11:13.18 +(head, rank=0, pid=17894) train_samples = 282 +(head, rank=0, pid=17894) train_samples_per_second = 0.951 +(head, rank=0, pid=17894) train_steps_per_second = 0.03 + +(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics ***** +(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5 +(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948 +(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03 + +(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics ***** +(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5 +(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947 +(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03 + +(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics ***** +(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5 +(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948 +(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03 + +INFO: Job finished (status: SUCCEEDED). +``` + +# Serving + +TPU v6e also supports serving. Examples in this directory (`serve-llama2-7b.yaml`) shows how to use TPU v6e to serve a Llama2 7b model, using PyTorch (XLA) and the JetStream lib. To start the serving, use the following command: + +```bash +$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN +``` + +After the server is ready, you should see the following message: + +```bash +(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads +(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads +(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False +(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False +(task, pid=26431) Started jetstream_server.... +``` + +You can now start a benchmark to test the serving performance: + +```bash +$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml +... (emitted logs) +(task, pid=25491) Successful requests: 100 +(task, pid=25491) Benchmark duration: 8.753792 s +(task, pid=25491) Total input tokens: 21888 +(task, pid=25491) Total generated tokens: 18803 +(task, pid=25491) Request throughput: 11.42 requests/s +(task, pid=25491) Input token throughput: 2500.40 tokens/s +(task, pid=25491) Output token throughput: 2147.98 tokens/s +(task, pid=25491) Mean TTFT: 1981.93 ms +(task, pid=25491) Median TTFT: 1829.33 ms +(task, pid=25491) P99 TTFT: 4511.95 ms +(task, pid=25491) Mean TPOT: 130.71 ms +(task, pid=25491) Median TPOT: 18.88 ms +(task, pid=25491) P99 TPOT: 2487.37 ms +``` diff --git a/examples/tpu/v6e/benchmark-llama2-7b.yaml b/examples/tpu/v6e/benchmark-llama2-7b.yaml new file mode 100644 index 00000000000..d6fa002e160 --- /dev/null +++ b/examples/tpu/v6e/benchmark-llama2-7b.yaml @@ -0,0 +1,10 @@ +envs: + model_name: llama-2 + tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model + +run: | + cd JetStream + python benchmarks/benchmark_serving.py \ + --tokenizer=$tokenizer_path --num-prompts=100 \ + --dataset openorca --save-request-outputs \ + --warmup-mode=sampled --model=$model_name diff --git a/examples/tpu/v6e/config-8B.json b/examples/tpu/v6e/config-8B.json new file mode 100644 index 00000000000..175a749e1bd --- /dev/null +++ b/examples/tpu/v6e/config-8B.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/examples/tpu/v6e/fsdp_config.json b/examples/tpu/v6e/fsdp_config.json new file mode 100644 index 00000000000..a37a04352ee --- /dev/null +++ b/examples/tpu/v6e/fsdp_config.json @@ -0,0 +1,8 @@ +{ + "fsdp_transformer_layer_cls_to_wrap": [ + "LlamaDecoderLayer" + ], + "xla": true, + "xla_fsdp_v2": true, + "xla_fsdp_grad_ckpt": true +} diff --git a/examples/tpu/v6e/serve-llama2-7b.yaml b/examples/tpu/v6e/serve-llama2-7b.yaml new file mode 100644 index 00000000000..49d0bf9fcd2 --- /dev/null +++ b/examples/tpu/v6e/serve-llama2-7b.yaml @@ -0,0 +1,60 @@ +resources: + accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use + +envs: + HF_TOKEN: # fill in your huggingface token + HF_REPO_ID: meta-llama/Llama-2-7b + model_name: llama-2 + input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original + output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted + tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model + +setup: | + pip3 install huggingface_hub + python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + + # Setup TPU + pip3 install cloud-tpu-client + sudo apt update + sudo apt install -y libopenblas-base + pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ + --index-url https://download.pytorch.org/whl/nightly/cpu + pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ + -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Setup runtime for serving + git clone https://github.com/google/JetStream.git + cd JetStream + git checkout main + git pull origin main + pip install -e . + cd benchmarks + pip install -r requirements.in + cd ../.. + git clone https://github.com/google/jetstream-pytorch.git + cd jetstream-pytorch/ + git checkout jetstream-v0.2.3 + source install_everything.sh + pip3 install -U --pre jax jaxlib libtpu-nightly requests \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + + # Prepare checkpoint, inside jetstream-pytorch repo + mkdir -p ${input_ckpt_dir} + python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')" + mkdir -p ${output_ckpt_dir} + python -m convert_checkpoints --model_name=$model_name \ + --input_checkpoint_dir=$input_ckpt_dir \ + --output_checkpoint_dir=$output_ckpt_dir + +run: | + cd jetstream-pytorch + python run_server.py --model_name=$model_name \ + --size=7b --batch_size=24 --max_cache_length=2048 \ + --checkpoint_path=$output_ckpt_dir \ + --tokenizer_path=$tokenizer_path \ + --sharding_config="default_shardings/llama.yaml" diff --git a/examples/tpu/v6e/train-llama3-8b.yaml b/examples/tpu/v6e/train-llama3-8b.yaml new file mode 100644 index 00000000000..3acdbdbbe0b --- /dev/null +++ b/examples/tpu/v6e/train-llama3-8b.yaml @@ -0,0 +1,53 @@ +resources: + accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use + +envs: + HF_TOKEN: # fill in your huggingface token + +workdir: . + +setup: | + pip3 install huggingface_hub + python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" + + # Setup TPU + pip3 install cloud-tpu-client + sudo apt update + sudo apt install -y libopenblas-base + pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ + --index-url https://download.pytorch.org/whl/nightly/cpu + pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ + -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Setup runtime for training + git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git + cd transformers + pip3 install -e . + pip3 install datasets evaluate scikit-learn accelerate + +run: | + unset LD_PRELOAD + PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \ + python3 transformers/examples/pytorch/language-modeling/run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --per_device_train_batch_size 16 \ + --do_train \ + --output_dir /home/$USER/tmp/test-clm \ + --overwrite_output_dir \ + --config_name /home/$USER/sky_workdir/config-8B.json \ + --cache_dir /home/$USER/cache \ + --tokenizer_name meta-llama/Meta-Llama-3-8B \ + --block_size 8192 \ + --optim adafactor \ + --save_strategy no \ + --logging_strategy no \ + --fsdp "full_shard" \ + --fsdp_config /home/$USER/sky_workdir/fsdp_config.json \ + --torch_dtype bfloat16 \ + --dataloader_drop_last yes \ + --flash_attention \ + --max_steps 20 diff --git a/format.sh b/format.sh index 66b966c3029..b06481b4c10 100755 --- a/format.sh +++ b/format.sh @@ -60,6 +60,10 @@ BLACK_INCLUDES=( 'sky/skylet/providers/ibm' ) +PYLINT_FLAGS=( + '--load-plugins' 'pylint_quotes' +) + # Format specified files format() { yapf --in-place "${YAPF_FLAGS[@]}" "$@" @@ -77,8 +81,9 @@ format_changed() { MERGEBASE="$(git merge-base origin/master HEAD)" if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ - yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | \ + tr '\n' '\0' | xargs -P 5 -0 \ + yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" fi } @@ -119,7 +124,21 @@ mypy $(cat tests/mypy_files.txt) # Run Pylint echo 'Sky Pylint:' -pylint --load-plugins pylint_quotes sky +if [[ "$1" == '--files' ]]; then + # If --files is passed, filter to files within sky/ and pass to pylint. + pylint "${PYLINT_FLAGS[@]}" "${@:2}" +elif [[ "$1" == '--all' ]]; then + # Pylint entire sky directory. + pylint "${PYLINT_FLAGS[@]}" sky +else + # Pylint only files in sky/ that have changed in last commit. + changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- 'sky/*.py' 'sky/*.pyi') + if [[ -n "$changed_files" ]]; then + echo "$changed_files" | tr '\n' '\0' | xargs -0 pylint "${PYLINT_FLAGS[@]}" + else + echo 'Pylint skipped: no files changed in sky/.' + fi +fi if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' diff --git a/llm/axolotl/axolotl-spot.yaml b/llm/axolotl/axolotl-spot.yaml index b22a8ae3fce..0e04ba11992 100644 --- a/llm/axolotl/axolotl-spot.yaml +++ b/llm/axolotl/axolotl-spot.yaml @@ -4,7 +4,7 @@ # HF_TOKEN=abc BUCKET= sky launch -c axolotl-spot axolotl-spot.yaml --env HF_TOKEN --env BUCKET -i30 --down # # Managed spot (auto-recovery; for full runs): -# HF_TOKEN=abc BUCKET= sky spot launch -n axolotl-spot axolotl-spot.yaml --env HF_TOKEN --env BUCKET +# HF_TOKEN=abc BUCKET= sky jobs launch -n axolotl-spot axolotl-spot.yaml --env HF_TOKEN --env BUCKET name: axolotl diff --git a/llm/axolotl/readme.md b/llm/axolotl/readme.md index 0cc06b98723..eb80231aa93 100644 --- a/llm/axolotl/readme.md +++ b/llm/axolotl/readme.md @@ -22,5 +22,5 @@ ssh -L 8888:localhost:8888 axolotl-spot Launch managed spot instances (auto-recovery; for full runs): ``` -HF_TOKEN=abc BUCKET= sky spot launch -n axolotl-spot axolotl-spot.yaml --env HF_TOKEN --env BUCKET +HF_TOKEN=abc BUCKET= sky jobs launch -n axolotl-spot axolotl-spot.yaml --env HF_TOKEN --env BUCKET ``` diff --git a/llm/falcon/README.md b/llm/falcon/README.md index 6eb480d9ea8..1f40dc9f524 100644 --- a/llm/falcon/README.md +++ b/llm/falcon/README.md @@ -1,6 +1,6 @@ # Finetuning Falcon with SkyPilot -This README contains instructions on how to use SkyPilot to finetune Falcon-7B and Falcon-40B, an open-source LLM that rivals many current closed-source models, including ChatGPT. +This README contains instructions on how to use SkyPilot to finetune Falcon-7B and Falcon-40B, an open-source LLM that rivals many current closed-source models, including ChatGPT. * [Blog post](https://huggingface.co/blog/falcon) * [Repo](https://huggingface.co/tiiuae/falcon-40b) @@ -16,10 +16,10 @@ sky check See the Falcon SkyPilot YAML for [training](train.yaml). Serving is currently a work in progress and a YAML will be provided for that soon! We are also working on adding an evaluation step to evaluate the model you finetuned compared to the base model. ## Running Falcon on SkyPilot -Finetuning `Falcon-7B` and `Falcon-40B` require GPUs with 80GB memory, +Finetuning `Falcon-7B` and `Falcon-40B` require GPUs with 80GB memory, but `Falcon-7b-sharded` requires only 40GB memory. Thus, * If your GPU has 40 GB memory or less (e.g., Nvidia A100): use `ybelkada/falcon-7b-sharded-bf16`. -* If your GPU has 80 GB memory (e.g., Nvidia A100-80GB): you can also use `tiiuae/falcon-7b` and `tiiuae/falcon-40b`. +* If your GPU has 80 GB memory (e.g., Nvidia A100-80GB): you can also use `tiiuae/falcon-7b` and `tiiuae/falcon-40b`. Try `sky show-gpus --all` for supported GPUs. @@ -32,13 +32,13 @@ Steps for training on your cloud(s): 1. In [train.yaml](train.yaml), set the following variables in `envs`: - Replace the `OUTPUT_BUCKET_NAME` with a unique name. SkyPilot will create this bucket for you to store the model weights. - - Replace the `WANDB_API_KEY` to your own key. - - Replace the `MODEL_NAME` with your desired base model. + - Replace the `WANDB_API_KEY` to your own key. + - Replace the `MODEL_NAME` with your desired base model. 2. **Training the Falcon model using spot instances**: ```bash -sky spot launch -n falcon falcon.yaml +sky jobs launch --use-spot -n falcon falcon.yaml ``` Currently, such `A100-80GB:1` spot instances are only available on AWS and GCP. diff --git a/llm/vicuna-llama-2/README.md b/llm/vicuna-llama-2/README.md index 24caa525a56..e392b231e64 100644 --- a/llm/vicuna-llama-2/README.md +++ b/llm/vicuna-llama-2/README.md @@ -120,12 +120,12 @@ sky launch --no-use-spot ... ### Reducing costs by 3x with spot instances -[SkyPilot Managed Spot](https://skypilot.readthedocs.io/en/latest/examples/spot-jobs.html) is a library built on top of SkyPilot that helps users run jobs on spot instances without worrying about interruptions. That is the tool used by the LMSYS organization to train the first version of Vicuna (more details can be found in their [launch blog post](https://lmsys.org/blog/2023-03-30-vicuna/) and [example](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna)). With this, the training cost can be reduced from $1000 to **\$300**. +[SkyPilot Managed Jobs](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html) is a library built on top of SkyPilot that helps users run jobs on spot instances without worrying about interruptions. That is the tool used by the LMSYS organization to train the first version of Vicuna (more details can be found in their [launch blog post](https://lmsys.org/blog/2023-03-30-vicuna/) and [example](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna)). With this, the training cost can be reduced from $1000 to **\$300**. -To use SkyPilot Managed Spot, you can simply replace `sky launch` with `sky spot launch` in the above command: +To use SkyPilot Managed Spot Jobs, you can simply replace `sky launch` with `sky jobs launch` in the above command: ```bash -sky spot launch -n vicuna train.yaml \ +sky jobs launch -n vicuna train.yaml \ --env ARTIFACT_BUCKET_NAME= \ --env WANDB_API_KEY= ``` diff --git a/llm/vicuna/README.md b/llm/vicuna/README.md index b511eb7f4b0..6d9f46127d4 100644 --- a/llm/vicuna/README.md +++ b/llm/vicuna/README.md @@ -63,14 +63,14 @@ Steps for training on your cloud(s): 2. **Training the Vicuna-7B model on 8 A100 GPUs (80GB memory) using spot instances**: ```bash # Launch it on managed spot to save 3x cost -sky spot launch -n vicuna train.yaml +sky jobs launch -n vicuna train.yaml ``` Note: if you would like to see the training curve on W&B, you can add `--env WANDB_API_KEY` to the above command, which will propagate your local W&B API key in the environment variable to the job. [Optional] Train a larger 13B model ``` # Train a 13B model instead of the default 7B -sky spot launch -n vicuna-7b train.yaml --env MODEL_SIZE=13 +sky jobs launch -n vicuna-7b train.yaml --env MODEL_SIZE=13 # Use *unmanaged* spot instances (i.e., preemptions won't get auto-recovered). # Unmanaged spot provides a better interactive development experience but is vulnerable to spot preemptions. diff --git a/sky/__init__.py b/sky/__init__.py index 37b5a1caf08..b851775dabf 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -128,6 +128,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): Lambda = clouds.Lambda SCP = clouds.SCP Kubernetes = clouds.Kubernetes +K8s = Kubernetes OCI = clouds.OCI Paperspace = clouds.Paperspace RunPod = clouds.RunPod @@ -143,6 +144,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'GCP', 'IBM', 'Kubernetes', + 'K8s', 'Lambda', 'OCI', 'Paperspace', diff --git a/sky/adaptors/azure.py b/sky/adaptors/azure.py index 61d8d14352e..0730b76ec88 100644 --- a/sky/adaptors/azure.py +++ b/sky/adaptors/azure.py @@ -69,6 +69,17 @@ def exceptions(): return azure_exceptions +@functools.lru_cache() +@common.load_lazy_modules(modules=_LAZY_MODULES) +def azure_mgmt_models(name: str): + if name == 'compute': + from azure.mgmt.compute import models + return models + elif name == 'network': + from azure.mgmt.network import models + return models + + # We should keep the order of the decorators having 'lru_cache' followed # by 'load_lazy_modules' as we need to make sure a caller can call # 'get_client.cache_clear', which is a function provided by 'lru_cache' @@ -120,6 +131,9 @@ def get_client(name: str, from azure.mgmt import authorization return authorization.AuthorizationManagementClient( credential, subscription_id) + elif name == 'msi': + from azure.mgmt import msi + return msi.ManagedServiceIdentityClient(credential, subscription_id) elif name == 'graph': import msgraph return msgraph.GraphServiceClient(credential) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index caa6c9292d5..e4633ef0671 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -401,6 +401,8 @@ class SSHConfigHelper(object): ssh_conf_path = '~/.ssh/config' ssh_conf_lock_path = os.path.expanduser('~/.sky/ssh_config.lock') + ssh_conf_per_cluster_lock_path = os.path.expanduser( + '~/.sky/ssh_config_{}.lock') ssh_cluster_path = SKY_USER_FILE_PATH + '/ssh/{}' @classmethod @@ -486,12 +488,6 @@ def add_cluster( config_path = os.path.expanduser(cls.ssh_conf_path) - # For backward compatibility: before #2706, we wrote the config of SkyPilot clusters - # directly in ~/.ssh/config. For these clusters, we remove the config in ~/.ssh/config - # and write/overwrite the config in ~/.sky/ssh/ instead. - cls._remove_stale_cluster_config_for_backward_compatibility( - cluster_name, ip, auth_config, docker_user) - if not os.path.exists(config_path): config = ['\n'] with open(config_path, @@ -560,139 +556,20 @@ def add_cluster( f.write(codegen) @classmethod - def _remove_stale_cluster_config_for_backward_compatibility( - cls, - cluster_name: str, - ip: str, - auth_config: Dict[str, str], - docker_user: Optional[str] = None, - ): - """Remove authentication information for cluster from local SSH config. - - If no existing host matching the provided specification is found, then - nothing is removed. - - Args: - ip: Head node's IP address. - auth_config: read_yaml(handle.cluster_yaml)['auth'] - docker_user: If not None, use this user to ssh into the docker - """ - username = auth_config['ssh_user'] - config_path = os.path.expanduser(cls.ssh_conf_path) - cluster_config_path = os.path.expanduser( - cls.ssh_cluster_path.format(cluster_name)) - if not os.path.exists(config_path): - return - - with open(config_path, 'r', encoding='utf-8') as f: - config = f.readlines() - - start_line_idx = None - - # Scan the config for the cluster name. - for i, line in enumerate(config): - next_line = config[i + 1] if i + 1 < len(config) else '' - if docker_user is None: - found = (line.strip() == f'HostName {ip}' and - next_line.strip() == f'User {username}') - else: - found = (line.strip() == 'HostName localhost' and - next_line.strip() == f'User {docker_user}') - if found: - # Find the line starting with ProxyCommand and contains the ip - found = False - for idx in range(i, len(config)): - # Stop if we reach an empty line, which means a new host - if not config[idx].strip(): - break - if config[idx].strip().startswith('ProxyCommand'): - proxy_command_line = config[idx].strip() - if proxy_command_line.endswith(f'@{ip}'): - found = True - break - if found: - start_line_idx = i - 1 - break - - if start_line_idx is not None: - # Scan for end of previous config. - cursor = start_line_idx - while cursor > 0 and len(config[cursor].strip()) > 0: - cursor -= 1 - prev_end_line_idx = cursor - - # Scan for end of the cluster config. - end_line_idx = None - cursor = start_line_idx + 1 - start_line_idx -= 1 # remove auto-generated comment - while cursor < len(config): - if config[cursor].strip().startswith( - '# ') or config[cursor].strip().startswith('Host '): - end_line_idx = cursor - break - cursor += 1 - - # Remove sky-generated config and update the file. - config[prev_end_line_idx:end_line_idx] = [ - '\n' - ] if end_line_idx is not None else [] - with open(config_path, 'w', encoding='utf-8') as f: - f.write(''.join(config).strip()) - f.write('\n' * 2) - - # Delete include statement if it exists in the config. - sky_autogen_comment = ('# Added by sky (use `sky stop/down ' - f'{cluster_name}` to remove)') - with open(config_path, 'r', encoding='utf-8') as f: - config = f.readlines() - - for i, line in enumerate(config): - config_str = line.strip() - if f'Include {cluster_config_path}' in config_str: - with open(config_path, 'w', encoding='utf-8') as f: - if i < len(config) - 1 and config[i + 1] == '\n': - del config[i + 1] - # Delete Include string - del config[i] - # Delete Sky Autogen Comment - if i > 0 and sky_autogen_comment in config[i - 1].strip(): - del config[i - 1] - f.write(''.join(config)) - break - if 'Host' in config_str: - break - - @classmethod - # TODO: We can remove this after 0.6.0 and have a lock only per cluster. - @timeline.FileLockEvent(ssh_conf_lock_path) - def remove_cluster( - cls, - cluster_name: str, - ip: str, - auth_config: Dict[str, str], - docker_user: Optional[str] = None, - ): + def remove_cluster(cls, cluster_name: str): """Remove authentication information for cluster from ~/.sky/ssh/. - For backward compatibility also remove the config from ~/.ssh/config if it exists. - If no existing host matching the provided specification is found, then nothing is removed. Args: - ip: Head node's IP address. - auth_config: read_yaml(handle.cluster_yaml)['auth'] - docker_user: If not None, use this user to ssh into the docker + cluster_name: Cluster name. """ - cluster_config_path = os.path.expanduser( - cls.ssh_cluster_path.format(cluster_name)) - common_utils.remove_file_if_exists(cluster_config_path) - - # Ensures backward compatibility: before #2706, we wrote the config of SkyPilot clusters - # directly in ~/.ssh/config. For these clusters, we should clean up the config. - # TODO: Remove this after 0.6.0 - cls._remove_stale_cluster_config_for_backward_compatibility( - cluster_name, ip, auth_config, docker_user) + with timeline.FileLockEvent( + cls.ssh_conf_per_cluster_lock_path.format(cluster_name)): + cluster_config_path = os.path.expanduser( + cls.ssh_cluster_path.format(cluster_name)) + common_utils.remove_file_if_exists(cluster_config_path) def _replace_yaml_dicts( @@ -867,7 +744,7 @@ def write_cluster_config( labels = skypilot_config.get_nested((str(cloud).lower(), 'labels'), {}) # Deprecated: instance_tags have been replaced by labels. For backward # compatibility, we support them and the schema allows them only if - # `labels` are not specified. This should be removed after 0.7.0. + # `labels` are not specified. This should be removed after 0.8.0. labels = skypilot_config.get_nested((str(cloud).lower(), 'instance_tags'), labels) # labels is a dict, which is guaranteed by the type check in @@ -2621,10 +2498,12 @@ def get_task_resources_str(task: 'task_lib.Task', the accelerator demands (if any). Otherwise, the CPU demand is shown. """ spot_str = '' + is_controller_task = task.is_controller_task() task_cpu_demand = (str(constants.CONTROLLER_PROCESS_CPU_DEMAND) - if task.is_controller_task() else - str(DEFAULT_TASK_CPU_DEMAND)) - if task.best_resources is not None: + if is_controller_task else str(DEFAULT_TASK_CPU_DEMAND)) + if is_controller_task: + resources_str = f'CPU:{task_cpu_demand}' + elif task.best_resources is not None: accelerator_dict = task.best_resources.accelerators if is_managed_job: if task.best_resources.use_spot: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index e9cdfc539d6..7bebf209942 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1950,17 +1950,8 @@ def provision_with_retries( failover_history: List[Exception] = list() - style = colorama.Style - fore = colorama.Fore # Retrying launchable resources. while True: - if (isinstance(to_provision.cloud, clouds.Azure) and - to_provision.accelerators is not None and - 'A10' in to_provision.accelerators and prev_handle is None): - logger.warning(f'{style.BRIGHT}{fore.YELLOW}Trying to launch ' - 'an A10 cluster on Azure. This may take ~20 ' - 'minutes due to driver installation.' - f'{style.RESET_ALL}') try: # Recheck cluster name as the 'except:' block below may # change the cloud assignment. @@ -2118,18 +2109,16 @@ def __init__( stable_internal_external_ips: Optional[List[Tuple[str, str]]] = None, stable_ssh_ports: Optional[List[int]] = None, - cluster_info: Optional[provision_common.ClusterInfo] = None, - # The following 2 fields are deprecated. SkyPilot new provisioner - # API handles the TPU node creation/deletion. - # Backward compatibility for TPU nodes created before #2943. - # TODO (zhwu): Remove this after 0.6.0. - tpu_create_script: Optional[str] = None, - tpu_delete_script: Optional[str] = None) -> None: + cluster_info: Optional[provision_common.ClusterInfo] = None + ) -> None: self._version = self._VERSION self.cluster_name = cluster_name self.cluster_name_on_cloud = cluster_name_on_cloud - self._cluster_yaml = cluster_yaml.replace(os.path.expanduser('~'), '~', - 1) + # Replace the home directory with ~ for better robustness across systems + # with different home directories. + if cluster_yaml.startswith(os.path.expanduser('~')): + cluster_yaml = cluster_yaml.replace(os.path.expanduser('~'), '~', 1) + self._cluster_yaml = cluster_yaml # List of (internal_ip, feasible_ip) tuples for all the nodes in the # cluster, sorted by the feasible ips. The feasible ips can be either # internal or external ips, depending on the use_internal_ips flag. @@ -2139,12 +2128,6 @@ def __init__( self.launched_nodes = launched_nodes self.launched_resources = launched_resources self.docker_user: Optional[str] = None - # Deprecated. SkyPilot new provisioner API handles the TPU node - # creation/deletion. - # Backward compatibility for TPU nodes created before #2943. - # TODO (zhwu): Remove this after 0.6.0. - self.tpu_create_script = tpu_create_script - self.tpu_delete_script = tpu_delete_script def __repr__(self): return (f'ResourceHandle(' @@ -2160,10 +2143,7 @@ def __repr__(self): f'\n\tlaunched_resources={self.launched_nodes}x ' f'{self.launched_resources}, ' f'\n\tdocker_user={self.docker_user},' - f'\n\tssh_user={self.ssh_user},' - # TODO (zhwu): Remove this after 0.6.0. - f'\n\ttpu_create_script={self.tpu_create_script}, ' - f'\n\ttpu_delete_script={self.tpu_delete_script})') + f'\n\tssh_user={self.ssh_user}') def get_cluster_name(self): return self.cluster_name @@ -2176,26 +2156,6 @@ def _use_internal_ips(self): return common_utils.read_yaml(self.cluster_yaml).get( 'provider', {}).get('use_internal_ips', False) - def _update_cluster_region(self): - """Update the region in handle.launched_resources. - - This is for backward compatibility to handle the clusters launched - long before. We should remove this after 0.6.0. - """ - if self.launched_resources.region is not None: - return - - config = common_utils.read_yaml(self.cluster_yaml) - provider = config['provider'] - cloud = self.launched_resources.cloud - if cloud.is_same_cloud(clouds.Azure()): - region = provider['location'] - elif cloud.is_same_cloud(clouds.GCP()) or cloud.is_same_cloud( - clouds.AWS()): - region = provider['region'] - - self.launched_resources = self.launched_resources.copy(region=region) - def update_ssh_ports(self, max_attempts: int = 1) -> None: """Fetches and sets the SSH ports for the cluster nodes. @@ -2510,7 +2470,7 @@ def num_ips_per_node(self) -> int: """Returns number of IPs per node in the cluster, handling TPU Pod.""" is_tpu_vm_pod = gcp_utils.is_tpu_vm_pod(self.launched_resources) if is_tpu_vm_pod: - num_ips = gcp_utils.get_num_tpu_devices(self.launched_resources) + num_ips = len(self.internal_ips()) else: num_ips = 1 return num_ips @@ -2567,8 +2527,6 @@ def __setstate__(self, state): if version < 4: self.update_ssh_ports() - self._update_cluster_region() - if version < 8: try: self._update_cluster_info() @@ -2649,8 +2607,6 @@ def check_resources_fit_cluster( if record is not None: usage_lib.messages.usage.update_cluster_status(record['status']) - # Backward compatibility: the old launched_resources without region info - # was handled by ResourceHandle._update_cluster_region. assert launched_resources.region is not None, handle mismatch_str = (f'To fix: specify a new cluster name, or down the ' @@ -2713,6 +2669,21 @@ def check_resources_fit_cluster( f' Existing:\t{handle.launched_nodes}x ' f'{handle.launched_resources}\n' f'{mismatch_str}') + else: + # For fractional acc count clusters, we round up the number of accs + # to 1 (sky/utils/resources_utils.py::make_ray_custom_resources_str) + # Here we scale the required acc count to (required / launched) * 1 + # so the total number of accs is the same as the requested number. + launched_accs = launched_resources.accelerators + if (launched_accs is not None and + valid_resource.accelerators is not None): + for _, count in launched_accs.items(): + if isinstance(count, float) and not count.is_integer(): + valid_resource = valid_resource.copy( + accelerators={ + k: v / count + for k, v in valid_resource.accelerators.items() + }) return valid_resource def _provision( @@ -2737,7 +2708,7 @@ def _provision( (e.g., cluster name invalid) or a region/zone throwing resource unavailability. exceptions.CommandError: any ssh command error. - RuntimeErorr: raised when 'rsync' is not installed. + RuntimeError: raised when 'rsync' is not installed. # TODO(zhwu): complete the list of exceptions. """ # FIXME: ray up for Azure with different cluster_names will overwrite @@ -3198,9 +3169,19 @@ def _run_setup(setup_cmd: str) -> int: returncode = _run_setup(f'{create_script_code} && {setup_cmd}',) if returncode == 255: is_message_too_long = False - with open(setup_log_path, 'r', encoding='utf-8') as f: - if 'too long' in f.read(): - is_message_too_long = True + try: + with open(os.path.expanduser(setup_log_path), + 'r', + encoding='utf-8') as f: + if 'too long' in f.read(): + is_message_too_long = True + except Exception as e: # pylint: disable=broad-except + # We don't crash the setup if we cannot read the log file. + # Instead, we should retry the setup with dumping the script + # to a file to be safe. + logger.debug('Failed to read setup log file ' + f'{setup_log_path}: {e}') + is_message_too_long = True if is_message_too_long: # If the setup script is too long, we retry it with dumping @@ -3593,9 +3574,6 @@ def _teardown(self, backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) try: - # TODO(mraheja): remove pylint disabling when filelock - # version updated - # pylint: disable=abstract-class-instantiated with filelock.FileLock( lock_path, backend_utils.CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS): @@ -4106,55 +4084,9 @@ def post_teardown_cleanup(self, * Removing ssh configs for the cluster; * Updating the local state of the cluster; * Removing the terminated cluster's scripts and ray yaml files. - - Raises: - RuntimeError: If it fails to delete the TPU. """ - log_path = os.path.join(os.path.expanduser(self.log_dir), - 'teardown.log') - log_abs_path = os.path.abspath(log_path) cluster_name_on_cloud = handle.cluster_name_on_cloud - # Backward compatibility for TPU nodes created before #2943. Any TPU - # node launched before that PR have the delete script generated (and do - # not have the tpu_node config set in its cluster yaml), so we have to - # call the deletion script to clean up the TPU node. - # For TPU nodes launched after the PR, deletion is done in SkyPilot's - # new GCP provisioner API. - # TODO (zhwu): Remove this after 0.6.0. - if (handle.tpu_delete_script is not None and - os.path.exists(handle.tpu_delete_script)): - # Only call the deletion script if the cluster config does not - # contain TPU node config. Otherwise, the deletion should - # already be handled by the new provisioner. - config = common_utils.read_yaml(handle.cluster_yaml) - tpu_node_config = config['provider'].get('tpu_node') - if tpu_node_config is None: - with rich_utils.safe_status( - ux_utils.spinner_message('Terminating TPU')): - tpu_rc, tpu_stdout, tpu_stderr = log_lib.run_with_log( - ['bash', handle.tpu_delete_script], - log_abs_path, - stream_logs=False, - require_outputs=True) - if tpu_rc != 0: - if _TPU_NOT_FOUND_ERROR in tpu_stderr: - logger.info('TPU not found. ' - 'It should have been deleted already.') - elif purge: - logger.warning( - _TEARDOWN_PURGE_WARNING.format( - reason='stopping/terminating TPU', - details=tpu_stderr)) - else: - raise RuntimeError( - _TEARDOWN_FAILURE_MESSAGE.format( - extra_reason='It is caused by TPU failure.', - cluster_name=common_utils.cluster_name_in_hint( - handle.cluster_name, cluster_name_on_cloud), - stdout=tpu_stdout, - stderr=tpu_stderr)) - if (terminate and handle.launched_resources.is_image_managed is True): # Delete the image when terminating a "cloned" cluster, i.e., # whose image is created by SkyPilot (--clone-disk-from) @@ -4199,11 +4131,7 @@ def post_teardown_cleanup(self, # The cluster file must exist because the cluster_yaml will only # be removed after the cluster entry in the database is removed. config = common_utils.read_yaml(handle.cluster_yaml) - auth_config = config['auth'] - backend_utils.SSHConfigHelper.remove_cluster(handle.cluster_name, - handle.head_ip, - auth_config, - handle.docker_user) + backend_utils.SSHConfigHelper.remove_cluster(handle.cluster_name) global_user_state.remove_cluster(handle.cluster_name, terminate=terminate) @@ -4212,13 +4140,6 @@ def post_teardown_cleanup(self, # This function could be directly called from status refresh, # where we need to cleanup the cluster profile. metadata_utils.remove_cluster_metadata(handle.cluster_name) - # Clean up TPU creation/deletion scripts - # Backward compatibility for TPU nodes created before #2943. - # TODO (zhwu): Remove this after 0.6.0. - if handle.tpu_delete_script is not None: - assert handle.tpu_create_script is not None - common_utils.remove_file_if_exists(handle.tpu_create_script) - common_utils.remove_file_if_exists(handle.tpu_delete_script) # Clean up generated config # No try-except is needed since Ray will fail to teardown the diff --git a/sky/check.py b/sky/check.py index 9ac2848733c..dcaa349d234 100644 --- a/sky/check.py +++ b/sky/check.py @@ -1,4 +1,5 @@ """Credential checks: check cloud credentials and enable clouds.""" +import os import traceback from types import ModuleType from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -194,19 +195,25 @@ def get_cached_enabled_clouds_or_refresh( def get_cloud_credential_file_mounts( excluded_clouds: Optional[Iterable[sky_clouds.Cloud]] ) -> Dict[str, str]: - """Returns the files necessary to access all enabled clouds. + """Returns the files necessary to access all clouds. Returns a dictionary that will be added to a task's file mounts and a list of patterns that will be excluded (used as rsync_exclude). """ - enabled_clouds = get_cached_enabled_clouds_or_refresh() + # Uploading credentials for all clouds instead of only sky check + # enabled clouds because users may have partial credentials for some + # clouds to access their specific resources (e.g. cloud storage) but + # not have the complete credentials to pass sky check. + clouds = sky_clouds.CLOUD_REGISTRY.values() file_mounts = {} - for cloud in enabled_clouds: + for cloud in clouds: if (excluded_clouds is not None and sky_clouds.cloud_in_iterable(cloud, excluded_clouds)): continue cloud_file_mounts = cloud.get_credential_file_mounts() - file_mounts.update(cloud_file_mounts) + for remote_path, local_path in cloud_file_mounts.items(): + if os.path.exists(os.path.expanduser(local_path)): + file_mounts[remote_path] = local_path # Currently, get_cached_enabled_clouds_or_refresh() does not support r2 as # only clouds with computing instances are marked as enabled by skypilot. # This will be removed when cloudflare/r2 is added as a 'cloud'. diff --git a/sky/cli.py b/sky/cli.py index 521c41f2844..37311b540c2 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -339,7 +339,6 @@ def _get_shell_complete_args(complete_fn): _RELOAD_ZSH_CMD = 'source ~/.zshrc' -_RELOAD_FISH_CMD = 'source ~/.config/fish/config.fish' _RELOAD_BASH_CMD = 'source ~/.bashrc' @@ -378,7 +377,9 @@ def _install_shell_completion(ctx: click.Context, param: click.Parameter, cmd = '_SKY_COMPLETE=fish_source sky > \ ~/.config/fish/completions/sky.fish' - reload_cmd = _RELOAD_FISH_CMD + # Fish does not need to be reloaded and will automatically pick up + # completions. + reload_cmd = None elif value == 'zsh': install_cmd = f'_SKY_COMPLETE=zsh_source sky > \ @@ -398,9 +399,10 @@ def _install_shell_completion(ctx: click.Context, param: click.Parameter, check=True, executable=shutil.which('bash')) click.secho(f'Shell completion installed for {value}', fg='green') - click.echo( - 'Completion will take effect once you restart the terminal: ' + - click.style(f'{reload_cmd}', bold=True)) + if reload_cmd is not None: + click.echo( + 'Completion will take effect once you restart the terminal: ' + + click.style(f'{reload_cmd}', bold=True)) except subprocess.CalledProcessError as e: click.secho(f'> Installation failed with code {e.returncode}', fg='red') ctx.exit() @@ -431,7 +433,9 @@ def _uninstall_shell_completion(ctx: click.Context, param: click.Parameter, elif value == 'fish': cmd = 'rm -f ~/.config/fish/completions/sky.fish' - reload_cmd = _RELOAD_FISH_CMD + # Fish does not need to be reloaded and will automatically pick up + # completions. + reload_cmd = None elif value == 'zsh': cmd = 'sed -i"" -e "/# For SkyPilot shell completion/d" ~/.zshrc && \ @@ -447,8 +451,10 @@ def _uninstall_shell_completion(ctx: click.Context, param: click.Parameter, try: subprocess.run(cmd, shell=True, check=True) click.secho(f'Shell completion uninstalled for {value}', fg='green') - click.echo('Changes will take effect once you restart the terminal: ' + - click.style(f'{reload_cmd}', bold=True)) + if reload_cmd is not None: + click.echo( + 'Changes will take effect once you restart the terminal: ' + + click.style(f'{reload_cmd}', bold=True)) except subprocess.CalledProcessError as e: click.secho(f'> Uninstallation failed with code {e.returncode}', fg='red') @@ -549,6 +555,7 @@ def _launch_with_confirm( retry_until_up: bool = False, no_setup: bool = False, clone_disk_from: Optional[str] = None, + fast: bool = False, ): """Launch a cluster with a Task.""" if cluster is None: @@ -613,6 +620,7 @@ def _launch_with_confirm( retry_until_up=retry_until_up, no_setup=no_setup, clone_disk_from=clone_disk_from, + fast=fast, ) @@ -1034,6 +1042,13 @@ def cli(): help=('[Experimental] Clone disk from an existing cluster to launch ' 'a new one. This is useful when the new cluster needs to have ' 'the same data on the boot disk as an existing cluster.')) +@click.option( + '--fast', + is_flag=True, + default=False, + required=False, + help=('[Experimental] If the cluster is already up and available, skip ' + 'provisioning and setup steps.')) @usage_lib.entrypoint def launch( entrypoint: Tuple[str, ...], @@ -1065,6 +1080,7 @@ def launch( yes: bool, no_setup: bool, clone_disk_from: Optional[str], + fast: bool, ): """Launch a cluster or task. @@ -1133,7 +1149,8 @@ def launch( down=down, retry_until_up=retry_until_up, no_setup=no_setup, - clone_disk_from=clone_disk_from) + clone_disk_from=clone_disk_from, + fast=fast) @cli.command(cls=_DocumentedCodeCommand) @@ -3056,7 +3073,8 @@ def show_gpus( # This will validate 'cloud' and raise if not found. cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud) - service_catalog.validate_region_zone(region, None, clouds=cloud) + cloud_name = cloud_obj.canonical_name() if cloud_obj is not None else None + service_catalog.validate_region_zone(region, None, clouds=cloud_name) show_all = all if show_all and accelerator_str is not None: raise click.UsageError('--all is only allowed without a GPU name.') @@ -3078,7 +3096,7 @@ def _get_kubernetes_realtime_gpu_table( qty_header = 'QTY_FILTER' free_header = 'FILTERED_FREE_GPUS' else: - qty_header = 'QTY_PER_NODE' + qty_header = 'REQUESTABLE_QTY_PER_NODE' free_header = 'TOTAL_FREE_GPUS' realtime_gpu_table = log_utils.create_table( ['GPU', qty_header, 'TOTAL_GPUS', free_header]) @@ -3142,8 +3160,8 @@ def _output(): # Optimization - do not poll for Kubernetes API for fetching # common GPUs because that will be fetched later for the table after # common GPUs. - clouds_to_list = cloud - if cloud is None: + clouds_to_list = cloud_name + if cloud_name is None: clouds_to_list = [ c for c in service_catalog.ALL_CLOUDS if c != 'kubernetes' ] @@ -3153,7 +3171,8 @@ def _output(): # Collect k8s related messages in k8s_messages and print them at end print_section_titles = False # If cloud is kubernetes, we want to show real-time capacity - if kubernetes_is_enabled and (cloud is None or cloud_is_kubernetes): + if kubernetes_is_enabled and (cloud_name is None or + cloud_is_kubernetes): if region: context = region else: @@ -3263,8 +3282,8 @@ def _output(): name, quantity = accelerator_str, None print_section_titles = False - if (kubernetes_is_enabled and (cloud is None or cloud_is_kubernetes) and - not show_all): + if (kubernetes_is_enabled and + (cloud_name is None or cloud_is_kubernetes) and not show_all): # Print section title if not showing all and instead a specific # accelerator is requested print_section_titles = True @@ -3336,7 +3355,7 @@ def _output(): if len(result) == 0: quantity_str = (f' with requested quantity {quantity}' if quantity else '') - cloud_str = f' on {cloud_obj}.' if cloud else ' in cloud catalogs.' + cloud_str = f' on {cloud_obj}.' if cloud_name else ' in cloud catalogs.' yield f'Resources \'{name}\'{quantity_str} not found{cloud_str} ' yield 'To show available accelerators, run: sky show-gpus --all' return @@ -3511,7 +3530,7 @@ def jobs(): default=None, type=str, hidden=True, - help=('Alias for --name, the name of the spot job.')) + help=('Alias for --name, the name of the managed job.')) @click.option('--job-recovery', default=None, type=str, @@ -3541,6 +3560,15 @@ def jobs(): default=False, required=False, help='Skip confirmation prompt.') +# TODO(cooperc): remove this flag once --fast can robustly detect cluster +# yaml config changes +@click.option('--fast', + default=False, + is_flag=True, + help='[Experimental] Launch the job faster by skipping ' + 'controller initialization steps. If you update SkyPilot or ' + 'your local cloud credentials, they will not be reflected until ' + 'you run `sky jobs launch` at least once without this flag.') @timeline.event @usage_lib.entrypoint def jobs_launch( @@ -3567,6 +3595,7 @@ def jobs_launch( detach_run: bool, retry_until_up: bool, yes: bool, + fast: bool, ): """Launch a managed job from a YAML or a command. @@ -3650,7 +3679,8 @@ def jobs_launch( managed_jobs.launch(dag, name, detach_run=detach_run, - retry_until_up=retry_until_up) + retry_until_up=retry_until_up, + fast=fast) @jobs.command('queue', cls=_DocumentedCodeCommand) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index a0962b17cac..43062ebf393 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -2,13 +2,12 @@ import enum import fnmatch import functools -import json import os import re import subprocess import time import typing -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union from sky import clouds from sky import exceptions @@ -383,7 +382,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='aws') @@ -411,10 +410,8 @@ def make_deploy_resources_variables( r = resources # r.accelerators is cleared but .instance_type encodes the info. acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) if r.extract_docker_image() is not None: image_id_to_use = None diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index adffd32ad88..edd5840d271 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -1,20 +1,21 @@ """Azure.""" import functools -import json import os import re import subprocess import textwrap import typing -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import colorama from sky import clouds from sky import exceptions from sky import sky_logging +from sky import skypilot_config from sky.adaptors import azure from sky.clouds import service_catalog +from sky.clouds.utils import azure_utils from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import ux_utils @@ -36,6 +37,17 @@ _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB = 30 _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB = 150 +_DEFAULT_SKYPILOT_IMAGE_GB = 30 + +_DEFAULT_CPU_IMAGE_ID = 'skypilot:custom-cpu-ubuntu-v2' +_DEFAULT_GPU_IMAGE_ID = 'skypilot:custom-gpu-ubuntu-v2' +_DEFAULT_V1_IMAGE_ID = 'skypilot:custom-gpu-ubuntu-v1' +_DEFAULT_GPU_K80_IMAGE_ID = 'skypilot:k80-ubuntu-2004' +_FALLBACK_IMAGE_ID = 'skypilot:gpu-ubuntu-2204' +# This is used by Azure GPU VMs that use grid drivers (e.g. A10). +_DEFAULT_GPU_GRID_IMAGE_ID = 'skypilot:custom-gpu-ubuntu-v2-grid' + +_COMMUNITY_IMAGE_PREFIX = '/CommunityGalleries' def _run_output(cmd): @@ -132,29 +144,56 @@ def get_egress_cost(self, num_gigabytes: float): cost += 0.0 return cost + @classmethod + def get_default_instance_type( + cls, + cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[resources_utils.DiskTier] = None + ) -> Optional[str]: + return service_catalog.get_default_instance_type(cpus=cpus, + memory=memory, + disk_tier=disk_tier, + clouds='azure') + @classmethod def get_image_size(cls, image_id: str, region: Optional[str]) -> float: - if region is None: - # The region used here is only for where to send the query, - # not the image location. Azure's image is globally available. - region = 'eastus' - is_skypilot_image_tag = False + # Process skypilot images. if image_id.startswith('skypilot:'): - is_skypilot_image_tag = True image_id = service_catalog.get_image_id_from_tag(image_id, clouds='azure') - image_id_splitted = image_id.split(':') - if len(image_id_splitted) != 4: - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Invalid image id: {image_id}. Expected ' - 'format: :::') - publisher, offer, sku, version = image_id_splitted - if is_skypilot_image_tag: - if offer == 'ubuntu-hpc': - return _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB + if image_id.startswith(_COMMUNITY_IMAGE_PREFIX): + # Avoid querying the image size from Azure as + # all skypilot custom images have the same size. + return _DEFAULT_SKYPILOT_IMAGE_GB else: - return _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB + publisher, offer, sku, version = image_id.split(':') + if offer == 'ubuntu-hpc': + return _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB + else: + return _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB + + # Process user-specified images. + azure_utils.validate_image_id(image_id) compute_client = azure.get_client('compute', cls.get_project_id()) + + # Community gallery image. + if image_id.startswith(_COMMUNITY_IMAGE_PREFIX): + if region is None: + return 0.0 + _, _, gallery_name, _, image_name = image_id.split('/') + try: + return azure_utils.get_community_image_size( + compute_client, gallery_name, image_name, region) + except exceptions.ResourcesUnavailableError: + return 0.0 + + # Marketplace image + if region is None: + # The region used here is only for where to send the query, + # not the image location. Marketplace image is globally available. + region = 'eastus' + publisher, offer, sku, version = image_id.split(':') try: image = compute_client.virtual_machine_images.get( region, publisher, offer, sku, version) @@ -176,40 +215,25 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> float: size_in_gb = size_in_bytes / (1024**3) return size_in_gb - @classmethod - def get_default_instance_type( - cls, - cpus: Optional[str] = None, - memory: Optional[str] = None, - disk_tier: Optional[resources_utils.DiskTier] = None - ) -> Optional[str]: - return service_catalog.get_default_instance_type(cpus=cpus, - memory=memory, - disk_tier=disk_tier, - clouds='azure') - def _get_default_image_tag(self, gen_version, instance_type) -> str: # ubuntu-2004 v21.08.30, K80 requires image with old NVIDIA driver version acc = self.get_accelerators_from_instance_type(instance_type) if acc is not None: acc_name = list(acc.keys())[0] if acc_name == 'K80': - return 'skypilot:k80-ubuntu-2004' - - # ubuntu-2004 v21.11.04, the previous image we used in the past for - # V1 HyperV instance before we change default image to ubuntu-hpc. + return _DEFAULT_GPU_K80_IMAGE_ID + if acc_name == 'A10': + return _DEFAULT_GPU_GRID_IMAGE_ID + # About Gen V1 vs V2: # In Azure, all instances with K80 (Standard_NC series), some # instances with M60 (Standard_NV series) and some cpu instances - # (Basic_A, Standard_D, ...) are V1 instance. For these instances, - # we use the previous image. + # (Basic_A, Standard_D, ...) are V1 instance. + # All A100 instances are V2. if gen_version == 'V1': - return 'skypilot:v1-ubuntu-2004' - - # nvidia-driver: 535.54.03, cuda: 12.2 - # see: https://github.com/Azure/azhpc-images/releases/tag/ubuntu-hpc-20230803 - # All A100 instances is of gen2, so it will always use - # the latest ubuntu-hpc:2204 image. - return 'skypilot:gpu-ubuntu-2204' + return _DEFAULT_V1_IMAGE_ID + if acc is None: + return _DEFAULT_CPU_IMAGE_ID + return _DEFAULT_GPU_IMAGE_ID @classmethod def regions_with_offering(cls, instance_type: str, @@ -252,7 +276,7 @@ def zones_provision_loop( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='azure') @@ -284,10 +308,9 @@ def make_deploy_resources_variables( acc_dict = self.get_accelerators_from_instance_type(r.instance_type) acc_count = None if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) acc_count = str(sum(acc_dict.values())) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) if (resources.image_id is None or resources.extract_docker_image() is not None): @@ -302,21 +325,41 @@ def make_deploy_resources_variables( else: assert region_name in resources.image_id, resources.image_id image_id = resources.image_id[region_name] + + # Checked basic image syntax in resources.py if image_id.startswith('skypilot:'): image_id = service_catalog.get_image_id_from_tag(image_id, clouds='azure') - # Already checked in resources.py - publisher, offer, sku, version = image_id.split(':') - image_config = { - 'image_publisher': publisher, - 'image_offer': offer, - 'image_sku': sku, - 'image_version': version, - } - - # Setup the A10 nvidia driver. - need_nvidia_driver_extension = (acc_dict is not None and - 'A10' in acc_dict) + # Fallback if image does not exist in the specified region. + # Putting fallback here instead of at image validation + # when creating the resource because community images are + # regional so we need the correct region when we check whether + # the image exists. + if image_id.startswith( + _COMMUNITY_IMAGE_PREFIX + ) and region_name not in azure_catalog.COMMUNITY_IMAGE_AVAILABLE_REGIONS: + logger.info(f'Azure image {image_id} does not exist in region ' + f'{region_name} so use the fallback image instead.') + image_id = service_catalog.get_image_id_from_tag( + _FALLBACK_IMAGE_ID, clouds='azure') + + if image_id.startswith(_COMMUNITY_IMAGE_PREFIX): + image_config = {'community_gallery_image_id': image_id} + else: + publisher, offer, sku, version = image_id.split(':') + image_config = { + 'image_publisher': publisher, + 'image_offer': offer, + 'image_sku': sku, + 'image_version': version, + } + + # Determine resource group for deploying the instance. + resource_group_name = skypilot_config.get_nested( + ('azure', 'resource_group_vm'), None) + use_external_resource_group = resource_group_name is not None + if resource_group_name is None: + resource_group_name = f'{cluster_name.name_on_cloud}-{region_name}' # Setup commands to eliminate the banner and restart sshd. # This script will modify /etc/ssh/sshd_config and add a bash script @@ -370,17 +413,16 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: # Azure does not support specific zones. 'zones': None, **image_config, - 'need_nvidia_driver_extension': need_nvidia_driver_extension, 'disk_tier': Azure._get_disk_type(disk_tier), 'cloud_init_setup_commands': cloud_init_setup_commands, 'azure_subscription_id': self.get_project_id(dryrun), - 'resource_group': f'{cluster_name.name_on_cloud}-{region_name}', + 'resource_group': resource_group_name, + 'use_external_resource_group': use_external_resource_group, } # Setting disk performance tier for high disk tier. if disk_tier == resources_utils.DiskTier.HIGH: resources_vars['disk_performance_tier'] = 'P50' - return resources_vars def _get_feasible_launchable_resources( diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index dae1d56d309..4028c1fef59 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -9,8 +9,9 @@ """ import collections import enum +import math import typing -from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple +from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union from sky import exceptions from sky import skypilot_config @@ -306,7 +307,7 @@ def get_vcpus_mem_from_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: """Returns {acc: acc_count} held by 'instance_type', if any.""" raise NotImplementedError @@ -673,8 +674,9 @@ def _check_instance_type_accelerators_combination( assert resources.is_launchable(), resources def _equal_accelerators( - acc_requested: Optional[Dict[str, int]], - acc_from_instance_type: Optional[Dict[str, int]]) -> bool: + acc_requested: Optional[Dict[str, Union[int, float]]], + acc_from_instance_type: Optional[Dict[str, Union[int, + float]]]) -> bool: """Check the requested accelerators equals to the instance type Check the requested accelerators equals to the accelerators @@ -689,12 +691,14 @@ def _equal_accelerators( for acc in acc_requested: if acc not in acc_from_instance_type: return False - if acc_requested[acc] != acc_from_instance_type[acc]: + # Avoid float point precision issue. + if not math.isclose(acc_requested[acc], + acc_from_instance_type[acc]): return False return True - acc_from_instance_type = (cls.get_accelerators_from_instance_type( - resources.instance_type)) + acc_from_instance_type = cls.get_accelerators_from_instance_type( + resources.instance_type) if not _equal_accelerators(resources.accelerators, acc_from_instance_type): with ux_utils.print_exception_no_traceback(): @@ -819,6 +823,10 @@ def delete_image(cls, image_id: str, region: Optional[str]) -> None: # === End of image related methods === + @classmethod + def canonical_name(cls) -> str: + return cls.__name__.lower() + def __repr__(self): return self._REPR diff --git a/sky/clouds/cloud_registry.py b/sky/clouds/cloud_registry.py index 5c4b10b9fd4..52a026aa330 100644 --- a/sky/clouds/cloud_registry.py +++ b/sky/clouds/cloud_registry.py @@ -1,7 +1,7 @@ """Clouds need to be registered in CLOUD_REGISTRY to be discovered""" import typing -from typing import Optional, Type +from typing import Callable, Dict, List, Optional, overload, Type, Union from sky.utils import ux_utils @@ -12,20 +12,65 @@ class _CloudRegistry(dict): """Registry of clouds.""" + def __init__(self) -> None: + super().__init__() + self.aliases: Dict[str, str] = {} + def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']: + """Returns the cloud instance from the canonical name or alias.""" if name is None: return None - if name.lower() not in self: - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Cloud {name!r} is not a valid cloud among ' - f'{list(self.keys())}') - return self.get(name.lower()) + search_name = name.lower() + + if search_name in self: + return self[search_name] + + if search_name in self.aliases: + return self[self.aliases[search_name]] + + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Cloud {name!r} is not a valid cloud among ' + f'{[*self.keys(), *self.aliases.keys()]}') + + @overload def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']: - name = cloud_cls.__name__.lower() - assert name not in self, f'{name} already registered' - self[name] = cloud_cls() - return cloud_cls + ... + + @overload + def register( + self, + cloud_cls: None = None, + aliases: Optional[List[str]] = None, + ) -> Callable[[Type['cloud.Cloud']], Type['cloud.Cloud']]: + ... + + def register( + self, + cloud_cls: Optional[Type['cloud.Cloud']] = None, + aliases: Optional[List[str]] = None, + ) -> Union[Type['cloud.Cloud'], Callable[[Type['cloud.Cloud']], + Type['cloud.Cloud']]]: + + def _register(cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']: + name = cloud_cls.canonical_name() + assert name not in self, f'{name} already registered' + self[name] = cloud_cls() + + for alias in aliases or []: + alias = alias.lower() + assert alias not in self.aliases, ( + f'alias {alias} already registered') + self.aliases[alias] = name + + return cloud_cls + + if cloud_cls is not None: + # invocation without parens (e.g. just `@register`) + return _register(cloud_cls) + + # Invocation with parens (e.g. `@register(aliases=['alias'])`) + return _register CLOUD_REGISTRY: _CloudRegistry = _CloudRegistry() diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index 4dca442fa01..6f02e007049 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -1,8 +1,7 @@ """Cudo Compute""" -import json import subprocess import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from sky import clouds from sky.clouds import service_catalog @@ -183,7 +182,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='cudo') @@ -202,10 +201,8 @@ def make_deploy_resources_variables( del zones, cluster_name # unused r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) return { 'instance_type': resources.instance_type, diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index 473fceabbe3..31e2112f8f7 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -1,8 +1,7 @@ """Fluidstack Cloud.""" -import json import os import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union import requests @@ -155,7 +154,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='fluidstack') @@ -184,10 +183,8 @@ def make_deploy_resources_variables( r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) return { 'instance_type': resources.instance_type, diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 1b70abf914d..0e20fdc9789 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -7,7 +7,7 @@ import subprocess import time import typing -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union import colorama @@ -669,7 +669,7 @@ def _get_feasible_launchable_resources( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: # GCP handles accelerators separately from regular instance types, # hence return none here. return None diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index b78cc4287c0..0ac3c36cc48 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -1,8 +1,7 @@ """IBM Web Services.""" -import json import os import typing -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import colorama @@ -206,10 +205,8 @@ def _get_profile_resources(instance_profile): 'IBM does not currently support spot instances in this framework' acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) instance_resources = _get_profile_resources(r.instance_type) @@ -247,7 +244,7 @@ def get_vcpus_mem_from_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: """Returns {acc: acc_count} held by 'instance_type', if any.""" return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='ibm') diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index da85246e9ea..d930a24271f 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -1,10 +1,9 @@ """Kubernetes.""" import functools -import json import os import re import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from sky import clouds from sky import sky_logging @@ -33,7 +32,7 @@ _SKYPILOT_SYSTEM_NAMESPACE = 'skypilot-system' -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register(aliases=['k8s']) class Kubernetes(clouds.Cloud): """Kubernetes.""" @@ -69,8 +68,8 @@ class Kubernetes(clouds.Cloud): 'Kubernetes.', } - IMAGE_CPU = 'skypilot:cpu-ubuntu-2004' - IMAGE_GPU = 'skypilot:gpu-ubuntu-2004' + IMAGE_CPU = 'skypilot:custom-cpu-ubuntu-2004' + IMAGE_GPU = 'skypilot:custom-gpu-ubuntu-2004' PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT STATUS_VERSION = clouds.StatusVersion.SKYPILOT @@ -271,7 +270,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: inst = kubernetes_utils.KubernetesInstanceType.from_instance_type( instance_type) return { @@ -328,10 +327,8 @@ def make_deploy_resources_variables( r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) # resources.memory and cpus are None if they are not explicitly set. # We fetch the default values for the instance type in that case. diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 0201f4f76ad..055a5338750 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -1,7 +1,6 @@ """Lambda Cloud.""" -import json import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union import requests @@ -136,7 +135,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='lambda') @@ -164,10 +163,8 @@ def make_deploy_resources_variables( r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) resources_vars = { 'instance_type': resources.instance_type, diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 810e43fe3b5..93a70c5ac37 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -20,11 +20,10 @@ - Hysun He (hysun.he@oracle.com) @ Oct 13, 2024: Support more OS types additional to ubuntu for OCI resources. """ -import json import logging import os import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from sky import clouds from sky import exceptions @@ -193,7 +192,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='oci') @@ -213,10 +212,8 @@ def make_deploy_resources_variables( acc_dict = self.get_accelerators_from_instance_type( resources.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) image_str = self._get_image_id(resources.image_id, region.name, resources.instance_type) @@ -468,8 +465,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]: api_key_file = oci_cfg[ 'key_file'] if 'key_file' in oci_cfg else 'BadConf' sky_cfg_file = oci_utils.oci_config.get_sky_user_config_file() + # Must catch ImportError before any oci_adaptor.oci.exceptions + # because oci_adaptor.oci.exceptions can throw ImportError. except ImportError: return {} + except oci_adaptor.oci.exceptions.ConfigFileNotFound: + return {} # OCI config and API key file are mandatory credential_files = [oci_cfg_file, api_key_file] diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index 4c4fa1d695a..4047a2f5926 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -1,8 +1,7 @@ """ Paperspace Cloud. """ -import json import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union import requests @@ -162,7 +161,7 @@ def get_default_instance_type( @classmethod def get_accelerators_from_instance_type( - cls, instance_type: str) -> Optional[Dict[str, int]]: + cls, instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='paperspace') @@ -181,10 +180,8 @@ def make_deploy_resources_variables( r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) return { 'instance_type': resources.instance_type, diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 6cfdf11c6b4..0d693fd9f60 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -1,8 +1,7 @@ """ RunPod Cloud. """ -import json import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from sky import clouds from sky.clouds import service_catalog @@ -147,7 +146,7 @@ def get_default_instance_type( @classmethod def get_accelerators_from_instance_type( - cls, instance_type: str) -> Optional[Dict[str, int]]: + cls, instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='runpod') @@ -166,10 +165,8 @@ def make_deploy_resources_variables( r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) if r.image_id is None: image_id = 'runpod/base:0.0.2' diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index 17a54ce1607..d0ad611bf0c 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -4,9 +4,8 @@ to access the SCP catalog and check credentials for the SCP access. """ -import json import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union from sky import clouds from sky import exceptions @@ -160,7 +159,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds='scp') @@ -188,11 +187,9 @@ def make_deploy_resources_variables( r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None image_id = self._get_image_id(r.image_id, region.name, r.instance_type) return { 'instance_type': resources.instance_type, diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index f2301bac466..4deab8ac204 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -238,7 +238,7 @@ def get_default_instance_type(cpus: Optional[str] = None, def get_accelerators_from_instance_type( instance_type: str, - clouds: CloudFilter = None) -> Optional[Dict[str, int]]: + clouds: CloudFilter = None) -> Optional[Dict[str, Union[int, float]]]: """Returns the accelerators from a instance type.""" return _map_clouds_catalog(clouds, 'get_accelerators_from_instance_type', instance_type) diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index d156135047b..918a4070414 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -8,7 +8,7 @@ import os import threading import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky import exceptions from sky import sky_logging @@ -243,7 +243,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl( _get_df(), instance_type) diff --git a/sky/clouds/service_catalog/azure_catalog.py b/sky/clouds/service_catalog/azure_catalog.py index 2d323cbac5f..62cb422bf83 100644 --- a/sky/clouds/service_catalog/azure_catalog.py +++ b/sky/clouds/service_catalog/azure_catalog.py @@ -4,14 +4,32 @@ instance types and pricing information for Azure. """ import re -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky import clouds as cloud_lib +from sky import sky_logging from sky.clouds import Azure from sky.clouds.service_catalog import common from sky.utils import resources_utils from sky.utils import ux_utils +logger = sky_logging.init_logger(__name__) + +# This list should match the list of regions in +# skypilot image generation Packer script's replication_regions +# sky/clouds/service_catalog/images/skypilot-azure-cpu-ubuntu.pkr.hcl +COMMUNITY_IMAGE_AVAILABLE_REGIONS = { + 'centralus', + 'eastus', + 'eastus2', + 'northcentralus', + 'southcentralus', + 'westcentralus', + 'westus', + 'westus2', + 'westus3', +} + # The frequency of pulling the latest catalog from the cloud provider. # Though the catalog update is manual in our skypilot-catalog repo, we # still want to pull the latest catalog periodically to make sure the @@ -119,7 +137,7 @@ def _filter_disk_type(instance_type: str) -> bool: def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) @@ -139,6 +157,7 @@ def get_instance_type_for_accelerator( if zone is not None: with ux_utils.print_exception_no_traceback(): raise ValueError('Azure does not support zones.') + return common.get_instance_type_for_accelerator_impl(df=_df, acc_name=acc_name, acc_count=acc_count, @@ -176,9 +195,16 @@ def list_accelerators( def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]: """Returns the image id from the tag.""" - # Azure images are not region-specific. - del region # Unused. - return common.get_image_id_from_tag_impl(_image_df, tag, None) + global _image_df + image_id = common.get_image_id_from_tag_impl(_image_df, tag, region) + if image_id is None: + # Refresh the image catalog and try again, if the image tag is not + # found. + logger.debug('Refreshing the image catalog and trying again.') + _image_df = common.read_catalog('azure/images.csv', + pull_frequency_hours=0) + image_id = common.get_image_id_from_tag_impl(_image_df, tag, region) + return image_id def is_image_tag_valid(tag: str, region: Optional[str]) -> bool: diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 4df72824027..1082b4e9efd 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -5,7 +5,7 @@ import os import time import typing -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import filelock import requests @@ -481,7 +481,7 @@ def get_instance_type_for_cpus_mem_impl( def get_accelerators_from_instance_type_impl( df: 'pd.DataFrame', instance_type: str, -) -> Optional[Dict[str, int]]: +) -> Optional[Dict[str, Union[int, float]]]: df = _get_instance_type(df, instance_type, None) if len(df) == 0: with ux_utils.print_exception_no_traceback(): @@ -490,13 +490,19 @@ def get_accelerators_from_instance_type_impl( acc_name, acc_count = row['AcceleratorName'], row['AcceleratorCount'] if pd.isnull(acc_name): return None - return {acc_name: int(acc_count)} + + def _convert(value): + if int(value) == value: + return int(value) + return float(value) + + return {acc_name: _convert(acc_count)} def get_instance_type_for_accelerator_impl( df: 'pd.DataFrame', acc_name: str, - acc_count: int, + acc_count: Union[int, float], cpus: Optional[str] = None, memory: Optional[str] = None, use_spot: bool = False, @@ -509,7 +515,7 @@ def get_instance_type_for_accelerator_impl( accelerators with sorted prices and a list of candidates with fuzzy search. """ result = df[(df['AcceleratorName'].str.fullmatch(acc_name, case=False)) & - (df['AcceleratorCount'] == acc_count)] + (abs(df['AcceleratorCount'] - acc_count) <= 0.01)] result = _filter_region_zone(result, region, zone) if len(result) == 0: fuzzy_result = df[ @@ -522,8 +528,11 @@ def get_instance_type_for_accelerator_impl( fuzzy_candidate_list = [] if len(fuzzy_result) > 0: for _, row in fuzzy_result.iterrows(): + acc_cnt = float(row['AcceleratorCount']) + acc_count_display = (int(acc_cnt) if acc_cnt.is_integer() else + f'{acc_cnt:.2f}') fuzzy_candidate_list.append(f'{row["AcceleratorName"]}:' - f'{int(row["AcceleratorCount"])}') + f'{acc_count_display}') return (None, fuzzy_candidate_list) result = _filter_with_cpus(result, cpus) diff --git a/sky/clouds/service_catalog/cudo_catalog.py b/sky/clouds/service_catalog/cudo_catalog.py index 62832cba5bf..d4adc5baea5 100644 --- a/sky/clouds/service_catalog/cudo_catalog.py +++ b/sky/clouds/service_catalog/cudo_catalog.py @@ -1,7 +1,7 @@ """Cudo Compute Offerings Catalog.""" import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.clouds.service_catalog import common import sky.provision.cudo.cudo_machine_type as cudo_mt @@ -66,7 +66,7 @@ def get_default_instance_type(cpus: Optional[str] = None, def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py index bbd337e23aa..f646cac339a 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py @@ -93,14 +93,15 @@ def get_regions() -> List[str]: # We have to manually remove it. DEPRECATED_FAMILIES = ['standardNVSv2Family'] -# Some A10 instance types only contains a fractional of GPU. We temporarily -# filter them out here to avoid using it as a whole A10 GPU. -# TODO(zhwu,tian): support fractional GPUs, which can be done on -# kubernetes as well. +# Azure has those fractional A10 instance types, which still shows has 1 A10 GPU +# in the API response. We manually changing the number of GPUs to a float here. # Ref: https://learn.microsoft.com/en-us/azure/virtual-machines/nva10v5-series -FILTERED_A10_INSTANCE_TYPES = [ - f'Standard_NV{vcpu}ads_A10_v5' for vcpu in [6, 12, 18] -] +# TODO(zhwu,tian): Support fractional GPUs on k8s as well. +# TODO(tian): Maybe we should support literally fractional count, i.e. A10:1/6 +# instead of float point count (A10:0.167). +AZURE_FRACTIONAL_A10_INS_TYPE_TO_NUM_GPUS = { + f'Standard_NV{vcpu}ads_A10_v5': round(vcpu / 36, 3) for vcpu in [6, 12, 18] +} USEFUL_COLUMNS = [ 'InstanceType', 'AcceleratorName', 'AcceleratorCount', 'vCPUs', 'MemoryGiB', @@ -274,6 +275,19 @@ def get_additional_columns(row): axis='columns', ) + def _upd_a10_gpu_count(row): + new_gpu_cnt = AZURE_FRACTIONAL_A10_INS_TYPE_TO_NUM_GPUS.get( + row['InstanceType']) + if new_gpu_cnt is not None: + return new_gpu_cnt + return row['AcceleratorCount'] + + # Manually update the GPU count for fractional A10 instance types. + # Those instance types have fractional GPU count, but Azure API returns + # 1 GPU count for them. We manually update the GPU count here. + df_ret['AcceleratorCount'] = df_ret.apply(_upd_a10_gpu_count, + axis='columns') + # As of Dec 2023, a few H100 instance types fetched from Azure APIs do not # have pricing: # @@ -299,10 +313,6 @@ def get_additional_columns(row): after_drop_len = len(df_ret) print(f'Dropped {before_drop_len - after_drop_len} duplicated rows') - # Filter out instance types that only contain a fractional of GPU. - df_ret = df_ret.loc[~df_ret['InstanceType'].isin(FILTERED_A10_INSTANCE_TYPES - )] - # Filter out deprecated families df_ret = df_ret.loc[~df_ret['family'].isin(DEPRECATED_FAMILIES)] df_ret = df_ret[USEFUL_COLUMNS] diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py index 6550c6bbe64..8cc9fc6c127 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py @@ -47,6 +47,10 @@ TPU_V4_ZONES = ['us-central2-b'] # TPU v3 pods are available in us-east1-d, but hidden in the skus. # We assume the TPU prices are the same as us-central1. +# TPU v6e's pricing info is not available on the SKUs. However, in +# https://cloud.google.com/tpu/pricing, it listed the price for 4 regions: +# us-east1, us-east5, europe-west4, and asia-northeast1. We hardcode them here +# and filtered out the other regions (us-central{1,2}, us-south1). HIDDEN_TPU_DF = pd.read_csv( io.StringIO( textwrap.dedent("""\ @@ -58,8 +62,50 @@ ,tpu-v3-512,1,,,tpu-v3-512,512.0,153.6,us-east1,us-east1-d ,tpu-v3-1024,1,,,tpu-v3-1024,1024.0,307.2,us-east1,us-east1-d ,tpu-v3-2048,1,,,tpu-v3-2048,2048.0,614.4,us-east1,us-east1-d + ,tpu-v6e-1,1,,,tpu-v6e-1,2.7,,us-east5,us-east5-b + ,tpu-v6e-1,1,,,tpu-v6e-1,2.7,,us-east5,us-east5-c + ,tpu-v6e-1,1,,,tpu-v6e-1,2.97,,europe-west4,europe-west4-a + ,tpu-v6e-1,1,,,tpu-v6e-1,3.24,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-1,1,,,tpu-v6e-1,2.7,,us-east1,us-east1-d + ,tpu-v6e-4,1,,,tpu-v6e-4,10.8,,us-east5,us-east5-b + ,tpu-v6e-4,1,,,tpu-v6e-4,10.8,,us-east5,us-east5-c + ,tpu-v6e-4,1,,,tpu-v6e-4,11.88,,europe-west4,europe-west4-a + ,tpu-v6e-4,1,,,tpu-v6e-4,12.96,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-4,1,,,tpu-v6e-4,10.8,,us-east1,us-east1-d + ,tpu-v6e-8,1,,,tpu-v6e-8,21.6,,us-east5,us-east5-b + ,tpu-v6e-8,1,,,tpu-v6e-8,21.6,,us-east5,us-east5-c + ,tpu-v6e-8,1,,,tpu-v6e-8,23.76,,europe-west4,europe-west4-a + ,tpu-v6e-8,1,,,tpu-v6e-8,25.92,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-8,1,,,tpu-v6e-8,21.6,,us-east1,us-east1-d + ,tpu-v6e-16,1,,,tpu-v6e-16,43.2,,us-east5,us-east5-b + ,tpu-v6e-16,1,,,tpu-v6e-16,43.2,,us-east5,us-east5-c + ,tpu-v6e-16,1,,,tpu-v6e-16,47.52,,europe-west4,europe-west4-a + ,tpu-v6e-16,1,,,tpu-v6e-16,51.84,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-16,1,,,tpu-v6e-16,43.2,,us-east1,us-east1-d + ,tpu-v6e-32,1,,,tpu-v6e-32,86.4,,us-east5,us-east5-b + ,tpu-v6e-32,1,,,tpu-v6e-32,86.4,,us-east5,us-east5-c + ,tpu-v6e-32,1,,,tpu-v6e-32,95.04,,europe-west4,europe-west4-a + ,tpu-v6e-32,1,,,tpu-v6e-32,103.68,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-32,1,,,tpu-v6e-32,86.4,,us-east1,us-east1-d + ,tpu-v6e-64,1,,,tpu-v6e-64,172.8,,us-east5,us-east5-b + ,tpu-v6e-64,1,,,tpu-v6e-64,172.8,,us-east5,us-east5-c + ,tpu-v6e-64,1,,,tpu-v6e-64,190.08,,europe-west4,europe-west4-a + ,tpu-v6e-64,1,,,tpu-v6e-64,207.36,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-64,1,,,tpu-v6e-64,172.8,,us-east1,us-east1-d + ,tpu-v6e-128,1,,,tpu-v6e-128,345.6,,us-east5,us-east5-b + ,tpu-v6e-128,1,,,tpu-v6e-128,345.6,,us-east5,us-east5-c + ,tpu-v6e-128,1,,,tpu-v6e-128,380.16,,europe-west4,europe-west4-a + ,tpu-v6e-128,1,,,tpu-v6e-128,414.72,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-128,1,,,tpu-v6e-128,345.6,,us-east1,us-east1-d + ,tpu-v6e-256,1,,,tpu-v6e-256,691.2,,us-east5,us-east5-b + ,tpu-v6e-256,1,,,tpu-v6e-256,691.2,,us-east5,us-east5-c + ,tpu-v6e-256,1,,,tpu-v6e-256,760.32,,europe-west4,europe-west4-a + ,tpu-v6e-256,1,,,tpu-v6e-256,829.44,,asia-northeast1,asia-northeast1-b + ,tpu-v6e-256,1,,,tpu-v6e-256,691.2,,us-east1,us-east1-d """))) +TPU_V6E_MISSING_REGIONS = ['us-central1', 'us-central2', 'us-south1'] + # TPU V5 is not visible in specific zones. We hardcode the missing zones here. # NOTE(dev): Keep the zones and the df in sync. TPU_V5_MISSING_ZONES_DF = { @@ -683,11 +729,13 @@ def get_tpu_price(row: pd.Series, spot: bool) -> Optional[float]: 'not found in SKUs or hidden TPU price DF.') # TODO(tian): Hack. Should investigate how to retrieve the price # for TPU-v6e. - if not tpu_name.startswith('tpu-v6e'): + if (tpu_name.startswith('tpu-v6e') and + tpu_region in TPU_V6E_MISSING_REGIONS): + if not spot: + tpu_price = 0.0 + else: assert spot or tpu_price is not None, (row, hidden_tpu, HIDDEN_TPU_DF) - else: - tpu_price = 0.0 return tpu_price df['Price'] = df.apply(lambda row: get_tpu_price(row, spot=False), axis=1) diff --git a/sky/clouds/service_catalog/fluidstack_catalog.py b/sky/clouds/service_catalog/fluidstack_catalog.py index 2f47a38df43..7a28ac8174a 100644 --- a/sky/clouds/service_catalog/fluidstack_catalog.py +++ b/sky/clouds/service_catalog/fluidstack_catalog.py @@ -4,7 +4,7 @@ instance types and pricing information for FluidStack. """ import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.clouds.service_catalog import common from sky.utils import ux_utils @@ -65,7 +65,7 @@ def get_default_instance_type(cpus: Optional[str] = None, def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/ibm_catalog.py b/sky/clouds/service_catalog/ibm_catalog.py index 51b4e14f569..5cec86fbb65 100644 --- a/sky/clouds/service_catalog/ibm_catalog.py +++ b/sky/clouds/service_catalog/ibm_catalog.py @@ -4,7 +4,7 @@ instance types and pricing information for IBM. """ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky import sky_logging from sky.adaptors import ibm @@ -43,7 +43,7 @@ def get_vcpus_mem_from_instance_type( def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/images/README.md b/sky/clouds/service_catalog/images/README.md index 31ce7c6d9ce..8f9d8e85f16 100644 --- a/sky/clouds/service_catalog/images/README.md +++ b/sky/clouds/service_catalog/images/README.md @@ -8,44 +8,71 @@ You only need to do this once. packer init plugins.pkr.hcl ``` 3. Setup cloud credentials +4. `cd sky/clouds/service_catalog/images` ## Generate Images -```bash -export CLOUD=gcp # Update this -export TYPE=gpu # Update this -export IMAGE=skypilot-${CLOUD}-${TYPE}-ubuntu -packer build ${IMAGE}.pkr.hcl -``` -You will see the image ID after the build is complete. - -FYI time to packer build an image: - +FYI time to packer build images: | Cloud | Type | Approx. Time | |-------|------|------------------------| | AWS | GPU | 15 min | | AWS | CPU | 10 min | | GCP | GPU | 16 min | | GCP | CPU | 5 min | +| Azure | GPU | 35 min | +| Azure | CPU | 25 min | ### GCP +1. Build a single global image. +```bash +export TYPE=cpu # Update this +export IMAGE=skypilot-gcp-${TYPE}-ubuntu +packer build ${IMAGE}.pkr.hcl +``` +2. Make the image public ```bash -export IMAGE_NAME=skypilot-gcp-cpu-ubuntu-20241011003407 # Update this - # Make image public +export IMAGE_NAME=skypilot-gcp-gpu-ubuntu-241030 # Update this export IMAGE_ID=projects/sky-dev-465/global/images/${IMAGE_NAME} gcloud compute images add-iam-policy-binding ${IMAGE_NAME} --member='allAuthenticatedUsers' --role='roles/compute.imageUser' ``` ### AWS -1. Generate images for all regions +1. Generate the source image for a single region. ```bash -export IMAGE_ID=ami-0b31b24524afa8e47 # Update this - +export TYPE=cpu # Update this +export IMAGE=skypilot-aws-${TYPE}-ubuntu +packer build ${IMAGE}.pkr.hcl +``` +2. Copy images to all regions +```bash +export TYPE=gpu # Update this +export IMAGE_ID=ami-0989556a89639b1bb # Update this python aws_utils/image_gen.py --image-id ${IMAGE_ID} --processor ${TYPE} ``` -2. Add fallback images if any region failed \ +3. Add fallback images if any region failed \ Look for "NEED_FALLBACK" in the output `images.csv` and edit. (You can use public [ubuntu images](https://cloud-images.ubuntu.com/locator/ec2/) as fallback.) +### Azure +1. Generate a client secret for packer [here](https://portal.azure.com/?feature.msaljs=true#view/Microsoft_AAD_RegisteredApps/ApplicationMenuBlade/~/Credentials/appId/1d249f23-c22e-4d02-b62b-a6827bd113fe/isMSAApp~/false). +```bash +export SECRET=xxxxxx # Update this +``` +2. Build and copy images for all regions for GPU (gen 1 & 2) and CPU (gen 2 only). +```bash +packer build --var vm_generation=2 --var client_secret=${SECRET} skypilot-azure-cpu-ubuntu.pkr.hcl +packer build --var vm_generation=2 --var client_secret=${SECRET} skypilot-azure-gpu-ubuntu.pkr.hcl +packer build --var vm_generation=1 --var client_secret=${SECRET} skypilot-azure-gpu-ubuntu.pkr.hcl +packer build --var vm_generation=2 --var client_secret=${SECRET} --var use_grid_driver=true skypilot-azure-gpu-ubuntu.pkr.hcl +``` + +### Kubernetes +1. Build the image +```bash +export REGION=europe # Update this: us, europe, asia +./skypilot-k8s-image.sh -p -l -r ${REGION} +./skypilot-k8s-image.sh -p -l -g -r ${REGION} +``` + ## Test Images 1. Minimal GPU test: `sky launch --image ${IMAGE_ID} --gpus=L4:1 --cloud ${CLOUD}` then run `nvidia-smi` in the launched instance. 2. Update the image ID in `sky/clouds/gcp.py` and run the test: @@ -60,13 +87,16 @@ pytest tests/test_smoke.py::test_cancel_gcp Submit a PR to update [`SkyPilot Catalog`](https://github.com/skypilot-org/skypilot-catalog/tree/master/catalogs) then clean up the old images to avoid extra iamge storage fees. ### GCP -1. Example PR: [#86](https://github.com/skypilot-org/skypilot-catalog/pull/86) -2. Go to console and delete old images. +1. Update Catalog with new images: [example PR](https://github.com/skypilot-org/skypilot-catalog/pull/86) +2. Go to [GCP console](https://console.cloud.google.com/compute/images?tab=images&project=sky-dev-465) and delete old images. ### AWS 1. Copy the old custom image rows from Catalog's existing `images.csv` to a local `images.csv` in this folder. -2. Update Catalog with new images. Example PR: [#89](https://github.com/skypilot-org/skypilot-catalog/pull/89) +2. Update Catalog with new images: [example PR](https://github.com/skypilot-org/skypilot-catalog/pull/89) 3. Delete AMIs across regions by running ```bash python aws_utils/image_delete.py --tag ${TAG} ``` + +### Azure +1. Update Catalog with new images: [example PR](https://github.com/skypilot-org/skypilot-catalog/pull/92) diff --git a/sky/clouds/service_catalog/images/aws_utils/image_gen.py b/sky/clouds/service_catalog/images/aws_utils/image_gen.py index cb39355ad2c..cadfee912a9 100644 --- a/sky/clouds/service_catalog/images/aws_utils/image_gen.py +++ b/sky/clouds/service_catalog/images/aws_utils/image_gen.py @@ -133,7 +133,7 @@ def process_region(copy_to_region): except Exception as e: print(f"Error generating image to {copy_to_region}: {str(e)}") new_image_id = 'NEED_FALLBACK' - image_cache.append((new_image_id, copy_to_region)) + image_cache.append((new_image_id, copy_to_region)) with concurrent.futures.ThreadPoolExecutor() as executor: executor.map(process_region, ALL_REGIONS) diff --git a/sky/clouds/service_catalog/images/provisioners/cloud.sh b/sky/clouds/service_catalog/images/provisioners/cloud.sh deleted file mode 100644 index b326c9fde51..00000000000 --- a/sky/clouds/service_catalog/images/provisioners/cloud.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python - -# TODO: keep this dependency installation align with utils/controller_utils.py and setup.py -install_azure() { - echo "Install cloud dependencies on controller: Azure" - $PYTHON_EXEC -m pip install "azure-cli>=2.31.0" azure-core "azure-identity>=1.13.0" azure-mgmt-network - $PYTHON_EXEC -m pip install azure-storage-blob msgraph-sdk -} - -install_gcp() { - echo "Install cloud dependencies on controller: GCP" - $PYTHON_EXEC -m pip install "google-api-python-client>=2.69.0" - $PYTHON_EXEC -m pip install google-cloud-storage - if ! gcloud --help > /dev/null 2>&1; then - pushd /tmp &>/dev/null - mkdir -p ~/.sky/logs - wget --quiet https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-424.0.0-linux-x86_64.tar.gz > ~/.sky/logs/gcloud_installation.log - tar xzf google-cloud-sdk-424.0.0-linux-x86_64.tar.gz >> ~/.sky/logs/gcloud_installation.log - rm -rf ~/google-cloud-sdk >> ~/.sky/logs/gcloud_installation.log - mv google-cloud-sdk ~/ - ~/google-cloud-sdk/install.sh -q >> ~/.sky/logs/gcloud_installation.log 2>&1 - echo "source ~/google-cloud-sdk/path.bash.inc > /dev/null 2>&1" >> ~/.bashrc - source ~/google-cloud-sdk/path.bash.inc >> ~/.sky/logs/gcloud_installation.log 2>&1 - popd &>/dev/null - fi -} - -install_aws() { - echo "Install cloud dependencies on controller: AWS" - $PYTHON_EXEC -m pip install botocore>=1.29.10 boto3>=1.26.1 - $PYTHON_EXEC -m pip install "urllib3<2" awscli>=1.27.10 "colorama<0.4.5" -} - -if [ "$CLOUD" = "azure" ]; then - install_azure -elif [ "$CLOUD" = "gcp" ]; then - install_gcp -elif [ "$CLOUD" = "aws" ]; then - install_aws -else - echo "Error: Unknown cloud $CLOUD so not installing any cloud dependencies." -fi - -if [ $? -eq 0 ]; then - echo "Successfully installed cloud dependencies on controller: $CLOUD" -else - echo "Error: Failed to install cloud dependencies on controller: $CLOUD" -fi diff --git a/sky/clouds/service_catalog/images/provisioners/cuda-azure-grid.sh b/sky/clouds/service_catalog/images/provisioners/cuda-azure-grid.sh new file mode 100644 index 00000000000..6177dfa5d53 --- /dev/null +++ b/sky/clouds/service_catalog/images/provisioners/cuda-azure-grid.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +sudo apt update +sudo apt install -y build-essential + +echo "Installing GRID driver..." +GRID_DRIVER_URL="https://download.microsoft.com/download/8/d/a/8da4fb8e-3a9b-4e6a-bc9a-72ff64d7a13c/NVIDIA-Linux-x86_64-535.161.08-grid-azure.run" +GRID_DRIVER_FILE="NVIDIA-Linux-x86_64-535.161.08-grid-azure.run" + +wget -nv $GRID_DRIVER_URL -O $GRID_DRIVER_FILE +sudo chmod +x $GRID_DRIVER_FILE +sudo sh $GRID_DRIVER_FILE --silent --disable-nouveau + +echo "Set vGPU Licensing Daemon config..." +sudo cp /etc/nvidia/gridd.conf.template /etc/nvidia/gridd.conf +sudo sed -i '/^FeatureType=0/s/^/# /' /etc/nvidia/gridd.conf +echo "IgnoreSP=FALSE" | sudo tee -a /etc/nvidia/gridd.conf +echo "EnableUI=FALSE" | sudo tee -a /etc/nvidia/gridd.conf + +echo "Installing CUDA toolkit..." +CUDA_TOOLKIT_URL="https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" +CUDA_TOOLKIT_FILE="cuda_12.2.0_535.54.03_linux.run" +wget -nv $CUDA_TOOLKIT_URL -O $CUDA_TOOLKIT_FILE +sudo sh $CUDA_TOOLKIT_FILE --silent --toolkit --override + +# Set environment variables +echo 'export PATH=$PATH:/usr/local/cuda-12.2/bin' >> ~/.bashrc +echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.2/lib64' >> ~/.bashrc +source ~/.bashrc + +# Verify installations +rm -f NVIDIA-Linux-x86_64-535.161.08-grid-azure.run cuda_12.2.0_535.54.03_linux.run +nvidia-smi diff --git a/sky/clouds/service_catalog/images/provisioners/skypilot.sh b/sky/clouds/service_catalog/images/provisioners/skypilot.sh index ff2aa06b2b6..cecb3664736 100644 --- a/sky/clouds/service_catalog/images/provisioners/skypilot.sh +++ b/sky/clouds/service_catalog/images/provisioners/skypilot.sh @@ -36,14 +36,12 @@ echo PATH=$PATH python3 -m venv ~/skypilot-runtime PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python -# Pip installs -$PYTHON_EXEC -m pip install "setuptools<70" -$PYTHON_EXEC -m pip install "grpcio!=1.48.0,<=1.51.3,>=1.42.0" -$PYTHON_EXEC -m pip install "skypilot-nightly" +# Install SkyPilot +$PYTHON_EXEC -m pip install "skypilot-nightly[remote]" -# Install ray +# Install Ray RAY_ADDRESS=127.0.0.1:6380 -$PYTHON_EXEC -m pip install --exists-action w -U ray[default]==2.9.3 +$PYTHON_EXEC -m pip install --exists-action w -U "ray[default]==2.9.3" export PATH=$PATH:$HOME/.local/bin source ~/skypilot-runtime/bin/activate which ray > ~/.sky/ray_path || exit 1 @@ -51,6 +49,18 @@ $PYTHON_EXEC -m pip list | grep "ray " | grep 2.9.3 2>&1 > /dev/null && { $PYTHON_EXEC -c "from sky.skylet.ray_patches import patch; patch()" || exit 1 } +# Install cloud dependencies +if [ "$CLOUD" = "azure" ]; then + $PYTHON_EXEC -m pip install "skypilot-nightly[azure]" +elif [ "$CLOUD" = "gcp" ]; then + # We don't have to install the google-cloud-sdk since it is installed by default in GCP machines. + $PYTHON_EXEC -m pip install "skypilot-nightly[gcp]" +elif [ "$CLOUD" = "aws" ]; then + $PYTHON_EXEC -m pip install "skypilot-nightly[aws]" +else + echo "Error: Unknown cloud $CLOUD so not installing any cloud dependencies." +fi + # System configurations sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf' sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf' @@ -67,3 +77,4 @@ sudo systemctl disable jupyter > /dev/null 2>&1 || true # Cleanup # Remove SkyPilot in OS image because when user sky launch we will install whatever version of SkyPilot user has on their local machine. $PYTHON_EXEC -m pip uninstall "skypilot-nightly" -y +rm -rf ~/.sky diff --git a/sky/clouds/service_catalog/images/provisioners/user-toolkit.sh b/sky/clouds/service_catalog/images/provisioners/user-toolkit.sh new file mode 100644 index 00000000000..cea02249938 --- /dev/null +++ b/sky/clouds/service_catalog/images/provisioners/user-toolkit.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# This script installs popular toolkits for users to use in the base environment. + +eval "$(~/miniconda3/bin/conda shell.bash hook)" +conda activate base +pip install numpy +pip install pandas + +if [ "$AZURE_GRID_DRIVER" = 1 ]; then + # Need PyTorch X.X.X+cu121 version to be compatible with older NVIDIA driver (535.161.08 or lower) + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 +fi diff --git a/sky/clouds/service_catalog/images/skypilot-aws-cpu-ubuntu.pkr.hcl b/sky/clouds/service_catalog/images/skypilot-aws-cpu-ubuntu.pkr.hcl index c21fbf51b20..370ea42e653 100644 --- a/sky/clouds/service_catalog/images/skypilot-aws-cpu-ubuntu.pkr.hcl +++ b/sky/clouds/service_catalog/images/skypilot-aws-cpu-ubuntu.pkr.hcl @@ -4,11 +4,11 @@ variable "region" { } locals { - timestamp = regex_replace(timestamp(), "[- TZ:]", "") + date = formatdate("YYMMDD", timestamp()) } source "amazon-ebs" "cpu-ubuntu" { - ami_name = "skypilot-aws-cpu-ubuntu-${local.timestamp}" + ami_name = "skypilot-aws-cpu-ubuntu-${local.date}" instance_type = "t2.micro" region = var.region ssh_username = "ubuntu" @@ -22,9 +22,9 @@ source "amazon-ebs" "cpu-ubuntu" { owners = ["099720109477"] } launch_block_device_mappings { - device_name = "/dev/sda1" - volume_size = 8 - volume_type = "gp2" + device_name = "/dev/sda1" + volume_size = 8 + volume_type = "gp2" delete_on_termination = true } } @@ -35,13 +35,13 @@ build { provisioner "shell" { script = "./provisioners/docker.sh" } - provisioner "shell" { - script = "./provisioners/skypilot.sh" - } provisioner "shell" { environment_vars = [ "CLOUD=aws", ] - script = "./provisioners/cloud.sh" + script = "./provisioners/skypilot.sh" + } + provisioner "shell" { + script = "./provisioners/user-toolkit.sh" } } diff --git a/sky/clouds/service_catalog/images/skypilot-aws-gpu-ubuntu.pkr.hcl b/sky/clouds/service_catalog/images/skypilot-aws-gpu-ubuntu.pkr.hcl index c4a8efac4dc..22f78cd88d0 100644 --- a/sky/clouds/service_catalog/images/skypilot-aws-gpu-ubuntu.pkr.hcl +++ b/sky/clouds/service_catalog/images/skypilot-aws-gpu-ubuntu.pkr.hcl @@ -4,11 +4,11 @@ variable "region" { } locals { - timestamp = regex_replace(timestamp(), "[- TZ:]", "") + date = formatdate("YYMMDD", timestamp()) } source "amazon-ebs" "gpu-ubuntu" { - ami_name = "skypilot-aws-gpu-ubuntu-${local.timestamp}" + ami_name = "skypilot-aws-gpu-ubuntu-${local.date}" instance_type = "g6.xlarge" region = var.region ssh_username = "ubuntu" @@ -22,9 +22,9 @@ source "amazon-ebs" "gpu-ubuntu" { owners = ["099720109477"] } launch_block_device_mappings { - device_name = "/dev/sda1" - volume_size = 30 - volume_type = "gp2" + device_name = "/dev/sda1" + volume_size = 30 + volume_type = "gp2" delete_on_termination = true } } @@ -43,13 +43,13 @@ build { provisioner "shell" { script = "./provisioners/nvidia-container-toolkit.sh" } - provisioner "shell" { - script = "./provisioners/skypilot.sh" - } provisioner "shell" { environment_vars = [ "CLOUD=aws", ] - script = "./provisioners/cloud.sh" + script = "./provisioners/skypilot.sh" + } + provisioner "shell" { + script = "./provisioners/user-toolkit.sh" } } diff --git a/sky/clouds/service_catalog/images/skypilot-azure-cpu-ubuntu.pkr.hcl b/sky/clouds/service_catalog/images/skypilot-azure-cpu-ubuntu.pkr.hcl new file mode 100644 index 00000000000..1b41f8a029c --- /dev/null +++ b/sky/clouds/service_catalog/images/skypilot-azure-cpu-ubuntu.pkr.hcl @@ -0,0 +1,71 @@ +variable "client_secret" { + type = string + description = "The client secret for the packer client registered in Azure (see Azure app registration)" +} + +variable "vm_generation" { + type = number + description = "Azure's VM generation, currently support 1 or 2" +} + +locals { + date = formatdate("YYMMDD", timestamp()) + version = formatdate("YY.MM.DD", timestamp()) +} + +source "azure-arm" "cpu-ubuntu" { + managed_image_resource_group_name = "skypilot-images" + + subscription_id = "59d8c23c-7ef5-42c7-b2f3-a919ad8026a7" + tenant_id = "7c81f068-46f8-4b26-9a46-2fbec2287e3d" + client_id = "1d249f23-c22e-4d02-b62b-a6827bd113fe" + client_secret = var.client_secret + + os_type = "Linux" + image_publisher = "Canonical" + image_offer = "0001-com-ubuntu-server-jammy" + image_sku = var.vm_generation == 1 ? "22_04-lts" : "22_04-lts-gen2" + location = "centralus" + vm_size = var.vm_generation == 1 ? "Standard_D1_v2" : "Standard_B2s" + ssh_username = "azureuser" + azure_tags = { + Created_by = "packer" + Purpose = "skypilot" + } + + shared_image_gallery_destination { + subscription = "59d8c23c-7ef5-42c7-b2f3-a919ad8026a7" + resource_group = "skypilot-images" + gallery_name = "skypilot_image_gallery" + image_name = "skypilot-cpu-gen${var.vm_generation}" + image_version = "${local.version}" + replication_regions = [ + "centralus", + "eastus", + "eastus2", + "northcentralus", + "southcentralus", + "westcentralus", + "westus", + "westus2", + "westus3" + ] + } +} + +build { + name = "azure-cpu-ubuntu-build" + sources = ["sources.azure-arm.cpu-ubuntu"] + provisioner "shell" { + script = "./provisioners/docker.sh" + } + provisioner "shell" { + environment_vars = [ + "CLOUD=azure", + ] + script = "./provisioners/skypilot.sh" + } + provisioner "shell" { + script = "./provisioners/user-toolkit.sh" + } +} diff --git a/sky/clouds/service_catalog/images/skypilot-azure-gpu-ubuntu.pkr.hcl b/sky/clouds/service_catalog/images/skypilot-azure-gpu-ubuntu.pkr.hcl new file mode 100644 index 00000000000..a68708cc66b --- /dev/null +++ b/sky/clouds/service_catalog/images/skypilot-azure-gpu-ubuntu.pkr.hcl @@ -0,0 +1,86 @@ +variable "client_secret" { + type = string + description = "The client secret for the packer client registered in Azure (see Azure app registration)" +} + +variable "vm_generation" { + type = number + description = "Azure's VM generation, currently support 1 or 2" +} + +variable "use_grid_driver" { + type = bool + default = false + description = "Whether to use the Azure GRID driver. Currently only A10 GPU VMs need this." +} + +locals { + date = formatdate("YYMMDD", timestamp()) + version = formatdate("YY.MM.DD", timestamp()) +} + +source "azure-arm" "gpu-ubuntu" { + managed_image_resource_group_name = "skypilot-images" + + subscription_id = "59d8c23c-7ef5-42c7-b2f3-a919ad8026a7" + tenant_id = "7c81f068-46f8-4b26-9a46-2fbec2287e3d" + client_id = "1d249f23-c22e-4d02-b62b-a6827bd113fe" + client_secret = var.client_secret + + os_type = "Linux" + image_publisher = "Canonical" + image_offer = "0001-com-ubuntu-server-jammy" + image_sku = var.vm_generation == 1 ? "22_04-lts" : "22_04-lts-gen2" + location = var.use_grid_driver || var.vm_generation == 1 ? "eastus" : "centralus" + vm_size = var.use_grid_driver ? "Standard_NV12ads_A10_v5" : (var.vm_generation == 1 ? "Standard_NC4as_T4_v3" : "Standard_NC24ads_A100_v4") + ssh_username = "azureuser" + azure_tags = { + Created_by = "packer" + Purpose = "skypilot" + } + + shared_image_gallery_destination { + subscription = "59d8c23c-7ef5-42c7-b2f3-a919ad8026a7" + resource_group = "skypilot-images" + gallery_name = var.use_grid_driver || var.vm_generation == 1 ? "skypilot_images" : "skypilot_image_gallery" + image_name = var.use_grid_driver ? "skypilot-gpu-gen2-grid" : "skypilot-gpu-gen${var.vm_generation}" + image_version = "${local.version}" + replication_regions = [ + "centralus", + "eastus", + "eastus2", + "northcentralus", + "southcentralus", + "westcentralus", + "westus", + "westus2", + "westus3" + ] + } +} + +build { + name = "azure-gpu-ubuntu-build" + sources = ["sources.azure-arm.gpu-ubuntu"] + provisioner "shell" { + script = "./provisioners/docker.sh" + } + provisioner "shell" { + script = var.use_grid_driver ? "./provisioners/cuda-azure-grid.sh" : "./provisioners/cuda.sh" + } + provisioner "shell" { + script = "./provisioners/nvidia-container-toolkit.sh" + } + provisioner "shell" { + environment_vars = [ + "CLOUD=azure", + ] + script = "./provisioners/skypilot.sh" + } + provisioner "shell" { + environment_vars = [ + var.use_grid_driver ? "AZURE_GRID_DRIVER=1" : "AZURE_GRID_DRIVER=0", + ] + script = "./provisioners/user-toolkit.sh" + } +} diff --git a/sky/clouds/service_catalog/images/skypilot-gcp-cpu-ubuntu.pkr.hcl b/sky/clouds/service_catalog/images/skypilot-gcp-cpu-ubuntu.pkr.hcl index bf3af0519e4..2ddd836d886 100644 --- a/sky/clouds/service_catalog/images/skypilot-gcp-cpu-ubuntu.pkr.hcl +++ b/sky/clouds/service_catalog/images/skypilot-gcp-cpu-ubuntu.pkr.hcl @@ -1,11 +1,11 @@ locals { - timestamp = regex_replace(timestamp(), "[- TZ:]", "") + date = formatdate("YYMMDD", timestamp()) } source "googlecompute" "cpu-ubuntu" { project_id = "sky-dev-465" - image_name = "skypilot-gcp-cpu-ubuntu-${local.timestamp}" + image_name = "skypilot-gcp-cpu-ubuntu-${local.date}" source_image_family = "ubuntu-2204-lts" zone = "us-west1-a" image_description = "SkyPilot custom image for launching GCP CPU instances." @@ -21,13 +21,13 @@ build { provisioner "shell" { script = "./provisioners/docker.sh" } - provisioner "shell" { - script = "./provisioners/skypilot.sh" - } provisioner "shell" { environment_vars = [ "CLOUD=gcp", ] - script = "./provisioners/cloud.sh" + script = "./provisioners/skypilot.sh" + } + provisioner "shell" { + script = "./provisioners/user-toolkit.sh" } } diff --git a/sky/clouds/service_catalog/images/skypilot-gcp-gpu-ubuntu.pkr.hcl b/sky/clouds/service_catalog/images/skypilot-gcp-gpu-ubuntu.pkr.hcl index f46d414493b..a4799c3025b 100644 --- a/sky/clouds/service_catalog/images/skypilot-gcp-gpu-ubuntu.pkr.hcl +++ b/sky/clouds/service_catalog/images/skypilot-gcp-gpu-ubuntu.pkr.hcl @@ -4,11 +4,11 @@ variable "zone" { } locals { - timestamp = regex_replace(timestamp(), "[- TZ:]", "") + date = formatdate("YYMMDD", timestamp()) } source "googlecompute" "gpu-ubuntu" { - image_name = "skypilot-gcp-gpu-ubuntu-${local.timestamp}" + image_name = "skypilot-gcp-gpu-ubuntu-${local.date}" project_id = "sky-dev-465" source_image_family = "ubuntu-2204-lts" zone = var.zone @@ -34,13 +34,13 @@ build { provisioner "shell" { script = "./provisioners/nvidia-container-toolkit.sh" } - provisioner "shell" { - script = "./provisioners/skypilot.sh" - } provisioner "shell" { environment_vars = [ "CLOUD=gcp", ] - script = "./provisioners/cloud.sh" + script = "./provisioners/skypilot.sh" + } + provisioner "shell" { + script = "./provisioners/user-toolkit.sh" } } diff --git a/tests/kubernetes/build_image.sh b/sky/clouds/service_catalog/images/skypilot-k8s-image.sh similarity index 78% rename from tests/kubernetes/build_image.sh rename to sky/clouds/service_catalog/images/skypilot-k8s-image.sh index 54d1d356b59..075aa7ae4bc 100755 --- a/tests/kubernetes/build_image.sh +++ b/sky/clouds/service_catalog/images/skypilot-k8s-image.sh @@ -1,38 +1,38 @@ #!/bin/bash # Builds the Dockerfile_k8s image as the SkyPilot image. -# Optionally, if -p is specified, pushes the image to the registry. # Uses buildx to build the image for both amd64 and arm64. -# If -p flag is specified, pushes the image to the registry. -# If -g flag is specified, builds the GPU image in Dockerfile_k8s_gpu. GPU image is built only for amd64. -# If -l flag is specified, uses the latest tag instead of the date tag. Date tag is of the form YYYYMMDD. -# Usage: ./build_image.sh [-p] [-g] +# Usage: ./skypilot-k8s-image.sh [-p] [-g] [-l] [-r region] # -p: Push the image to the registry -# -g: Build the GPU image -# -l: Use latest tag - -TAG=us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot - +# -g: Builds the GPU image in Dockerfile_k8s_gpu. GPU image is built only for amd64 +# -l: Use latest tag instead of the date tag. Date tag is of the form YYYYMMDD +# -r: Specify the region to be us, europe or asia +region=us push=false gpu=false latest=false # Parse command line arguments -while getopts ":pgl" opt; do +OPTSTRING=":pglr:" +while getopts ${OPTSTRING} opt; do case ${opt} in - p ) + p) push=true ;; - g ) + g) gpu=true ;; - l ) + l) latest=true ;; - \? ) - echo "Usage: ./build_image.sh [-p] [-g] [-l]" + r) + region=${OPTARG} + ;; + ?) + echo "Usage: ./build_image.sh [-p] [-g] [-l] [-r region]" echo "-p: Push the image to the registry" echo "-g: Build the GPU image" echo "-l: Use latest tag instead of the date tag" + echo "-r: Specify the region to be us, europe or asia" exit 1 ;; esac @@ -42,6 +42,9 @@ echo "Options:" echo "Push: $push" echo "GPU: $gpu" echo "Latest: $latest" +echo "Region: $region" + +TAG=$region-docker.pkg.dev/sky-dev-465/skypilotk8s/skypilot # Set the version tag. If the latest flag is used, use the latest tag if [[ $latest == "true" ]]; then diff --git a/sky/clouds/service_catalog/kubernetes_catalog.py b/sky/clouds/service_catalog/kubernetes_catalog.py index 24daeabf9d4..7ff8f49c621 100644 --- a/sky/clouds/service_catalog/kubernetes_catalog.py +++ b/sky/clouds/service_catalog/kubernetes_catalog.py @@ -8,12 +8,15 @@ from typing import Dict, List, Optional, Set, Tuple from sky import check as sky_check +from sky import sky_logging from sky.adaptors import common as adaptors_common from sky.clouds import Kubernetes from sky.clouds.service_catalog import CloudFilter from sky.clouds.service_catalog import common from sky.provision.kubernetes import utils as kubernetes_utils +logger = sky_logging.init_logger(__name__) + if typing.TYPE_CHECKING: import pandas as pd else: @@ -31,7 +34,16 @@ def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]: """Returns the image id from the tag.""" - return common.get_image_id_from_tag_impl(_image_df, tag, region) + global _image_df + image_id = common.get_image_id_from_tag_impl(_image_df, tag, region) + if image_id is None: + # Refresh the image catalog and try again, if the image tag is not + # found. + logger.debug('Refreshing the image catalog and trying again.') + _image_df = common.read_catalog('kubernetes/images.csv', + pull_frequency_hours=0) + image_id = common.get_image_id_from_tag_impl(_image_df, tag, region) + return image_id def is_image_tag_valid(tag: str, region: Optional[str]) -> bool: @@ -120,8 +132,14 @@ def list_accelerators_realtime( # Generate the GPU quantities for the accelerators if accelerator_name and accelerator_count > 0: - for count in range(1, accelerator_count + 1): + count = 1 + while count <= accelerator_count: accelerators_qtys.add((accelerator_name, count)) + count *= 2 + # Add the accelerator count if it's not already in the set + # (e.g., if there's 12 GPUs, we should have qtys 1, 2, 4, 8, 12) + if accelerator_count not in accelerators_qtys: + accelerators_qtys.add((accelerator_name, accelerator_count)) for pod in pods: # Get all the pods running on the node diff --git a/sky/clouds/service_catalog/lambda_catalog.py b/sky/clouds/service_catalog/lambda_catalog.py index e843ab72cc0..24cb4064d54 100644 --- a/sky/clouds/service_catalog/lambda_catalog.py +++ b/sky/clouds/service_catalog/lambda_catalog.py @@ -4,7 +4,7 @@ instance types and pricing information for Lambda. """ import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.clouds.service_catalog import common from sky.utils import resources_utils @@ -72,7 +72,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/oci_catalog.py b/sky/clouds/service_catalog/oci_catalog.py index 47d0489f6ab..c8e475df871 100644 --- a/sky/clouds/service_catalog/oci_catalog.py +++ b/sky/clouds/service_catalog/oci_catalog.py @@ -14,7 +14,7 @@ import logging import threading import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.adaptors import oci as oci_adaptor from sky.clouds import OCI @@ -131,7 +131,7 @@ def _filter_disk_type(instance_type: str) -> bool: def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl( _get_df(), instance_type) diff --git a/sky/clouds/service_catalog/paperspace_catalog.py b/sky/clouds/service_catalog/paperspace_catalog.py index 1eb635c93e5..49948b219a1 100644 --- a/sky/clouds/service_catalog/paperspace_catalog.py +++ b/sky/clouds/service_catalog/paperspace_catalog.py @@ -5,7 +5,7 @@ """ import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.clouds.service_catalog import common from sky.utils import ux_utils @@ -60,7 +60,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py index 2d3ed44307b..7fbc46206ed 100644 --- a/sky/clouds/service_catalog/runpod_catalog.py +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -5,7 +5,7 @@ """ import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.clouds.service_catalog import common from sky.utils import ux_utils @@ -56,7 +56,7 @@ def get_default_instance_type(cpus: Optional[str] = None, def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/scp_catalog.py b/sky/clouds/service_catalog/scp_catalog.py index 209bb4cf631..e4773ab3250 100644 --- a/sky/clouds/service_catalog/scp_catalog.py +++ b/sky/clouds/service_catalog/scp_catalog.py @@ -5,7 +5,7 @@ """ import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.clouds.service_catalog import common from sky.utils import resources_utils @@ -67,7 +67,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl(_df, instance_type) diff --git a/sky/clouds/service_catalog/vsphere_catalog.py b/sky/clouds/service_catalog/vsphere_catalog.py index e1199d3d266..74fb2fbe60d 100644 --- a/sky/clouds/service_catalog/vsphere_catalog.py +++ b/sky/clouds/service_catalog/vsphere_catalog.py @@ -2,7 +2,7 @@ import io import os import typing -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from sky.adaptors import common as adaptors_common from sky.clouds.service_catalog import common @@ -85,7 +85,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( - instance_type: str) -> Optional[Dict[str, int]]: + instance_type: str) -> Optional[Dict[str, Union[int, float]]]: return common.get_accelerators_from_instance_type_impl( _get_df(), instance_type) diff --git a/sky/clouds/utils/azure_utils.py b/sky/clouds/utils/azure_utils.py new file mode 100644 index 00000000000..83b86f4d54f --- /dev/null +++ b/sky/clouds/utils/azure_utils.py @@ -0,0 +1,91 @@ +"""Utilies for Azure""" + +import typing + +from sky import exceptions +from sky.adaptors import azure +from sky.utils import ux_utils + +if typing.TYPE_CHECKING: + from azure.mgmt import compute as azure_compute + from azure.mgmt.compute import models as azure_compute_models + + +def validate_image_id(image_id: str): + """Check if the image ID has a valid format. + + Raises: + ValueError: If the image ID is invalid. + """ + image_id_colon_splitted = image_id.split(':') + image_id_slash_splitted = image_id.split('/') + if len(image_id_slash_splitted) != 5 and len(image_id_colon_splitted) != 4: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Invalid image id for Azure: {image_id}. Expected format: \n' + '* Marketplace image ID: :::\n' + '* Community image ID: ' + '/CommunityGalleries//Images/') + if len(image_id_slash_splitted) == 5: + _, gallery_type, _, image_type, _ = image_id.split('/') + if gallery_type != 'CommunityGalleries' or image_type != 'Images': + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Invalid community image id for Azure: {image_id}.\n' + 'Expected format: ' + '/CommunityGalleries//Images/') + + +def get_community_image( + compute_client: 'azure_compute.ComputeManagementClient', image_id: str, + region: str) -> 'azure_compute_models.CommunityGalleryImage': + """Get community image from cloud. + + Args: + image_id: /CommunityGalleries//Images/ + Raises: + ResourcesUnavailableError + """ + try: + _, _, gallery_name, _, image_name = image_id.split('/') + return compute_client.community_gallery_images.get( + location=region, + public_gallery_name=gallery_name, + gallery_image_name=image_name) + except azure.exceptions().AzureError as e: + raise exceptions.ResourcesUnavailableError( + f'Community image {image_id} does not exist in region {region}.' + ) from e + + +def get_community_image_size( + compute_client: 'azure_compute.ComputeManagementClient', + gallery_name: str, image_name: str, region: str) -> float: + """Get the size of the community image from cloud. + + Args: + image_id: /CommunityGalleries//Images/ + Raises: + ResourcesUnavailableError + """ + try: + image_versions = compute_client.community_gallery_image_versions.list( + location=region, + public_gallery_name=gallery_name, + gallery_image_name=image_name, + ) + image_versions = list(image_versions) + if not image_versions: + raise exceptions.ResourcesUnavailableError( + f'No versions available for Azure community image {image_name}') + latest_version = image_versions[-1].name + + image_details = compute_client.community_gallery_image_versions.get( + location=region, + public_gallery_name=gallery_name, + gallery_image_name=image_name, + gallery_image_version_name=latest_version) + return image_details.storage_profile.os_disk_image.disk_size_gb + except azure.exceptions().AzureError as e: + raise exceptions.ResourcesUnavailableError( + f'Failed to get community image size: {e}.') from e diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index 68e6192d351..cfb893c8cb4 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -49,14 +49,6 @@ def is_tpu_vm_pod(resources: Optional['resources_lib.Resources']) -> bool: return not acc.endswith('-8') -def get_num_tpu_devices(resources: Optional['resources_lib.Resources']) -> int: - if resources is None or not is_tpu(resources): - raise ValueError('resources must be a valid TPU resource.') - acc, _ = list(resources.accelerators.items())[0] - num_tpu_devices = int(int(acc.split('-')[2]) / 8) - return num_tpu_devices - - @dataclasses.dataclass class SpecificReservation: count: int diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index 7cf56b46a8d..88d5df3232a 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -1,8 +1,7 @@ """Vsphere cloud implementation.""" -import json import subprocess import typing -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Union import requests @@ -152,7 +151,7 @@ def get_default_instance_type( def get_accelerators_from_instance_type( cls, instance_type: str, - ) -> Optional[Dict[str, int]]: + ) -> Optional[Dict[str, Union[int, float]]]: return service_catalog.get_accelerators_from_instance_type( instance_type, clouds=_CLOUD_VSPHERE) @@ -182,10 +181,8 @@ def make_deploy_resources_variables( zone_names = [zone.name for zone in zones] r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) - if acc_dict is not None: - custom_resources = json.dumps(acc_dict, separators=(',', ':')) - else: - custom_resources = None + custom_resources = resources_utils.make_ray_custom_resources_str( + acc_dict) return { 'instance_type': resources.instance_type, diff --git a/sky/data/storage.py b/sky/data/storage.py index 6fbb95a8c56..897f2f96b94 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -1082,16 +1082,31 @@ class S3Store(AbstractStore): for S3 buckets. """ + _DEFAULT_REGION = 'us-east-1' _ACCESS_DENIED_MESSAGE = 'Access Denied' + _CUSTOM_ENDPOINT_REGIONS = [ + 'ap-east-1', 'me-south-1', 'af-south-1', 'eu-south-1', 'eu-south-2', + 'ap-south-2', 'ap-southeast-3', 'ap-southeast-4', 'me-central-1', + 'il-central-1' + ] def __init__(self, name: str, source: str, - region: Optional[str] = 'us-east-2', + region: Optional[str] = _DEFAULT_REGION, is_sky_managed: Optional[bool] = None, sync_on_reconstruction: bool = True): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' + # TODO(romilb): This is purely a stopgap fix for + # https://github.com/skypilot-org/skypilot/issues/3405 + # We should eventually make all opt-in regions also work for S3 by + # passing the right endpoint flags. + if region in self._CUSTOM_ENDPOINT_REGIONS: + logger.warning('AWS opt-in regions are not supported for S3. ' + f'Falling back to default region ' + f'{self._DEFAULT_REGION} for bucket {name!r}.') + region = self._DEFAULT_REGION super().__init__(name, source, region, is_sky_managed, sync_on_reconstruction) @@ -1424,7 +1439,7 @@ def mount_command(self, mount_path: str) -> str: def _create_s3_bucket(self, bucket_name: str, - region='us-east-2') -> StorageHandle: + region=_DEFAULT_REGION) -> StorageHandle: """Creates S3 bucket with specific name in specific region Args: diff --git a/sky/exceptions.py b/sky/exceptions.py index 066d36c3cf3..f78c6605261 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -1,7 +1,7 @@ """Exceptions.""" import enum import typing -from typing import List, Optional +from typing import List, Optional, Sequence if typing.TYPE_CHECKING: from sky import status_lib @@ -61,12 +61,12 @@ class ProvisionPrechecksError(Exception): the error will be raised. Args: - reasons: (List[Exception]) The reasons why the prechecks failed. + reasons: (Sequence[Exception]) The reasons why the prechecks failed. """ - def __init__(self, reasons: List[Exception]) -> None: + def __init__(self, reasons: Sequence[Exception]) -> None: super().__init__() - self.reasons = list(reasons) + self.reasons = reasons class ManagedJobReachedMaxRetriesError(Exception): diff --git a/sky/execution.py b/sky/execution.py index d9a346a99cf..8fab5e583fb 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -11,6 +11,7 @@ from sky import admin_policy from sky import backends from sky import clouds +from sky import exceptions from sky import global_user_state from sky import optimizer from sky import sky_logging @@ -171,10 +172,11 @@ def _execute( task = dag.tasks[0] if any(r.job_recovery is not None for r in task.resources): - with ux_utils.print_exception_no_traceback(): - raise ValueError( - 'Job recovery is specified in the task. To launch a ' - 'managed job, please use: sky jobs launch') + logger.warning( + f'{colorama.Style.DIM}The task has `job_recovery` specified, ' + 'but is launched as an unmanaged job. It will be ignored.' + 'To enable job recovery, use managed jobs: sky jobs launch.' + f'{colorama.Style.RESET_ALL}') cluster_exists = False if cluster_name is not None: @@ -215,7 +217,8 @@ def _execute( '(after all jobs finish).' f'{colorama.Style.RESET_ALL}') idle_minutes_to_autostop = 1 - stages.remove(Stage.DOWN) + if Stage.DOWN in stages: + stages.remove(Stage.DOWN) if idle_minutes_to_autostop >= 0: requested_features.add( clouds.CloudImplementationFeatures.AUTO_TERMINATE) @@ -354,6 +357,7 @@ def launch( detach_run: bool = False, no_setup: bool = False, clone_disk_from: Optional[str] = None, + fast: bool = False, # Internal only: # pylint: disable=invalid-name _is_launched_by_jobs_controller: bool = False, @@ -408,6 +412,8 @@ def launch( clone_disk_from: [Experimental] if set, clone the disk from the specified cluster. This is useful to migrate the cluster to a different availability zone or region. + fast: [Experimental] If the cluster is already up and available, + skip provisioning and setup steps. Example: .. code-block:: python @@ -451,15 +457,43 @@ def launch( controller_utils.check_cluster_name_not_controller( cluster_name, operation_str='sky.launch') + handle = None + stages = None + # Check if cluster exists and we are doing fast provisioning + if fast and cluster_name is not None: + maybe_handle = global_user_state.get_handle_from_cluster_name( + cluster_name) + if maybe_handle is not None: + try: + # This will throw if the cluster is not available + backend_utils.check_cluster_available( + cluster_name, + operation='executing tasks', + check_cloud_vm_ray_backend=False, + dryrun=dryrun) + handle = maybe_handle + # Get all stages + stages = [ + Stage.SYNC_WORKDIR, + Stage.SYNC_FILE_MOUNTS, + Stage.PRE_EXEC, + Stage.EXEC, + Stage.DOWN, + ] + except exceptions.ClusterNotUpError: + # Proceed with normal provisioning + pass + return _execute( entrypoint=entrypoint, dryrun=dryrun, down=down, stream_logs=stream_logs, - handle=None, + handle=handle, backend=backend, retry_until_up=retry_until_up, optimize_target=optimize_target, + stages=stages, cluster_name=cluster_name, detach_setup=detach_setup, detach_run=detach_run, diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index cd6664eb114..fca60750af9 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -182,6 +182,11 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: if task_id == 0: submitted_at = backend_utils.get_timestamp_from_run_timestamp( self._backend.run_timestamp) + assert task.name is not None, task + cluster_name = managed_job_utils.generate_managed_job_cluster_name( + task.name, self._job_id) + strategy_executor = recovery_strategy.StrategyExecutor.make( + cluster_name, self._backend, task, self._retry_until_up) managed_job_state.set_submitted( self._job_id, task_id, @@ -189,15 +194,14 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: submitted_at, resources_str=backend_utils.get_task_resources_str( task, is_managed_job=True), + specs={ + 'max_restarts_on_errors': + strategy_executor.max_restarts_on_errors + }, callback_func=callback_func) logger.info( f'Submitted managed job {self._job_id} (task: {task_id}, name: ' f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}') - assert task.name is not None, task - cluster_name = managed_job_utils.generate_managed_job_cluster_name( - task.name, self._job_id) - strategy_executor = recovery_strategy.StrategyExecutor.make( - cluster_name, self._backend, task, self._retry_until_up) logger.info('Started monitoring.') managed_job_state.set_starting(job_id=self._job_id, @@ -237,7 +241,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: end_time=end_time, callback_func=callback_func) logger.info( - f'Spot job {self._job_id} (task: {task_id}) SUCCEEDED. ' + f'Managed job {self._job_id} (task: {task_id}) SUCCEEDED. ' f'Cleaning up the cluster {cluster_name}.') # Only clean up the cluster, not the storages, because tasks may # share storages. @@ -305,23 +309,35 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: failure_reason = ( 'To see the details, run: ' f'sky jobs logs --controller {self._job_id}') - - managed_job_state.set_failed( - self._job_id, - task_id, - failure_type=managed_job_status, - failure_reason=failure_reason, - end_time=end_time, - callback_func=callback_func) - return False - # Although the cluster is healthy, we fail to access the - # job status. Try to recover the job (will not restart the - # cluster, if the cluster is healthy). - assert job_status is None, job_status - logger.info('Failed to fetch the job status while the ' - 'cluster is healthy. Try to recover the job ' - '(the cluster will not be restarted).') - + should_restart_on_failure = ( + strategy_executor.should_restart_on_failure()) + if should_restart_on_failure: + max_restarts = ( + strategy_executor.max_restarts_on_errors) + logger.info( + f'User program crashed ' + f'({managed_job_status.value}). ' + f'Retry the job as max_restarts_on_errors is ' + f'set to {max_restarts}. ' + f'[{strategy_executor.restart_cnt_on_failure}' + f'/{max_restarts}]') + else: + managed_job_state.set_failed( + self._job_id, + task_id, + failure_type=managed_job_status, + failure_reason=failure_reason, + end_time=end_time, + callback_func=callback_func) + return False + else: + # Although the cluster is healthy, we fail to access the + # job status. Try to recover the job (will not restart the + # cluster, if the cluster is healthy). + assert job_status is None, job_status + logger.info('Failed to fetch the job status while the ' + 'cluster is healthy. Try to recover the job ' + '(the cluster will not be restarted).') # When the handle is None, the cluster should be cleaned up already. if handle is not None: resources = handle.launched_resources @@ -375,13 +391,12 @@ def _handle_future_completion(self, future: futures.Future, task_id: int): task_id, managed_job_state.ManagedJobStatus.FAILED_PRECHECKS, failure_reason) except exceptions.ManagedJobReachedMaxRetriesError as e: - # Please refer to the docstring of self._run for - # the cases when this exception can occur. + # Please refer to the docstring of self._run for the cases when + # this exception can occur. failure_reason = common_utils.format_exception(e) logger.error(failure_reason) - # The managed job should be marked as - # FAILED_NO_RESOURCE, as the managed job may be able to - # launch next time. + # The managed job should be marked as FAILED_NO_RESOURCE, as the + # managed job may be able to launch next time. self._update_failed_task_state( task_id, managed_job_state.ManagedJobStatus.FAILED_NO_RESOURCE, failure_reason) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 4ac0d5a78d1..534b0ae113a 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -36,6 +36,7 @@ def launch( stream_logs: bool = True, detach_run: bool = False, retry_until_up: bool = False, + fast: bool = False, ) -> None: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Launch a managed job. @@ -47,6 +48,9 @@ def launch( managed job. name: Name of the managed job. detach_run: Whether to detach the run. + fast: Whether to use sky.launch(fast=True) for the jobs controller. If + True, the SkyPilot wheel and the cloud credentials may not be updated + on the jobs controller. Raises: ValueError: cluster does not exist. Or, the entrypoint is not a valid @@ -141,6 +145,7 @@ def launch( idle_minutes_to_autostop=skylet_constants. CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP, retry_until_up=True, + fast=fast, _disable_controller_check=True) diff --git a/sky/jobs/recovery_strategy.py b/sky/jobs/recovery_strategy.py index e2e7c8c8f11..0b332cc983d 100644 --- a/sky/jobs/recovery_strategy.py +++ b/sky/jobs/recovery_strategy.py @@ -66,7 +66,8 @@ class StrategyExecutor: RETRY_INIT_GAP_SECONDS = 60 def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool) -> None: + task: 'task_lib.Task', retry_until_up: bool, + max_restarts_on_errors: int) -> None: """Initialize the strategy executor. Args: @@ -82,6 +83,8 @@ def __init__(self, cluster_name: str, backend: 'backends.Backend', self.cluster_name = cluster_name self.backend = backend self.retry_until_up = retry_until_up + self.max_restarts_on_errors = max_restarts_on_errors + self.restart_cnt_on_failure = 0 def __init_subclass__(cls, name: str, default: bool = False): RECOVERY_STRATEGIES[name] = cls @@ -109,8 +112,17 @@ def make(cls, cluster_name: str, backend: 'backends.Backend', # set the new_task_resources to be the same type (list or set) as the # original task.resources task.set_resources(type(task.resources)(new_resources_list)) - return RECOVERY_STRATEGIES[job_recovery](cluster_name, backend, task, - retry_until_up) + if isinstance(job_recovery, dict): + job_recovery_name = job_recovery.pop('strategy', + DEFAULT_RECOVERY_STRATEGY) + max_restarts_on_errors = job_recovery.pop('max_restarts_on_errors', + 0) + else: + job_recovery_name = job_recovery + max_restarts_on_errors = 0 + return RECOVERY_STRATEGIES[job_recovery_name](cluster_name, backend, + task, retry_until_up, + max_restarts_on_errors) def launch(self) -> float: """Launch the cluster for the first time. @@ -327,8 +339,7 @@ def _launch(self, 'Failure happened before provisioning. Failover ' f'reasons: {reasons_str}') if raise_on_failure: - raise exceptions.ProvisionPrechecksError( - reasons=reasons) + raise exceptions.ProvisionPrechecksError(reasons) return None logger.info('Failed to launch a cluster with error: ' f'{common_utils.format_exception(e)})') @@ -368,6 +379,17 @@ def _launch(self, f'{gap_seconds:.1f} seconds.') time.sleep(gap_seconds) + def should_restart_on_failure(self) -> bool: + """Increments counter & checks if job should be restarted on a failure. + + Returns: + True if the job should be restarted, otherwise False. + """ + self.restart_cnt_on_failure += 1 + if self.restart_cnt_on_failure > self.max_restarts_on_errors: + return False + return True + class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER', default=False): @@ -376,8 +398,10 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER', _MAX_RETRY_CNT = 240 # Retry for 4 hours. def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool) -> None: - super().__init__(cluster_name, backend, task, retry_until_up) + task: 'task_lib.Task', retry_until_up: bool, + max_restarts_on_errors: int) -> None: + super().__init__(cluster_name, backend, task, retry_until_up, + max_restarts_on_errors) # Note down the cloud/region of the launched cluster, so that we can # first retry in the same cloud/region. (Inside recover() we may not # rely on cluster handle, as it can be None if the cluster is diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 6392cb55e5a..e5cb5f9d6ee 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -2,6 +2,7 @@ # TODO(zhwu): maybe use file based status instead of database, so # that we can easily switch to a s3-based storage. import enum +import json import pathlib import sqlite3 import time @@ -65,7 +66,8 @@ def _get_db_path() -> str: failure_reason TEXT, spot_job_id INTEGER, task_id INTEGER DEFAULT 0, - task_name TEXT)""") + task_name TEXT, + specs TEXT)""") _CONN.commit() db_utils.add_column_to_table(_CURSOR, _CONN, 'spot', 'failure_reason', 'TEXT') @@ -92,6 +94,17 @@ def _get_db_path() -> str: 'TEXT', copy_from='job_name') +# Specs is some useful information about the task, e.g., the +# max_restarts_on_errors value. It is stored in JSON format. +db_utils.add_column_to_table(_CURSOR, + _CONN, + 'spot', + 'specs', + 'TEXT', + value_to_replace_existing_entries=json.dumps({ + 'max_restarts_on_errors': 0, + })) + # `job_info` contains the mapping from job_id to the job_name. # In the future, it may contain more information about each job. _CURSOR.execute("""\ @@ -128,9 +141,10 @@ def _get_db_path() -> str: 'job_id', 'task_id', 'task_name', + 'specs', # columns from the job_info table '_job_info_job_id', # This should be the same as job_id - 'job_name' + 'job_name', ] @@ -283,7 +297,8 @@ def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str): def set_submitted(job_id: int, task_id: int, run_timestamp: str, submit_time: float, resources_str: str, - callback_func: CallbackType): + specs: Dict[str, Union[str, + int]], callback_func: CallbackType): """Set the task to submitted. Args: @@ -293,6 +308,8 @@ def set_submitted(job_id: int, task_id: int, run_timestamp: str, determine the log directory of the managed task. submit_time: The time when the managed task is submitted. resources_str: The resources string of the managed task. + specs: The specs of the managed task. + callback_func: The callback function. """ # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make # the log directory and submission time align with each other, so as to @@ -306,11 +323,12 @@ def set_submitted(job_id: int, task_id: int, run_timestamp: str, resources=(?), submitted_at=(?), status=(?), - run_timestamp=(?) + run_timestamp=(?), + specs=(?) WHERE spot_job_id=(?) AND task_id=(?)""", (resources_str, submit_time, ManagedJobStatus.SUBMITTED.value, - run_timestamp, job_id, task_id)) + run_timestamp, json.dumps(specs), job_id, task_id)) callback_func('SUBMITTED') @@ -639,3 +657,13 @@ def get_latest_job_id() -> Optional[int]: for (job_id,) in rows: return job_id return None + + +def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]: + with db_utils.safe_cursor(_DB_PATH) as cursor: + task_specs = cursor.execute( + """\ + SELECT specs FROM spot + WHERE spot_job_id=(?) AND task_id=(?)""", + (job_id, task_id)).fetchone() + return json.loads(task_specs[0]) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index bda14fedd5f..2ff83c668e3 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -70,7 +70,7 @@ # state, after the job finished. This is a safeguard to avoid the case where # the managed job status fails to be updated and keep the `sky jobs logs` # blocking for a long time. -_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 20 +_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 25 class UserSignal(enum.Enum): @@ -481,32 +481,56 @@ def get_next_task_id_status( assert task_id == specific_task_id, (task_id, specific_task_id) break - if (task_id == num_tasks - 1 or not follow): - break - - # The log for the current job is finished. We need to - # wait until next job to be started. - logger.debug( - f'INFO: Log for the current task ({task_id}) ' - 'is finished. Waiting for the next task\'s log ' - 'to be started.') - update_message(f'Waiting for the next task: {task_id + 1}.') - status_display.start() - - original_task_id = task_id - while True: - task_id, managed_job_status = get_next_task_id_status( - job_id, specific_task_id) - if original_task_id != task_id: + if task_id < num_tasks - 1 and follow: + # The log for the current job is finished. We need to + # wait until next job to be started. + logger.debug( + f'INFO: Log for the current task ({task_id}) ' + 'is finished. Waiting for the next task\'s log ' + 'to be started.') + # Add a newline to avoid the status display below + # removing the last line of the task output. + print() + status_display.update( + ux_utils.spinner_message( + f'Waiting for the next task: {task_id + 1}')) + status_display.start() + original_task_id = task_id + while True: + task_id, managed_job_status = ( + managed_job_state.get_latest_task_id_status( + job_id)) + if original_task_id != task_id: + break + time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + continue + else: + task_specs = managed_job_state.get_task_specs( + job_id, task_id) + if task_specs.get('max_restarts_on_errors', 0) == 0: + # We don't need to wait for the managed job status + # update, as the job is guaranteed to be in terminal + # state afterwards. break - time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) - continue - else: - # The job can be cancelled by the user or the controller - # (when the cluster is partially preempted). - logger.debug( - 'INFO: Job is cancelled. Waiting for the status ' - f'update in {JOB_STATUS_CHECK_GAP_SECONDS} seconds.') + print() + status_display.update( + ux_utils.spinner_message( + 'Waiting for next restart for the failed task')) + status_display.start() + while True: + _, managed_job_status = ( + managed_job_state.get_latest_task_id_status( + job_id)) + if (managed_job_status != + managed_job_state.ManagedJobStatus.RUNNING): + break + time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + continue + # The job can be cancelled by the user or the controller (when + # the cluster is partially preempted). + logger.debug( + 'INFO: Job is cancelled. Waiting for the status update in ' + f'{JOB_STATUS_CHECK_GAP_SECONDS} seconds.') else: logger.debug( f'INFO: (Log streaming) Got return code {returncode}. ' diff --git a/sky/optimizer.py b/sky/optimizer.py index b69121954a5..98b7c0387f3 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -832,13 +832,17 @@ def format_number(x): return row def _get_resource_group_hash(resources: 'resources_lib.Resources'): - return json.dumps( - { - 'cloud': f'{resources.cloud}', - 'accelerators': f'{resources.accelerators}', - 'use_spot': resources.use_spot - }, - sort_keys=True) + resource_key_dict = { + 'cloud': f'{resources.cloud}', + 'accelerators': f'{resources.accelerators}', + 'use_spot': resources.use_spot + } + if isinstance(resources.cloud, clouds.Kubernetes): + # Region for Kubernetes is the context name, i.e. different + # Kubernetes clusters. We add region to the key to show all the + # Kubernetes clusters in the optimizer table for better UX. + resource_key_dict['region'] = resources.region + return json.dumps(resource_key_dict, sort_keys=True) # Print the list of resouces that the optimizer considered. resource_fields = [ diff --git a/sky/provision/azure/azure-config-template.json b/sky/provision/azure/azure-config-template.json index 489783faf98..0c70c4d3999 100644 --- a/sky/provision/azure/azure-config-template.json +++ b/sky/provision/azure/azure-config-template.json @@ -13,14 +13,26 @@ "metadata": { "description": "Subnet parameters." } + }, + "location": { + "type": "string", + "metadata": { + "description": "Location of where the resources are allocated." + } + }, + "nsgName": { + "type": "string", + "metadata": { + "description": "Name of the Network Security Group associated with the SkyPilot cluster." + } } }, "variables": { "contributor": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]", - "location": "[resourceGroup().location]", + "location": "[parameters('location')]", "msiName": "[concat('sky-', parameters('clusterId'), '-msi')]", "roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]", - "nsgName": "[concat('sky-', parameters('clusterId'), '-nsg')]", + "nsgName": "[parameters('nsgName')]", "nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]", "vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]", "subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]" diff --git a/sky/provision/azure/azure-vm-template.json b/sky/provision/azure/azure-vm-template.json deleted file mode 100644 index 52e82dc532c..00000000000 --- a/sky/provision/azure/azure-vm-template.json +++ /dev/null @@ -1,301 +0,0 @@ -{ - "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", - "contentVersion": "1.0.0.0", - "parameters": { - "vmName": { - "type": "string", - "metadata": { - "description": "The name of you Virtual Machine." - } - }, - "adminUsername": { - "type": "string", - "metadata": { - "description": "Username for the Virtual Machine." - } - }, - "publicKey": { - "type": "securestring", - "metadata": { - "description": "SSH Key for the Virtual Machine" - } - }, - "imagePublisher": { - "type": "string", - "metadata": { - "description": "The publisher of the VM image" - } - }, - "imageOffer": { - "type": "string", - "metadata": { - "description": "The offer of the VM image" - } - }, - "imageSku": { - "type": "string", - "metadata": { - "description": "The sku of the VM image" - } - }, - "imageVersion": { - "type": "string", - "metadata": { - "description": "The version of the VM image" - } - }, - "vmSize": { - "type": "string", - "metadata": { - "description": "The size of the VM" - } - }, - "vmTags": { - "type": "object", - "metadata": { - "description": "Tags for the VM" - } - }, - "vmCount": { - "type": "int", - "metadata": { - "description": "Number of VMs to deploy" - } - }, - "provisionPublicIp": { - "type": "bool", - "defaultValue": true, - "metadata": { - "description": "If true creates a public ip" - } - }, - "priority": { - "type": "string", - "defaultValue": "Regular", - "metadata": { - "description": "Specifies the priority for the virtual machine." - } - }, - "billingProfile": { - "type": "object", - "defaultValue": {}, - "metadata": { - "description": "Specifies the maximum price to pay for Azure Spot VM." - } - }, - "osDiskSizeGB": { - "type": "int", - "metadata": { - "description": "OS disk size in GBs." - } - }, - "msi": { - "type": "string", - "metadata": { - "description": "Managed service identity resource id." - } - }, - "nsg": { - "type": "string", - "metadata": { - "description": "Network security group resource id." - } - }, - "subnet": { - "type": "string", - "metadata": { - "descriptions": "Subnet resource id." - } - }, - "osDiskTier": { - "type": "string", - "allowedValues": [ - "Premium_LRS", - "StandardSSD_LRS", - "Standard_LRS" - ], - "metadata": { - "description": "OS disk tier." - } - }, - "cloudInitSetupCommands": { - "type": "string", - "metadata": { - "description": "Base64 encoded cloud-init setup commands." - } - } - }, - "variables": { - "location": "[resourceGroup().location]", - "networkInterfaceNamePrivate": "[concat(parameters('vmName'), '-nic')]", - "networkInterfaceNamePublic": "[concat(parameters('vmName'), '-nic-public')]", - "networkInterfaceName": "[if(parameters('provisionPublicIp'), variables('networkInterfaceNamePublic'), variables('networkInterfaceNamePrivate'))]", - "networkIpConfig": "[guid(resourceGroup().id, parameters('vmName'))]", - "publicIpAddressName": "[concat(parameters('vmName'), '-ip')]" - }, - "resources": [ - { - "type": "Microsoft.Network/networkInterfaces", - "apiVersion": "2020-06-01", - "name": "[concat(variables('networkInterfaceNamePublic'), copyIndex())]", - "location": "[variables('location')]", - "dependsOn": [ - "[resourceId('Microsoft.Network/publicIpAddresses/', concat(variables('publicIpAddressName'), copyIndex()))]" - ], - "copy": { - "name": "NICPublicCopy", - "count": "[parameters('vmCount')]" - }, - "properties": { - "ipConfigurations": [ - { - "name": "[variables('networkIpConfig')]", - "properties": { - "subnet": { - "id": "[parameters('subnet')]" - }, - "privateIPAllocationMethod": "Dynamic", - "publicIpAddress": { - "id": "[resourceId('Microsoft.Network/publicIPAddresses', concat(variables('publicIPAddressName'), copyIndex()))]" - } - } - } - ], - "networkSecurityGroup": { - "id": "[parameters('nsg')]" - } - }, - "condition": "[parameters('provisionPublicIp')]" - }, - { - "type": "Microsoft.Network/networkInterfaces", - "apiVersion": "2020-06-01", - "name": "[concat(variables('networkInterfaceNamePrivate'), copyIndex())]", - "location": "[variables('location')]", - "copy": { - "name": "NICPrivateCopy", - "count": "[parameters('vmCount')]" - }, - "properties": { - "ipConfigurations": [ - { - "name": "[variables('networkIpConfig')]", - "properties": { - "subnet": { - "id": "[parameters('subnet')]" - }, - "privateIPAllocationMethod": "Dynamic" - } - } - ], - "networkSecurityGroup": { - "id": "[parameters('nsg')]" - } - }, - "condition": "[not(parameters('provisionPublicIp'))]" - }, - { - "type": "Microsoft.Network/publicIpAddresses", - "apiVersion": "2019-02-01", - "name": "[concat(variables('publicIpAddressName'), copyIndex())]", - "location": "[variables('location')]", - "properties": { - "publicIpAllocationMethod": "Static", - "publicIPAddressVersion": "IPv4" - }, - "copy": { - "name": "PublicIpCopy", - "count": "[parameters('vmCount')]" - }, - "sku": { - "name": "Basic", - "tier": "Regional" - }, - "condition": "[parameters('provisionPublicIp')]" - }, - { - "type": "Microsoft.Compute/virtualMachines", - "apiVersion": "2019-03-01", - "name": "[concat(parameters('vmName'), copyIndex())]", - "location": "[variables('location')]", - "dependsOn": [ - "[resourceId('Microsoft.Network/networkInterfaces/', concat(variables('networkInterfaceName'), copyIndex()))]" - ], - "copy": { - "name": "VmCopy", - "count": "[parameters('vmCount')]" - }, - "tags": "[parameters('vmTags')]", - "properties": { - "hardwareProfile": { - "vmSize": "[parameters('vmSize')]" - }, - "storageProfile": { - "osDisk": { - "createOption": "fromImage", - "managedDisk": { - "storageAccountType": "[parameters('osDiskTier')]" - }, - "diskSizeGB": "[parameters('osDiskSizeGB')]" - }, - "imageReference": { - "publisher": "[parameters('imagePublisher')]", - "offer": "[parameters('imageOffer')]", - "sku": "[parameters('imageSku')]", - "version": "[parameters('imageVersion')]" - } - }, - "networkProfile": { - "networkInterfaces": [ - { - "id": "[resourceId('Microsoft.Network/networkInterfaces', concat(variables('networkInterfaceName'), copyIndex()))]" - } - ] - }, - "osProfile": { - "computerName": "[concat(parameters('vmName'), copyIndex())]", - "adminUsername": "[parameters('adminUsername')]", - "adminPassword": "[parameters('publicKey')]", - "linuxConfiguration": { - "disablePasswordAuthentication": true, - "ssh": { - "publicKeys": [ - { - "path": "[concat('/home/', parameters('adminUsername'), '/.ssh/authorized_keys')]", - "keyData": "[parameters('publicKey')]" - } - ] - } - }, - "customData": "[parameters('cloudInitSetupCommands')]" - }, - "priority": "[parameters('priority')]", - "billingProfile": "[parameters('billingProfile')]" - }, - "identity": { - "type": "UserAssigned", - "userAssignedIdentities": { - "[parameters('msi')]": { - } - } - } - } - ], - "outputs": { - "publicIp": { - "type": "array", - "copy": { - "count": "[parameters('vmCount')]", - "input": "[reference(concat(variables('publicIpAddressName'), copyIndex())).ipAddress]" - }, - "condition": "[parameters('provisionPublicIp')]" - }, - "privateIp": { - "type": "array", - "copy": { - "count": "[parameters('vmCount')]", - "input": "[reference(concat(variables('networkInterfaceName'), copyIndex())).ipConfigurations[0].properties.privateIPAddress]" - } - } - } -} diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py index b3cb357512a..e7ab59daa33 100644 --- a/sky/provision/azure/config.py +++ b/sky/provision/azure/config.py @@ -8,20 +8,20 @@ from pathlib import Path import random import time -from typing import Any, Callable +from typing import Any, Callable, Tuple from sky import exceptions from sky import sky_logging from sky.adaptors import azure from sky.provision import common +from sky.provision import constants from sky.utils import common_utils logger = sky_logging.init_logger(__name__) UNIQUE_ID_LEN = 4 -_DEPLOYMENT_NAME = 'skypilot-config' -_LEGACY_DEPLOYMENT_NAME = 'ray-config' _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT = 480 # 8 minutes +_CLUSTER_ID = '{cluster_name_on_cloud}-{unique_id}' def get_azure_sdk_function(client: Any, function_name: str) -> Callable: @@ -41,11 +41,25 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable: return func +def get_cluster_id_and_nsg_name(resource_group: str, + cluster_name_on_cloud: str) -> Tuple[str, str]: + hasher = hashlib.md5(resource_group.encode('utf-8')) + unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN] + # We use the cluster name + resource group hash as the + # unique ID for the cluster, as we need to make sure that + # the deployments have unique names during failover. + cluster_id = _CLUSTER_ID.format(cluster_name_on_cloud=cluster_name_on_cloud, + unique_id=unique_id) + nsg_name = f'sky-{cluster_id}-nsg' + return cluster_id, nsg_name + + @common.log_function_start_end def bootstrap_instances( region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionConfig: """See sky/provision/__init__.py""" + # TODO: use new azure sdk instead of ARM deployment. del region # unused provider_config = config.provider_config subscription_id = provider_config.get('subscription_id') @@ -67,46 +81,55 @@ def bootstrap_instances( in provider_config), 'Provider config must include location field' params = {'location': provider_config['location']} + assert ('use_external_resource_group' + in provider_config), ('Provider config must include ' + 'use_external_resource_group field') + use_external_resource_group = provider_config['use_external_resource_group'] + if 'tags' in provider_config: params['tags'] = provider_config['tags'] - logger.info(f'Creating/Updating resource group: {resource_group}') - rg_create_or_update = get_azure_sdk_function( - client=resource_client.resource_groups, - function_name='create_or_update') - rg_creation_start = time.time() - retry = 0 - while (time.time() - rg_creation_start < - _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT): - try: - rg_create_or_update(resource_group_name=resource_group, - parameters=params) - break - except azure.exceptions().ResourceExistsError as e: - if 'ResourceGroupBeingDeleted' in str(e): - if retry % 5 == 0: - logger.info( - f'Azure resource group {resource_group} of a recent ' - f'terminated cluster {cluster_name_on_cloud} is being ' - 'deleted. It can only be provisioned after it is fully ' - 'deleted. Waiting...') - time.sleep(1) - retry += 1 - continue - raise - except azure.exceptions().ClientAuthenticationError as e: + # When resource group is user specified, it already exists in certain + # region. + if not use_external_resource_group: + logger.info(f'Creating/Updating resource group: {resource_group}') + rg_create_or_update = get_azure_sdk_function( + client=resource_client.resource_groups, + function_name='create_or_update') + rg_creation_start = time.time() + retry = 0 + while (time.time() - rg_creation_start < + _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT): + try: + rg_create_or_update(resource_group_name=resource_group, + parameters=params) + break + except azure.exceptions().ResourceExistsError as e: + if 'ResourceGroupBeingDeleted' in str(e): + if retry % 5 == 0: + logger.info( + f'Azure resource group {resource_group} of a ' + 'recent terminated cluster ' + f'{cluster_name_on_cloud} is being deleted. It can' + ' only be provisioned after it is fully deleted. ' + 'Waiting...') + time.sleep(1) + retry += 1 + continue + raise + except azure.exceptions().ClientAuthenticationError as e: + message = ( + 'Failed to authenticate with Azure. Please check your ' + 'Azure credentials. Error: ' + f'{common_utils.format_exception(e)}').replace('\n', ' ') + logger.error(message) + raise exceptions.NoClusterLaunchedError(message) from e + else: message = ( - 'Failed to authenticate with Azure. Please check your Azure ' - f'credentials. Error: {common_utils.format_exception(e)}' - ).replace('\n', ' ') + f'Timed out waiting for resource group {resource_group} to be ' + 'deleted.') logger.error(message) - raise exceptions.NoClusterLaunchedError(message) from e - else: - message = ( - f'Timed out waiting for resource group {resource_group} to be ' - 'deleted.') - logger.error(message) - raise TimeoutError(message) + raise TimeoutError(message) # load the template file current_path = Path(__file__).parent @@ -116,12 +139,13 @@ def bootstrap_instances( logger.info(f'Using cluster name: {cluster_name_on_cloud}') - hasher = hashlib.md5(provider_config['resource_group'].encode('utf-8')) - unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN] + cluster_id, nsg_name = get_cluster_id_and_nsg_name( + resource_group=provider_config['resource_group'], + cluster_name_on_cloud=cluster_name_on_cloud) subnet_mask = provider_config.get('subnet_mask') if subnet_mask is None: # choose a random subnet, skipping most common value of 0 - random.seed(unique_id) + random.seed(cluster_id) subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' logger.info(f'Using subnet mask: {subnet_mask}') @@ -134,11 +158,14 @@ def bootstrap_instances( 'value': subnet_mask }, 'clusterId': { - # We use the cluster name + resource group hash as the - # unique ID for the cluster, as we need to make sure that - # the deployments have unique names during failover. - 'value': f'{cluster_name_on_cloud}-{unique_id}' + 'value': cluster_id + }, + 'nsgName': { + 'value': nsg_name }, + 'location': { + 'value': params['location'] + } }, } } @@ -148,11 +175,22 @@ def bootstrap_instances( get_deployment = get_azure_sdk_function(client=resource_client.deployments, function_name='get') deployment_exists = False - for deployment_name in [_DEPLOYMENT_NAME, _LEGACY_DEPLOYMENT_NAME]: + if use_external_resource_group: + deployment_name = ( + constants.EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud)) + deployment_list = [deployment_name] + else: + deployment_name = constants.DEPLOYMENT_NAME + deployment_list = [ + constants.DEPLOYMENT_NAME, constants.LEGACY_DEPLOYMENT_NAME + ] + + for deploy_name in deployment_list: try: deployment = get_deployment(resource_group_name=resource_group, - deployment_name=deployment_name) - logger.info(f'Deployment {deployment_name!r} already exists. ' + deployment_name=deploy_name) + logger.info(f'Deployment {deploy_name!r} already exists. ' 'Skipping deployment creation.') outputs = deployment.properties.outputs @@ -163,22 +201,20 @@ def bootstrap_instances( deployment_exists = False if not deployment_exists: - logger.info(f'Creating/Updating deployment: {_DEPLOYMENT_NAME}') + logger.info(f'Creating/Updating deployment: {deployment_name}') create_or_update = get_azure_sdk_function( client=resource_client.deployments, function_name='create_or_update') # TODO (skypilot): this takes a long time (> 40 seconds) to run. outputs = create_or_update( resource_group_name=resource_group, - deployment_name=_DEPLOYMENT_NAME, + deployment_name=deployment_name, parameters=parameters, ).result().properties.outputs - nsg_id = outputs['nsg']['value'] - # append output resource ids to be used with vm creation provider_config['msi'] = outputs['msi']['value'] - provider_config['nsg'] = nsg_id + provider_config['nsg'] = outputs['nsg']['value'] provider_config['subnet'] = outputs['subnet']['value'] return config diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 3c5ed8801a4..60159232787 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -2,10 +2,8 @@ import base64 import copy import enum -import json import logging from multiprocessing import pool -import pathlib import time import typing from typing import Any, Callable, Dict, List, Optional, Tuple @@ -17,13 +15,16 @@ from sky.adaptors import azure from sky.provision import common from sky.provision import constants +from sky.provision.azure import config as config_lib from sky.utils import common_utils from sky.utils import subprocess_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: from azure.mgmt import compute as azure_compute - from azure.mgmt import resource as azure_resource + from azure.mgmt import network as azure_network + from azure.mgmt.compute import models as azure_compute_models + from azure.mgmt.network import models as azure_network_models logger = sky_logging.init_logger(__name__) @@ -31,6 +32,8 @@ # https://github.com/Azure/azure-sdk-for-python/issues/9422 azure_logger = logging.getLogger('azure') azure_logger.setLevel(logging.WARNING) +Client = Any +NetworkSecurityGroup = Any _RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes _RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes @@ -38,8 +41,21 @@ _TAG_SKYPILOT_VM_ID = 'skypilot-vm-id' _WAIT_CREATION_TIMEOUT_SECONDS = 600 +_RESOURCE_MANAGED_IDENTITY_TYPE = ( + 'Microsoft.ManagedIdentity/userAssignedIdentities') +_RESOURCE_NETWORK_SECURITY_GROUP_TYPE = ( + 'Microsoft.Network/networkSecurityGroups') +_RESOURCE_VIRTUAL_NETWORK_TYPE = 'Microsoft.Network/virtualNetworks' +_RESOURCE_PUBLIC_IP_ADDRESS_TYPE = 'Microsoft.Network/publicIPAddresses' +_RESOURCE_VIRTUAL_MACHINE_TYPE = 'Microsoft.Compute/virtualMachines' +_RESOURCE_NETWORK_INTERFACE_TYPE = 'Microsoft.Network/networkInterfaces' + _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' _POLL_INTERVAL = 1 +# TODO(Doyoung): _LEGACY_NSG_NAME can be remove this after 0.8.0 to ignore +# legacy nsg names. +_LEGACY_NSG_NAME = 'ray-{cluster_name_on_cloud}-nsg' +_SECOND_LEGACY_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg' class AzureInstanceStatus(enum.Enum): @@ -184,14 +200,131 @@ def _get_head_instance_id(instances: List) -> Optional[str]: return head_instance_id -def _create_instances( - compute_client: 'azure_compute.ComputeManagementClient', - resource_client: 'azure_resource.ResourceManagementClient', - cluster_name_on_cloud: str, resource_group: str, - provider_config: Dict[str, Any], node_config: Dict[str, Any], - tags: Dict[str, str], count: int) -> List: +def _create_network_interface( + network_client: 'azure_network.NetworkManagementClient', vm_name: str, + provider_config: Dict[str, + Any]) -> 'azure_network_models.NetworkInterface': + network = azure.azure_mgmt_models('network') + compute = azure.azure_mgmt_models('compute') + logger.info(f'Start creating network interface for {vm_name}...') + if provider_config.get('use_internal_ips', False): + name = f'{vm_name}-nic-private' + ip_config = network.IPConfiguration( + name=f'ip-config-private-{vm_name}', + subnet=compute.SubResource(id=provider_config['subnet']), + private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC) + else: + name = f'{vm_name}-nic-public' + public_ip_address = network.PublicIPAddress( + location=provider_config['location'], + public_ip_allocation_method='Static', + public_ip_address_version='IPv4', + sku=network.PublicIPAddressSku(name='Basic', tier='Regional')) + ip_poller = network_client.public_ip_addresses.begin_create_or_update( + resource_group_name=provider_config['resource_group'], + public_ip_address_name=f'{vm_name}-ip', + parameters=public_ip_address) + logger.info(f'Created public IP address {ip_poller.result().name} ' + f'with address {ip_poller.result().ip_address}.') + ip_config = network.IPConfiguration( + name=f'ip-config-public-{vm_name}', + subnet=compute.SubResource(id=provider_config['subnet']), + private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC, + public_ip_address=network.PublicIPAddress(id=ip_poller.result().id)) + + ni_poller = network_client.network_interfaces.begin_create_or_update( + resource_group_name=provider_config['resource_group'], + network_interface_name=name, + parameters=network.NetworkInterface( + location=provider_config['location'], + ip_configurations=[ip_config], + network_security_group=network.NetworkSecurityGroup( + id=provider_config['nsg']))) + logger.info(f'Created network interface {ni_poller.result().name}.') + return ni_poller.result() + + +def _create_vm( + compute_client: 'azure_compute.ComputeManagementClient', vm_name: str, + node_tags: Dict[str, str], provider_config: Dict[str, Any], + node_config: Dict[str, Any], + network_interface_id: str) -> 'azure_compute_models.VirtualMachine': + compute = azure.azure_mgmt_models('compute') + logger.info(f'Start creating VM {vm_name}...') + hardware_profile = compute.HardwareProfile( + vm_size=node_config['azure_arm_parameters']['vmSize']) + network_profile = compute.NetworkProfile(network_interfaces=[ + compute.NetworkInterfaceReference(id=network_interface_id, primary=True) + ]) + public_key = node_config['azure_arm_parameters']['publicKey'] + username = node_config['azure_arm_parameters']['adminUsername'] + os_linux_custom_data = base64.b64encode( + node_config['azure_arm_parameters']['cloudInitSetupCommands'].encode( + 'utf-8')).decode('utf-8') + os_profile = compute.OSProfile( + admin_username=username, + computer_name=vm_name, + admin_password=public_key, + linux_configuration=compute.LinuxConfiguration( + disable_password_authentication=True, + ssh=compute.SshConfiguration(public_keys=[ + compute.SshPublicKey( + path=f'/home/{username}/.ssh/authorized_keys', + key_data=public_key) + ])), + custom_data=os_linux_custom_data) + community_image_id = node_config['azure_arm_parameters'].get( + 'communityGalleryImageId', None) + if community_image_id is not None: + # Prioritize using community gallery image if specified. + image_reference = compute.ImageReference( + community_gallery_image_id=community_image_id) + logger.info( + f'Used community_image_id: {community_image_id} for VM {vm_name}.') + else: + image_reference = compute.ImageReference( + publisher=node_config['azure_arm_parameters']['imagePublisher'], + offer=node_config['azure_arm_parameters']['imageOffer'], + sku=node_config['azure_arm_parameters']['imageSku'], + version=node_config['azure_arm_parameters']['imageVersion']) + storage_profile = compute.StorageProfile( + image_reference=image_reference, + os_disk=compute.OSDisk( + create_option=compute.DiskCreateOptionTypes.FROM_IMAGE, + delete_option=compute.DiskDeleteOptionTypes.DELETE, + managed_disk=compute.ManagedDiskParameters( + storage_account_type=node_config['azure_arm_parameters'] + ['osDiskTier']), + disk_size_gb=node_config['azure_arm_parameters']['osDiskSizeGB'])) + vm_instance = compute.VirtualMachine( + location=provider_config['location'], + tags=node_tags, + hardware_profile=hardware_profile, + os_profile=os_profile, + storage_profile=storage_profile, + network_profile=network_profile, + identity=compute.VirtualMachineIdentity( + type='UserAssigned', + user_assigned_identities={provider_config['msi']: {}})) + vm_poller = compute_client.virtual_machines.begin_create_or_update( + resource_group_name=provider_config['resource_group'], + vm_name=vm_name, + parameters=vm_instance, + ) + # This line will block until the VM is created or the operation times out. + vm = vm_poller.result() + logger.info(f'Created VM {vm.name}.') + return vm + + +def _create_instances(compute_client: 'azure_compute.ComputeManagementClient', + network_client: 'azure_network.NetworkManagementClient', + cluster_name_on_cloud: str, resource_group: str, + provider_config: Dict[str, Any], node_config: Dict[str, + Any], + tags: Dict[str, str], count: int) -> List: vm_id = uuid4().hex[:UNIQUE_ID_LEN] - tags = { + all_tags = { constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud, **constants.WORKER_NODE_TAGS, @@ -199,83 +332,19 @@ def _create_instances( **tags, } node_tags = node_config['tags'].copy() - node_tags.update(tags) + node_tags.update(all_tags) - # load the template file - current_path = pathlib.Path(__file__).parent - template_path = current_path.joinpath('azure-vm-template.json') - with open(template_path, 'r', encoding='utf-8') as template_fp: - template = json.load(template_fp) + # Create VM instances in parallel. + def create_single_instance(vm_i): + vm_name = f'{cluster_name_on_cloud}-{vm_id}-{vm_i}' + network_interface = _create_network_interface(network_client, vm_name, + provider_config) + _create_vm(compute_client, vm_name, node_tags, provider_config, + node_config, network_interface.id) - vm_name = f'{cluster_name_on_cloud}-{vm_id}' - use_internal_ips = provider_config.get('use_internal_ips', False) - - template_params = node_config['azure_arm_parameters'].copy() - # We don't include 'head' or 'worker' in the VM name as on Azure the VM - # name is immutable and we may change the node type for existing VM in the - # multi-node cluster, due to manual termination of the head node. - template_params['vmName'] = vm_name - template_params['provisionPublicIp'] = not use_internal_ips - template_params['vmTags'] = node_tags - template_params['vmCount'] = count - template_params['msi'] = provider_config['msi'] - template_params['nsg'] = provider_config['nsg'] - template_params['subnet'] = provider_config['subnet'] - # In Azure, cloud-init script must be encoded in base64. For more - # information, see: - # https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data - template_params['cloudInitSetupCommands'] = (base64.b64encode( - template_params['cloudInitSetupCommands'].encode('utf-8')).decode( - 'utf-8')) - - if node_config.get('need_nvidia_driver_extension', False): - # pylint: disable=line-too-long - # Configure driver extension for A10 GPUs. A10 GPUs requires a - # special type of drivers which is available at Microsoft HPC - # extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2 - for r in template['resources']: - if r['type'] == 'Microsoft.Compute/virtualMachines': - # Add a nested extension resource for A10 GPUs - r['resources'] = [ - { - 'type': 'extensions', - 'apiVersion': '2015-06-15', - 'location': '[variables(\'location\')]', - 'dependsOn': [ - '[concat(\'Microsoft.Compute/virtualMachines/\', parameters(\'vmName\'), copyIndex())]' - ], - 'name': 'NvidiaGpuDriverLinux', - 'properties': { - 'publisher': 'Microsoft.HpcCompute', - 'type': 'NvidiaGpuDriverLinux', - 'typeHandlerVersion': '1.9', - 'autoUpgradeMinorVersion': True, - 'settings': {}, - }, - }, - ] - break - - parameters = { - 'properties': { - 'mode': azure.deployment_mode().incremental, - 'template': template, - 'parameters': { - key: { - 'value': value - } for key, value in template_params.items() - }, - } - } - - create_or_update = _get_azure_sdk_function( - client=resource_client.deployments, function_name='create_or_update') - create_or_update( - resource_group_name=resource_group, - deployment_name=vm_name, - parameters=parameters, - ).wait() + subprocess_utils.run_in_parallel(create_single_instance, range(count)) + # Update disk performance tier performance_tier = node_config.get('disk_performance_tier', None) if performance_tier is not None: disks = compute_client.disks.list_by_resource_group(resource_group) @@ -286,12 +355,14 @@ def _create_instances( f'az disk update -n {name} -g {resource_group} ' f'--set tier={performance_tier}') + # Validation filters = { constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, _TAG_SKYPILOT_VM_ID: vm_id } instances = _filter_instances(compute_client, resource_group, filters) assert len(instances) == count, (len(instances), count) + return instances @@ -303,7 +374,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, resource_group = provider_config['resource_group'] subscription_id = provider_config['subscription_id'] compute_client = azure.get_client('compute', subscription_id) - + network_client = azure.get_client('network', subscription_id) instances_to_resume = [] resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] @@ -439,12 +510,11 @@ def _create_instance_tag(target_instance, is_head: bool = True) -> str: to_start_count -= len(resumed_instance_ids) if to_start_count > 0: - resource_client = azure.get_client('resource', subscription_id) logger.debug(f'run_instances: Creating {to_start_count} instances.') try: created_instances = _create_instances( compute_client=compute_client, - resource_client=resource_client, + network_client=network_client, cluster_name_on_cloud=cluster_name_on_cloud, resource_group=resource_group, provider_config=provider_config, @@ -617,18 +687,30 @@ def terminate_instances( assert provider_config is not None, cluster_name_on_cloud - resource_group_client = azure.get_client('resource', subscription_id) - delete_resource_group = _get_azure_sdk_function( - client=resource_group_client.resource_groups, function_name='delete') - - try: - delete_resource_group(resource_group, force_deletion_types=None) - except azure.exceptions().ResourceNotFoundError as e: - if 'ResourceGroupNotFound' in str(e): - logger.warning(f'Resource group {resource_group} not found. Skip ' - 'terminating it.') - return - raise + use_external_resource_group = provider_config.get( + 'use_external_resource_group', False) + # When user specified resource group through config.yaml to create a VM, we + # cannot remove the entire resource group as it may contain other resources + # unrelated to this VM being removed. + if use_external_resource_group: + delete_vm_and_attached_resources(subscription_id, resource_group, + cluster_name_on_cloud) + else: + # For SkyPilot default resource groups, delete entire resource group. + # This automatically terminates all resources within, including VMs + resource_group_client = azure.get_client('resource', subscription_id) + delete_resource_group = _get_azure_sdk_function( + client=resource_group_client.resource_groups, + function_name='delete') + try: + delete_resource_group(resource_group, force_deletion_types=None) + except azure.exceptions().ResourceNotFoundError as e: + if 'ResourceGroupNotFound' in str(e): + logger.warning( + f'Resource group {resource_group} not found. Skip ' + 'terminating it.') + return + raise def _get_instance_status( @@ -690,6 +772,188 @@ def match_tags(vm): return nodes +def _delete_nic_with_retries(network_client, + resource_group, + nic_name, + max_retries=15, + retry_interval=20): + """Delete a NIC with retries. + + When a VM is created, its NIC is reserved for 180 seconds, preventing its + immediate deletion. If the NIC is in this reserved state, we must retry + deletion with intervals until the reservation expires. This situation + commonly arises if a VM termination is followed by a failover to another + region due to provisioning failures. + """ + delete_network_interfaces = _get_azure_sdk_function( + client=network_client.network_interfaces, function_name='begin_delete') + for _ in range(max_retries): + try: + delete_network_interfaces(resource_group_name=resource_group, + network_interface_name=nic_name).result() + return + except azure.exceptions().HttpResponseError as e: + if 'NicReservedForAnotherVm' in str(e): + # Retry when deletion fails with reserved NIC. + logger.warning(f'NIC {nic_name} is reserved. ' + f'Retrying in {retry_interval} seconds...') + time.sleep(retry_interval) + else: + raise e + logger.error( + f'Failed to delete NIC {nic_name} after {max_retries} attempts.') + + +def delete_vm_and_attached_resources(subscription_id: str, resource_group: str, + cluster_name_on_cloud: str) -> None: + """Removes VM with attached resources and Deployments. + + This function deletes a virtual machine and its associated resources + (public IP addresses, virtual networks, managed identities, network + interface and network security groups) that match cluster_name_on_cloud. + There is one attached resources that is not removed within this + method: OS disk. It is configured to be deleted when VM is terminated while + setting up storage profile from _create_vm. + + Args: + subscription_id: The Azure subscription ID. + resource_group: The name of the resource group. + cluster_name_on_cloud: The name of the cluster to filter resources. + """ + resource_client = azure.get_client('resource', subscription_id) + try: + list_resources = _get_azure_sdk_function( + client=resource_client.resources, + function_name='list_by_resource_group') + resources = list(list_resources(resource_group)) + except azure.exceptions().ResourceNotFoundError as e: + if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e): + return + raise + + filtered_resources: Dict[str, List[str]] = { + _RESOURCE_VIRTUAL_MACHINE_TYPE: [], + _RESOURCE_MANAGED_IDENTITY_TYPE: [], + _RESOURCE_NETWORK_SECURITY_GROUP_TYPE: [], + _RESOURCE_VIRTUAL_NETWORK_TYPE: [], + _RESOURCE_PUBLIC_IP_ADDRESS_TYPE: [], + _RESOURCE_NETWORK_INTERFACE_TYPE: [] + } + + for resource in resources: + if (resource.type in filtered_resources and + cluster_name_on_cloud in resource.name): + filtered_resources[resource.type].append(resource.name) + + network_client = azure.get_client('network', subscription_id) + msi_client = azure.get_client('msi', subscription_id) + compute_client = azure.get_client('compute', subscription_id) + auth_client = azure.get_client('authorization', subscription_id) + + delete_virtual_machine = _get_azure_sdk_function( + client=compute_client.virtual_machines, function_name='delete') + delete_public_ip_addresses = _get_azure_sdk_function( + client=network_client.public_ip_addresses, function_name='begin_delete') + delete_virtual_networks = _get_azure_sdk_function( + client=network_client.virtual_networks, function_name='begin_delete') + delete_managed_identity = _get_azure_sdk_function( + client=msi_client.user_assigned_identities, function_name='delete') + delete_network_security_group = _get_azure_sdk_function( + client=network_client.network_security_groups, + function_name='begin_delete') + delete_role_assignment = _get_azure_sdk_function( + client=auth_client.role_assignments, function_name='delete') + + for vm_name in filtered_resources[_RESOURCE_VIRTUAL_MACHINE_TYPE]: + try: + # Before removing Network Interface, we need to wait for the VM to + # be completely removed with .result() so the dependency of VM on + # Network Interface is disassociated. This takes abour ~30s. + delete_virtual_machine(resource_group_name=resource_group, + vm_name=vm_name).result() + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete VM: {}'.format(e)) + + for nic_name in filtered_resources[_RESOURCE_NETWORK_INTERFACE_TYPE]: + try: + # Before removing Public IP Address, we need to wait for the + # Network Interface to be completely removed with .result() so the + # dependency of Network Interface on Public IP Address is + # disassociated. This takes about ~1s. + _delete_nic_with_retries(network_client, resource_group, nic_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete nic: {}'.format(e)) + + for public_ip_name in filtered_resources[_RESOURCE_PUBLIC_IP_ADDRESS_TYPE]: + try: + delete_public_ip_addresses(resource_group_name=resource_group, + public_ip_address_name=public_ip_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete public ip: {}'.format(e)) + + for vnet_name in filtered_resources[_RESOURCE_VIRTUAL_NETWORK_TYPE]: + try: + delete_virtual_networks(resource_group_name=resource_group, + virtual_network_name=vnet_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete vnet: {}'.format(e)) + + for msi_name in filtered_resources[_RESOURCE_MANAGED_IDENTITY_TYPE]: + user_assigned_identities = ( + msi_client.user_assigned_identities.list_by_resource_group( + resource_group_name=resource_group)) + for identity in user_assigned_identities: + if msi_name == identity.name: + # We use the principal_id to find the correct guid converted + # role assignment name because each managed identity has a + # unique principal_id, and role assignments are associated + # with security principals (like managed identities) via this + # principal_id. + target_principal_id = identity.principal_id + scope = (f'/subscriptions/{subscription_id}' + f'/resourceGroups/{resource_group}') + role_assignments = auth_client.role_assignments.list_for_scope( + scope) + for assignment in role_assignments: + if target_principal_id == assignment.principal_id: + guid_role_assignment_name = assignment.name + try: + delete_role_assignment( + scope=scope, + role_assignment_name=guid_role_assignment_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete role ' + 'assignment: {}'.format(e)) + break + try: + delete_managed_identity(resource_group_name=resource_group, + resource_name=msi_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete msi: {}'.format(e)) + + for nsg_name in filtered_resources[_RESOURCE_NETWORK_SECURITY_GROUP_TYPE]: + try: + delete_network_security_group(resource_group_name=resource_group, + network_security_group_name=nsg_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete nsg: {}'.format(e)) + + delete_deployment = _get_azure_sdk_function( + client=resource_client.deployments, function_name='begin_delete') + deployment_names = [ + constants.EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud), + constants.EXTERNAL_RG_VM_DEPLOYMENT_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud) + ] + for deployment_name in deployment_names: + try: + delete_deployment(resource_group_name=resource_group, + deployment_name=deployment_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete deployment: {}'.format(e)) + + @common_utils.retry def query_instances( cluster_name_on_cloud: str, @@ -722,6 +986,32 @@ def _fetch_and_map_status(node, resource_group: str) -> None: return statuses +# TODO(Doyoung): _get_cluster_nsg can be remove this after 0.8.0 to ignore +# legacy nsg names. +def _get_cluster_nsg(network_client: Client, resource_group: str, + cluster_name_on_cloud: str) -> NetworkSecurityGroup: + """Retrieve the NSG associated with the given name of the cluster.""" + list_network_security_groups = _get_azure_sdk_function( + client=network_client.network_security_groups, function_name='list') + legacy_nsg_name = _LEGACY_NSG_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud) + second_legacy_nsg_name = _SECOND_LEGACY_NSG_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud) + _, nsg_name = config_lib.get_cluster_id_and_nsg_name( + resource_group=resource_group, + cluster_name_on_cloud=cluster_name_on_cloud) + possible_nsg_names = [nsg_name, legacy_nsg_name, second_legacy_nsg_name] + for nsg in list_network_security_groups(resource_group): + if nsg.name in possible_nsg_names: + return nsg + + # Raise an error if no matching NSG is found + raise ValueError('Failed to find a matching NSG for cluster ' + f'{cluster_name_on_cloud!r} in resource group ' + f'{resource_group!r}. Expected NSG names were: ' + f'{possible_nsg_names}.') + + def open_ports( cluster_name_on_cloud: str, ports: List[str], @@ -738,56 +1028,65 @@ def open_ports( function_name='create_or_update') list_network_security_groups = _get_azure_sdk_function( client=network_client.network_security_groups, function_name='list') + for nsg in list_network_security_groups(resource_group): - try: - # Wait the NSG creation to be finished before opening a port. The - # cluster provisioning triggers the NSG creation, but it may not be - # finished yet. - backoff = common_utils.Backoff(max_backoff_factor=1) - start_time = time.time() - while True: - if nsg.provisioning_state not in ['Creating', 'Updating']: - break - if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS: - logger.warning( - f'Fails to wait for the creation of NSG {nsg.name} in ' - f'{resource_group} within ' - f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. ' - 'Skip this NSG.') - backoff_time = backoff.current_backoff() - logger.info(f'NSG {nsg.name} is not created yet. Waiting for ' - f'{backoff_time} seconds before checking again.') - time.sleep(backoff_time) - - # Azure NSG rules have a priority field that determines the order - # in which they are applied. The priority must be unique across - # all inbound rules in one NSG. - priority = max(rule.priority - for rule in nsg.security_rules - if rule.direction == 'Inbound') + 1 - nsg.security_rules.append( - azure.create_security_rule( - name=f'sky-ports-{cluster_name_on_cloud}-{priority}', - priority=priority, - protocol='Tcp', - access='Allow', - direction='Inbound', - source_address_prefix='*', - source_port_range='*', - destination_address_prefix='*', - destination_port_ranges=ports, - )) - poller = update_network_security_groups(resource_group, nsg.name, - nsg) - poller.wait() - if poller.status() != 'Succeeded': + # Given resource group can contain network security groups that are + # irrelevant to this provisioning especially with user specified + # resource group at ~/.sky/config. So we make sure to check for the + # completion of nsg relevant to the VM being provisioned. + if cluster_name_on_cloud in nsg.name: + try: + # Wait the NSG creation to be finished before opening a port. + # The cluster provisioning triggers the NSG creation, but it + # may not be finished yet. + backoff = common_utils.Backoff(max_backoff_factor=1) + start_time = time.time() + while True: + if nsg.provisioning_state not in ['Creating', 'Updating']: + break + if time.time( + ) - start_time > _WAIT_CREATION_TIMEOUT_SECONDS: + logger.warning( + f'Fails to wait for the creation of NSG {nsg.name}' + f' in {resource_group} within ' + f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. ' + 'Skip this NSG.') + backoff_time = backoff.current_backoff() + logger.info( + f'NSG {nsg.name} is not created yet. Waiting for ' + f'{backoff_time} seconds before checking again.') + time.sleep(backoff_time) + + # Azure NSG rules have a priority field that determines the + # order in which they are applied. The priority must be unique + # across all inbound rules in one NSG. + priority = max(rule.priority + for rule in nsg.security_rules + if rule.direction == 'Inbound') + 1 + nsg.security_rules.append( + azure.create_security_rule( + name=f'sky-ports-{cluster_name_on_cloud}-{priority}', + priority=priority, + protocol='Tcp', + access='Allow', + direction='Inbound', + source_address_prefix='*', + source_port_range='*', + destination_address_prefix='*', + destination_port_ranges=ports, + )) + poller = update_network_security_groups(resource_group, + nsg.name, nsg) + poller.wait() + if poller.status() != 'Succeeded': + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Failed to open ports {ports} in NSG ' + f'{nsg.name}: {poller.status()}') + except azure.exceptions().HttpResponseError as e: with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Failed to open ports {ports} in NSG ' - f'{nsg.name}: {poller.status()}') - except azure.exceptions().HttpResponseError as e: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Failed to open ports {ports} in NSG {nsg.name}.') from e + raise ValueError( + f'Failed to open ports {ports} in NSG {nsg.name}.' + ) from e def cleanup_ports( diff --git a/sky/provision/constants.py b/sky/provision/constants.py index 760abc4861a..8e8ad5ddf1b 100644 --- a/sky/provision/constants.py +++ b/sky/provision/constants.py @@ -16,3 +16,10 @@ TAG_RAY_NODE_KIND: 'worker', TAG_SKYPILOT_HEAD_NODE: '0', } + +# Names for Azure Deployments. +DEPLOYMENT_NAME = 'skypilot-config' +LEGACY_DEPLOYMENT_NAME = 'ray-config' +EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME = ( + 'skypilot-bootstrap-{cluster_name_on_cloud}') +EXTERNAL_RG_VM_DEPLOYMENT_NAME = 'skypilot-vm-{cluster_name_on_cloud}' diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 21d04075f59..9872ad73dc7 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -632,13 +632,6 @@ def cleanup_ports( del ports # Unused. assert provider_config is not None, cluster_name_on_cloud project_id = provider_config['project_id'] - if 'ports' in provider_config: - # Backward compatibility for old provider config. - # TODO(tian): remove this after 2 minor releases, 0.6.0. - for port in provider_config['ports']: - firewall_rule_name = f'user-ports-{cluster_name_on_cloud}-{port}' - instance_utils.GCPComputeInstance.delete_firewall_rule( - project_id, firewall_rule_name) if 'firewall_rule' in provider_config: firewall_rule_name = provider_config['firewall_rule'] instance_utils.GCPComputeInstance.delete_firewall_rule( diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 6663ed3f657..26ed5f51a43 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -1,5 +1,6 @@ """Kubernetes instance provisioning.""" import copy +import json import time from typing import Any, Dict, List, Optional import uuid @@ -18,6 +19,7 @@ from sky.utils import command_runner from sky.utils import common_utils from sky.utils import kubernetes_enums +from sky.utils import subprocess_utils from sky.utils import ux_utils POLL_INTERVAL = 2 @@ -398,8 +400,7 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str], # See https://www.educative.io/answers/error-mesg-ttyname-failed-inappropriate-ioctl-for-device # pylint: disable=line-too-long '$(prefix_cmd) sed -i "s/mesg n/tty -s \\&\\& mesg n/" ~/.profile;') - # TODO(romilb): Parallelize the setup of SSH in pods for multi-node clusters - for new_node in new_nodes: + def _setup_ssh_thread(new_node): pod_name = new_node.metadata.name runner = command_runner.KubernetesCommandRunner( ((namespace, context), pod_name)) @@ -411,6 +412,8 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str], stdout) logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}') + subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes) + def _label_pod(namespace: str, context: Optional[str], pod_name: str, label: Dict[str, str]) -> None: @@ -423,6 +426,70 @@ def _label_pod(namespace: str, context: Optional[str], pod_name: str, _request_timeout=kubernetes.API_TIMEOUT) +def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, + context: Optional[str]) -> Any: + """Attempts to create a Kubernetes Pod and handle any errors. + + Currently, we handle errors due to the AppArmor annotation and retry if + it fails due to the `FieldValueForbidden` error. + See https://github.com/skypilot-org/skypilot/issues/4174 for details. + + Returns: The created Pod object. + """ + try: + # Attempt to create the Pod with the AppArmor annotation + pod = kubernetes.core_api(context).create_namespaced_pod( + namespace, pod_spec) + return pod + except kubernetes.api_exception() as e: + try: + error_body = json.loads(e.body) + error_message = error_body.get('message', '') + except json.JSONDecodeError: + error_message = str(e.body) + # Check if the error is due to the AppArmor annotation and retry. + # We add an AppArmor annotation to set it as unconfined in our + # base template in kubernetes-ray.yml.j2. This is required for + # FUSE to work in the pod on most Kubernetes distributions. + # However, some distributions do not support the AppArmor annotation + # and will fail to create the pod. In this case, we retry without + # the annotation. + if (e.status == 422 and 'FieldValueForbidden' in error_message and + 'AppArmorProfile: nil' in error_message): + logger.warning('AppArmor annotation caused pod creation to fail. ' + 'Retrying without the annotation. ' + 'Note: this may cause bucket mounting to fail.') + + # Remove the AppArmor annotation + annotations = pod_spec.get('metadata', {}).get('annotations', {}) + if ('container.apparmor.security.beta.kubernetes.io/ray-node' + in annotations): + del annotations[ + 'container.apparmor.security.beta.kubernetes.io/ray-node'] + pod_spec['metadata']['annotations'] = annotations + logger.info('AppArmor annotation removed from Pod spec.') + else: + logger.warning('AppArmor annotation not found in pod spec, ' + 'retrying will not help. ' + f'Current annotations: {annotations}') + raise e + + # Retry Pod creation without the AppArmor annotation + try: + pod = kubernetes.core_api(context).create_namespaced_pod( + namespace, pod_spec) + logger.info(f'Pod {pod.metadata.name} created successfully ' + 'without AppArmor annotation.') + return pod + except kubernetes.api_exception() as retry_exception: + logger.info('Failed to create Pod without AppArmor annotation: ' + f'{retry_exception}') + raise retry_exception + else: + # Re-raise the exception if it's a different error + raise e + + def _create_pods(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Create pods based on the config.""" @@ -544,8 +611,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, } } - pod = kubernetes.core_api(context).create_namespaced_pod( - namespace, pod_spec) + pod = _create_namespaced_pod_with_retries(namespace, pod_spec, context) created_pods[pod.metadata.name] = pod if head_pod_name is None: head_pod_name = pod.metadata.name diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 7706a3d489b..b3e965769c9 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -28,6 +28,7 @@ from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import rich_utils +from sky.utils import subprocess_utils from sky.utils import ux_utils # Do not use __name__ as we do not want to propagate logs to sky.provision, @@ -365,14 +366,13 @@ def wait_for_ssh(cluster_info: provision_common.ClusterInfo, # use a queue for SSH querying ips = collections.deque(ip_list) ssh_ports = collections.deque(port_list) - while ips: - ip = ips.popleft() - ssh_port = ssh_ports.popleft() - success, stderr = waiter(ip, ssh_port, **ssh_credentials) - if not success: - ips.append(ip) - ssh_ports.append(ssh_port) - if time.time() - start > timeout: + + def _retry_ssh_thread(ip_ssh_port: Tuple[str, int]): + ip, ssh_port = ip_ssh_port + success = False + while not success: + success, stderr = waiter(ip, ssh_port, **ssh_credentials) + if not success and time.time() - start > timeout: with ux_utils.print_exception_no_traceback(): raise RuntimeError( f'Failed to SSH to {ip} after timeout {timeout}s, with ' @@ -380,6 +380,14 @@ def wait_for_ssh(cluster_info: provision_common.ClusterInfo, logger.debug('Retrying in 1 second...') time.sleep(1) + # try one node and multiprocess the rest + if ips: + ip = ips.popleft() + ssh_port = ssh_ports.popleft() + _retry_ssh_thread((ip, ssh_port)) + subprocess_utils.run_in_parallel(_retry_ssh_thread, + list(zip(ips, ssh_ports))) + def _post_provision_setup( cloud_name: str, cluster_name: resources_utils.ClusterName, diff --git a/sky/resources.py b/sky/resources.py index 384f2b6a548..3b33476713b 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -55,7 +55,7 @@ def __init__( accelerators: Union[None, str, Dict[str, int]] = None, accelerator_args: Optional[Dict[str, str]] = None, use_spot: Optional[bool] = None, - job_recovery: Optional[str] = None, + job_recovery: Optional[Union[Dict[str, Union[str, int]], str]] = None, region: Optional[str] = None, zone: Optional[str] = None, image_id: Union[Dict[str, str], str, None] = None, @@ -111,6 +111,12 @@ def __init__( job to recover the cluster from preemption. Refer to `recovery_strategy module `__ # pylint: disable=line-too-long for more details. + When a dict is provided, it can have the following fields: + + - strategy: the recovery strategy to use. + - max_restarts_on_errors: the max number of restarts on user code + errors. + region: the region to use. zone: the zone to use. image_id: the image ID to use. If a str, must be a string @@ -161,10 +167,20 @@ def __init__( self._use_spot_specified = use_spot is not None self._use_spot = use_spot if use_spot is not None else False - self._job_recovery = None + self._job_recovery: Optional[Dict[str, Union[str, int]]] = None if job_recovery is not None: - if job_recovery.strip().lower() != 'none': - self._job_recovery = job_recovery.upper() + if isinstance(job_recovery, str): + job_recovery = {'strategy': job_recovery} + if 'strategy' not in job_recovery: + job_recovery['strategy'] = None + + strategy_name = job_recovery['strategy'] + if strategy_name == 'none': + self._job_recovery = None + else: + if strategy_name is not None: + job_recovery['strategy'] = strategy_name.upper() + self._job_recovery = job_recovery if disk_size is not None: if round(disk_size) != disk_size: @@ -225,6 +241,7 @@ def __init__( self._set_memory(memory) self._set_accelerators(accelerators, accelerator_args) + # TODO: move these out of init to prevent repeated calls. self._try_validate_instance_type() self._try_validate_cpus_mem() self._try_validate_managed_job_attributes() @@ -391,7 +408,7 @@ def memory(self) -> Optional[str]: @property @functools.lru_cache(maxsize=1) - def accelerators(self) -> Optional[Dict[str, int]]: + def accelerators(self) -> Optional[Dict[str, Union[int, float]]]: """Returns the accelerators field directly or by inferring. For example, Resources(AWS, 'p3.2xlarge') has its accelerators field @@ -418,7 +435,7 @@ def use_spot_specified(self) -> bool: return self._use_spot_specified @property - def job_recovery(self) -> Optional[str]: + def job_recovery(self) -> Optional[Dict[str, Union[str, int]]]: return self._job_recovery @property @@ -585,6 +602,9 @@ def _get_default_runtime_version() -> str: # TPU V5 requires a newer runtime version. if acc.startswith('tpu-v5'): return 'v2-alpha-tpuv5' + # TPU V6e requires a newer runtime version. + if acc.startswith('tpu-v6e'): + return 'v2-alpha-tpuv6e' return 'tpu-vm-base' accelerator_args['runtime_version'] = ( @@ -813,12 +833,13 @@ def _try_validate_managed_job_attributes(self) -> None: Raises: ValueError: if the attributes are invalid. """ - if self._job_recovery is None: + if self._job_recovery is None or self._job_recovery['strategy'] is None: return - if self._job_recovery not in managed_jobs.RECOVERY_STRATEGIES: + if (self._job_recovery['strategy'] + not in managed_jobs.RECOVERY_STRATEGIES): with ux_utils.print_exception_no_traceback(): raise ValueError( - f'Spot recovery strategy {self._job_recovery} ' + f'Spot recovery strategy {self._job_recovery["strategy"]} ' 'is not supported. The strategy should be among ' f'{list(managed_jobs.RECOVERY_STRATEGIES.keys())}') diff --git a/sky/serve/core.py b/sky/serve/core.py index 691a3edea0b..ea8f380a2e7 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -572,8 +572,6 @@ def status( 'controller_port': (Optional[int]) controller port, 'load_balancer_port': (Optional[int]) load balancer port, 'policy': (Optional[str]) load balancer policy description, - 'requested_resources': (sky.Resources) requested resources - for replica (deprecated), 'requested_resources_str': (str) str representation of requested resources, 'replica_info': (List[Dict[str, Any]]) replica information, diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index cbc8ef3d8cc..333e0138fb4 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -34,7 +34,7 @@ def _get_db_path() -> str: def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: """Creates the service and replica tables if they do not exist.""" - # auto_restart column is deprecated. + # auto_restart and requested_resources column is deprecated. cursor.execute("""\ CREATE TABLE IF NOT EXISTS services ( name TEXT PRIMARY KEY, @@ -323,8 +323,8 @@ def set_service_load_balancer_port(service_name: str, def _get_service_from_row(row) -> Dict[str, Any]: (current_version, name, controller_job_id, controller_port, - load_balancer_port, status, uptime, policy, _, requested_resources, - requested_resources_str, _, active_versions) = row[:13] + load_balancer_port, status, uptime, policy, _, _, requested_resources_str, + _, active_versions) = row[:13] return { 'name': name, 'controller_job_id': controller_job_id, @@ -340,10 +340,6 @@ def _get_service_from_row(row) -> Dict[str, Any]: # The versions that is active for the load balancer. This is a list of # integers in json format. This is mainly for display purpose. 'active_versions': json.loads(active_versions), - # TODO(tian): Backward compatibility. - # Remove after 2 minor release, 0.6.0. - 'requested_resources': pickle.loads(requested_resources) - if requested_resources is not None else None, 'requested_resources_str': requested_resources_str, } diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index d83a909db1d..1c82fa0f659 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -796,12 +796,7 @@ def format_service_table(service_records: List[Dict[str, Any]], replicas = _get_replicas(record) endpoint = get_endpoint(record) policy = record['policy'] - # TODO(tian): Backward compatibility. - # Remove `requested_resources` field after 2 minor release, 0.6.0. - if record.get('requested_resources_str') is None: - requested_resources_str = str(record['requested_resources']) - else: - requested_resources_str = record['requested_resources_str'] + requested_resources_str = record['requested_resources_str'] service_values = [ service_name, @@ -975,15 +970,8 @@ def _build(cls, code: List[str]) -> str: @classmethod def update_service(cls, service_name: str, version: int, mode: str) -> str: code = [ - # Backward compatibility for old serve version on the remote - # machine. The `mode` argument was added in #3249, and if the remote - # machine has an old SkyPilot version before that, we need to avoid - # passing the `mode` argument to the job_lib functions. - # TODO(zhwu): Remove this in 0.7.0 release. - f'mode_kwargs = {{"mode": {mode!r}}} ' - 'if getattr(constants, "SERVE_VERSION", 0) >= 1 else {}', f'msg = serve_utils.update_service_encoded({service_name!r}, ' - f'{version}, **mode_kwargs)', + f'{version}, mode={mode!r})', 'print(msg, end="", flush=True)', ] return cls._build(code) diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 3a97a6f8521..2eff6f40a9d 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -29,13 +29,6 @@ def __init__( base_ondemand_fallback_replicas: Optional[int] = None, upscale_delay_seconds: Optional[int] = None, downscale_delay_seconds: Optional[int] = None, - # The following arguments are deprecated. - # TODO(ziming): remove this after 2 minor release, i.e. 0.6.0. - # Deprecated: Always be True - auto_restart: Optional[bool] = None, - # Deprecated: replaced by the target_qps_per_replica. - qps_upper_threshold: Optional[float] = None, - qps_lower_threshold: Optional[float] = None, ) -> None: if max_replicas is not None and max_replicas < min_replicas: with ux_utils.print_exception_no_traceback(): @@ -62,21 +55,6 @@ def __init__( raise ValueError('readiness_path must start with a slash (/). ' f'Got: {readiness_path}') - # TODO(tian): Following field are deprecated. Remove after 2 minor - # release, i.e. 0.6.0. - if qps_upper_threshold is not None or qps_lower_threshold is not None: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - 'Field `qps_upper_threshold` and `qps_lower_threshold`' - 'under `replica_policy` are deprecated. ' - 'Please use target_qps_per_replica instead.') - if auto_restart is not None: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - 'Field `auto_restart` under `replica_policy` is deprecated.' - 'Currently, SkyServe will cleanup failed replicas' - 'and auto restart it to keep the service running.') - self._readiness_path: str = readiness_path self._initial_delay_seconds: int = initial_delay_seconds self._readiness_timeout_seconds: int = readiness_timeout_seconds @@ -160,14 +138,8 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': service_config['min_replicas'] = policy_section['min_replicas'] service_config['max_replicas'] = policy_section.get( 'max_replicas', None) - service_config['qps_upper_threshold'] = policy_section.get( - 'qps_upper_threshold', None) - service_config['qps_lower_threshold'] = policy_section.get( - 'qps_lower_threshold', None) service_config['target_qps_per_replica'] = policy_section.get( 'target_qps_per_replica', None) - service_config['auto_restart'] = policy_section.get( - 'auto_restart', None) service_config['upscale_delay_seconds'] = policy_section.get( 'upscale_delay_seconds', None) service_config['downscale_delay_seconds'] = policy_section.get( diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 604060c68ae..0fd6978ec03 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -153,7 +153,7 @@ def parse_readme(readme: str) -> str: 'tabulate', # Light weight requirement, can be replaced with "typing" once # we deprecate Python 3.7 (this will take a while). - "typing_extensions", + 'typing_extensions', 'filelock >= 3.6.0', 'packaging', 'psutil', @@ -216,8 +216,9 @@ def parse_readme(readme: str) -> str: # We need azure-identity>=1.13.0 to enable the customization of the # timeout of AzureCliCredential. 'azure': [ - 'azure-cli>=2.31.0', 'azure-core', 'azure-identity>=1.13.0', - 'azure-mgmt-network', 'azure-storage-blob', 'msgraph-sdk' + 'azure-cli>=2.65.0', 'azure-core>=1.31.0', 'azure-identity>=1.19.0', + 'azure-mgmt-network>=27.0.0', 'azure-mgmt-compute>=33.0.0', + 'azure-storage-blob>=12.23.1', 'msgraph-sdk' ] + local_ray, # We need google-api-python-client>=2.69.0 to enable 'discardLocalSsd' # parameter for stopping instances. diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 032ad5d25b1..a9b8013cad7 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -44,9 +44,6 @@ # We need to add SKY_PYTHON_CMD before ray executable because: # The ray executable is a python script with a header like: # #!/opt/conda/bin/python3 -# When we create the skypilot-runtime venv, the previously installed ray -# executable will be reused (due to --system-site-packages), and that will cause -# running ray CLI commands to use the wrong python executable. SKY_RAY_CMD = (f'{SKY_PYTHON_CMD} $([ -s {SKY_RAY_PATH_FILE} ] && ' f'cat {SKY_RAY_PATH_FILE} 2> /dev/null || which ray)') # Separate env for SkyPilot runtime dependencies. @@ -152,10 +149,11 @@ f'conda create -y -n {SKY_REMOTE_PYTHON_ENV_NAME} python=3.10 && ' f'conda activate {SKY_REMOTE_PYTHON_ENV_NAME};' # Create a separate conda environment for SkyPilot dependencies. - # We use --system-site-packages to reuse the system site packages to avoid - # the overhead of installing the same packages in the new environment. f'[ -d {SKY_REMOTE_PYTHON_ENV} ] || ' - f'{SKY_PYTHON_CMD} -m venv {SKY_REMOTE_PYTHON_ENV} --system-site-packages;' + # Do NOT use --system-site-packages here, because if users upgrade any + # packages in the base env, they interfere with skypilot dependencies. + # Reference: https://github.com/skypilot-org/skypilot/issues/4097 + f'{SKY_PYTHON_CMD} -m venv {SKY_REMOTE_PYTHON_ENV};' f'echo "$(echo {SKY_REMOTE_PYTHON_ENV})/bin/python" > {SKY_PYTHON_PATH_FILE};' ) diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index 5e7008e55d8..12d42d8c79c 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -512,16 +512,13 @@ def _get_jobs_by_ids(job_ids: List[int]) -> List[Dict[str, Any]]: return records -def _get_pending_jobs(): - rows = _CURSOR.execute( - 'SELECT job_id, created_time, submit FROM pending_jobs') - rows = list(rows) - return { - job_id: { - 'created_time': created_time, - 'submit': submit - } for job_id, created_time, submit in rows - } +def _get_pending_job(job_id: int) -> Optional[Dict[str, Any]]: + rows = _CURSOR.execute('SELECT created_time, submit FROM pending_jobs ' + f'WHERE job_id={job_id!r}') + for row in rows: + created_time, submit = row + return {'created_time': created_time, 'submit': submit} + return None def update_job_status(job_ids: List[int], @@ -535,7 +532,7 @@ def update_job_status(job_ids: List[int], during job cancelling, we still need this to handle the staleness problem, caused by instance restarting and other corner cases (if any). - This function should only be run on the remote instance with ray==2.4.0. + This function should only be run on the remote instance with ray>=2.4.0. """ if len(job_ids) == 0: return [] @@ -547,50 +544,45 @@ def update_job_status(job_ids: List[int], # In ray 2.4.0, job_client.list_jobs returns a list of JobDetails, # which contains the job status (str) and submission_id (str). + ray_job_query_time = time.time() job_detail_lists: List['ray_pydantic.JobDetails'] = job_client.list_jobs() - pending_jobs = _get_pending_jobs() job_details = {} ray_job_ids_set = set(ray_job_ids) for job_detail in job_detail_lists: if job_detail.submission_id in ray_job_ids_set: job_details[job_detail.submission_id] = job_detail - job_statuses: List[Optional[JobStatus]] = [None] * len(ray_job_ids) - for i, ray_job_id in enumerate(ray_job_ids): - job_id = job_ids[i] - if ray_job_id in job_details: - ray_status = job_details[ray_job_id].status - job_statuses[i] = _RAY_TO_JOB_STATUS_MAP[ray_status] - if job_id in pending_jobs: - if pending_jobs[job_id]['created_time'] < psutil.boot_time(): - logger.info( - f'Job {job_id} is stale, setting to FAILED: ' - f'created_time={pending_jobs[job_id]["created_time"]}, ' - f'boot_time={psutil.boot_time()}') - # The job is stale as it is created before the instance - # is booted, e.g. the instance is rebooted. - job_statuses[i] = JobStatus.FAILED - # Gives a 60 second grace period between job being submit from - # the pending table until appearing in ray jobs. - if (pending_jobs[job_id]['submit'] > 0 and - pending_jobs[job_id]['submit'] < - time.time() - _PENDING_SUBMIT_GRACE_PERIOD): - # For jobs submitted outside of the grace period, we will - # consider the ray job status. - continue - else: - # Reset the job status to PENDING even though it may not appear - # in the ray jobs, so that it will not be considered as stale. - job_statuses[i] = JobStatus.PENDING - - assert len(job_statuses) == len(job_ids), (job_statuses, job_ids) statuses = [] - for job_id, status in zip(job_ids, job_statuses): + for job_id, ray_job_id in zip(job_ids, ray_job_ids): # Per-job status lock is required because between the job status # query and the job status update, the job status in the databse # can be modified by the generated ray program. with filelock.FileLock(_get_lock_path(job_id)): + status = None + if ray_job_id in job_details: + ray_status = job_details[ray_job_id].status + status = _RAY_TO_JOB_STATUS_MAP[ray_status] + pending_job = _get_pending_job(job_id) + if pending_job is not None: + if pending_job['created_time'] < psutil.boot_time(): + logger.info(f'Job {job_id} is stale, setting to FAILED: ' + f'created_time={pending_job["created_time"]}, ' + f'boot_time={psutil.boot_time()}') + # The job is stale as it is created before the instance + # is booted, e.g. the instance is rebooted. + status = JobStatus.FAILED + # Gives a 60 second grace period between job being submit from + # the pending table until appearing in ray jobs. For jobs + # submitted outside of the grace period, we will consider the + # ray job status. + if not (pending_job['submit'] > 0 and pending_job['submit'] < + ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD): + # Reset the job status to PENDING even though it may not + # appear in the ray jobs, so that it will not be considered + # as stale. + status = JobStatus.PENDING + original_status = get_status_no_lock(job_id) assert original_status is not None, (job_id, status) if status is None: @@ -827,14 +819,6 @@ class JobLibCodeGen: 'import os', 'import getpass', 'from sky.skylet import job_lib, log_lib, constants', - # Backward compatibility for old skylet lib version on the remote - # machine. The `job_owner` argument was removed in #3037, and if the - # remote machine has an old SkyPilot version before that, we need to - # pass the `job_owner` argument to the job_lib functions. - # TODO(zhwu): Remove this in 0.7.0 release. - 'job_owner_kwargs = {} ' - 'if getattr(constants, "SKYLET_LIB_VERSION", 0) >= 1 ' - 'else {"job_owner": getpass.getuser()}', ] @classmethod @@ -861,7 +845,7 @@ def queue_job(cls, job_id: int, cmd: str) -> str: @classmethod def update_status(cls) -> str: - code = ['job_lib.update_status(**job_owner_kwargs)'] + code = ['job_lib.update_status()'] return cls._build(code) @classmethod @@ -879,7 +863,7 @@ def cancel_jobs(cls, """See job_lib.cancel_jobs().""" code = [ (f'cancelled = job_lib.cancel_jobs_encoded_results(' - f' {job_ids!r}, {cancel_all}, **job_owner_kwargs)'), + f' {job_ids!r}, {cancel_all})'), # Print cancelled IDs. Caller should parse by decoding. 'print(cancelled, flush=True)', ] @@ -902,7 +886,7 @@ def tail_logs(cls, 'run_timestamp = job_lib.get_run_timestamp(job_id)', f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)', f'log_lib.tail_logs(job_id=job_id, log_dir=log_dir, ' - f'managed_job_id={managed_job_id!r}, follow={follow}, **job_owner_kwargs)', + f'managed_job_id={managed_job_id!r}, follow={follow})', ] return cls._build(code) diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index df90736da17..1b647ca0c29 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -186,20 +186,11 @@ def run_with_log( daemon_script = os.path.join( os.path.dirname(os.path.abspath(job_lib.__file__)), 'subprocess_daemon.py') - if not hasattr(constants, 'SKY_GET_PYTHON_PATH_CMD'): - # Backward compatibility: for cluster started before #3326, this - # constant does not exist. Since we generate the job script - # in backends.cloud_vm_ray_backend with inspect, so the - # the lates `run_with_log` will be used, but the `constants` is - # not updated. We fallback to `python3` in this case. - # TODO(zhwu): remove this after 0.7.0. - python_path = 'python3' - else: - python_path = subprocess.check_output( - constants.SKY_GET_PYTHON_PATH_CMD, - shell=True, - stderr=subprocess.DEVNULL, - encoding='utf-8').strip() + python_path = subprocess.check_output( + constants.SKY_GET_PYTHON_PATH_CMD, + shell=True, + stderr=subprocess.DEVNULL, + encoding='utf-8').strip() daemon_cmd = [ python_path, daemon_script, diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 77ddda6652f..7b9737748d3 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -34,6 +34,7 @@ provider: # instead of the cluster_name. This ensures that ray creates new instances # for different cluster_name. resource_group: {{resource_group}} + use_external_resource_group: {{use_external_resource_group}} # Keep (otherwise cannot reuse when re-provisioning). # teardown(terminate=True) will override this. cache_stopped_nodes: True @@ -67,6 +68,8 @@ available_node_types: imageOffer: {{image_offer}} imageSku: "{{image_sku}}" imageVersion: {{image_version}} + # Community Gallery Image ID + communityGalleryImageId: {{community_gallery_image_id}} osDiskSizeGB: {{disk_size}} osDiskTier: {{disk_tier}} {%- if use_spot %} @@ -80,7 +83,6 @@ available_node_types: {%- for cmd in cloud_init_setup_commands %} {{ cmd }} {%- endfor %} - need_nvidia_driver_extension: {{need_nvidia_driver_extension}} {%- if disk_performance_tier is not none %} disk_performance_tier: {{disk_performance_tier}} {%- endif %} diff --git a/sky/usage/usage_lib.py b/sky/usage/usage_lib.py index a6c10da5c7a..07867939ee5 100644 --- a/sky/usage/usage_lib.py +++ b/sky/usage/usage_lib.py @@ -432,8 +432,9 @@ def entrypoint_context(name: str, fallback: bool = False): with ux_utils.enable_traceback(): trace = traceback.format_exc() messages.usage.stacktrace = trace - if hasattr(e, 'detailed_reason') and e.detailed_reason is not None: - messages.usage.stacktrace += '\nDetails: ' + e.detailed_reason + detailed_reason = getattr(e, 'detailed_reason', None) + if detailed_reason is not None: + messages.usage.stacktrace += '\nDetails: ' + detailed_reason messages.usage.exception = common_utils.remove_color( common_utils.format_exception(e)) raise diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 1216c463046..539e4124a63 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -373,7 +373,6 @@ def _wrapper(f): @functools.wraps(f) def _record(*args, **kwargs): - nonlocal name_or_fn with cls(name_or_fn, **ctx_kwargs): return f(*args, **kwargs) @@ -387,7 +386,6 @@ def _record(*args, **kwargs): @functools.wraps(name_or_fn) def _record(*args, **kwargs): - nonlocal name_or_fn f = name_or_fn func_name = getattr(f, '__qualname__', f.__name__) module_name = getattr(f, '__module__', '') @@ -590,7 +588,10 @@ def validate_schema(obj, schema, err_msg_prefix='', skip_none=True): e.message) else: err_msg = err_msg_prefix + assert isinstance(e.schema, dict), 'Schema must be a dictionary' known_fields = set(e.schema.get('properties', {}).keys()) + assert isinstance(e.instance, + dict), 'Instance must be a dictionary' for field in e.instance: if field not in known_fields: most_similar_field = difflib.get_close_matches( diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0c71357c856..0ab2fd7e117 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -505,20 +505,17 @@ def get_controller_resources( if handle is not None: controller_resources_to_use = handle.launched_resources - if controller_resources_to_use.cloud is not None: - return {controller_resources_to_use} + # If the controller and replicas are from the same cloud (and region/zone), + # it should provide better connectivity. We will let the controller choose + # from the clouds (and regions/zones) of the resources if the user does not + # specify the cloud (and region/zone) for the controller. - # If the controller and replicas are from the same cloud, it should - # provide better connectivity. We will let the controller choose from - # the clouds of the resources if the controller does not exist. - # TODO(tian): Consider respecting the regions/zones specified for the - # resources as well. - requested_clouds: Set['clouds.Cloud'] = set() + requested_clouds_with_region_zone: Dict[str, Dict[Optional[str], + Set[Optional[str]]]] = {} for resource in task_resources: - # cloud is an object and will not be able to be distinguished by set. - # Here we manually check if the cloud is in the set. if resource.cloud is not None: - if not clouds.cloud_in_iterable(resource.cloud, requested_clouds): + cloud_name = str(resource.cloud) + if cloud_name not in requested_clouds_with_region_zone: try: resource.cloud.check_features_are_supported( resources.Resources(), @@ -526,7 +523,26 @@ def get_controller_resources( except exceptions.NotSupportedError: # Skip the cloud if it does not support hosting controllers. continue - requested_clouds.add(resource.cloud) + requested_clouds_with_region_zone[cloud_name] = {} + if resource.region is None: + # If one of the resource.region is None, this could represent + # that the user is unsure about which region the resource is + # hosted in. In this case, we allow any region for this cloud. + requested_clouds_with_region_zone[cloud_name] = {None: {None}} + elif None not in requested_clouds_with_region_zone[cloud_name]: + if resource.region not in requested_clouds_with_region_zone[ + cloud_name]: + requested_clouds_with_region_zone[cloud_name][ + resource.region] = set() + # If one of the resource.zone is None, allow any zone in the + # region. + if resource.zone is None: + requested_clouds_with_region_zone[cloud_name][ + resource.region] = {None} + elif None not in requested_clouds_with_region_zone[cloud_name][ + resource.region]: + requested_clouds_with_region_zone[cloud_name][ + resource.region].add(resource.zone) else: # if one of the resource.cloud is None, this could represent user # does not know which cloud is best for the specified resources. @@ -536,14 +552,49 @@ def get_controller_resources( # - cloud: runpod # accelerators: A40 # In this case, we allow the controller to be launched on any cloud. - requested_clouds.clear() + requested_clouds_with_region_zone.clear() break - if not requested_clouds: + + # Extract filtering criteria from the controller resources specified by the + # user. + controller_cloud = str( + controller_resources_to_use.cloud + ) if controller_resources_to_use.cloud is not None else None + controller_region = controller_resources_to_use.region + controller_zone = controller_resources_to_use.zone + + # Filter clouds if controller_resources_to_use.cloud is specified. + filtered_clouds = ({controller_cloud} if controller_cloud is not None else + requested_clouds_with_region_zone.keys()) + + # Filter regions and zones and construct the result. + result: Set[resources.Resources] = set() + for cloud_name in filtered_clouds: + regions = requested_clouds_with_region_zone.get(cloud_name, + {None: {None}}) + + # Filter regions if controller_resources_to_use.region is specified. + filtered_regions = ({controller_region} if controller_region is not None + else regions.keys()) + + for region in filtered_regions: + zones = regions.get(region, {None}) + + # Filter zones if controller_resources_to_use.zone is specified. + filtered_zones = ({controller_zone} + if controller_zone is not None else zones) + + # Create combinations of cloud, region, and zone. + for zone in filtered_zones: + resource_copy = controller_resources_to_use.copy( + cloud=clouds.CLOUD_REGISTRY.from_str(cloud_name), + region=region, + zone=zone) + result.add(resource_copy) + + if not result: return {controller_resources_to_use} - return { - controller_resources_to_use.copy(cloud=controller_cloud) - for controller_cloud in requested_clouds - } + return result def _setup_proxy_command_on_controller( diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index fefd2b3ad1f..68eaaef4c3e 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -197,11 +197,21 @@ def fill_default_config_in_dag_for_job_launch(dag: dag_lib.Dag) -> None: for task_ in dag.tasks: new_resources_list = [] + default_strategy = jobs.DEFAULT_RECOVERY_STRATEGY + assert default_strategy is not None for resources in list(task_.resources): - change_default_value: Dict[str, Any] = {} - if resources.job_recovery is None: - change_default_value[ - 'job_recovery'] = jobs.DEFAULT_RECOVERY_STRATEGY + original_job_recovery = resources.job_recovery + job_recovery = {'strategy': default_strategy} + if isinstance(original_job_recovery, str): + job_recovery['strategy'] = original_job_recovery + elif isinstance(original_job_recovery, dict): + job_recovery.update(original_job_recovery) + strategy = job_recovery.get('strategy') + if strategy is None: + job_recovery['strategy'] = default_strategy + change_default_value: Dict[str, Any] = { + 'job_recovery': job_recovery + } new_resources = resources.copy(**change_default_value) new_resources_list.append(new_resources) diff --git a/sky/utils/resources_utils.py b/sky/utils/resources_utils.py index 72aa5ac05d3..653bb109ac0 100644 --- a/sky/utils/resources_utils.py +++ b/sky/utils/resources_utils.py @@ -2,9 +2,11 @@ import dataclasses import enum import itertools +import json +import math import re import typing -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Union from sky import skypilot_config from sky.clouds import cloud_registry @@ -163,6 +165,16 @@ def get_readable_resources_repr(handle: 'backends.CloudVmRayResourceHandle', return _DEFAULT_MESSAGE_HANDLE_INITIALIZING +def make_ray_custom_resources_str( + resource_dict: Optional[Dict[str, Union[int, float]]]) -> Optional[str]: + """Convert resources to Ray custom resources format.""" + if resource_dict is None: + return None + # Ray does not allow fractional resources, so we need to ceil the values. + ceiled_dict = {k: math.ceil(v) for k, v in resource_dict.items()} + return json.dumps(ceiled_dict, separators=(',', ':')) + + @dataclasses.dataclass class FeasibleResources: """Feasible resources returned by cloud. diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 94a6ed690e1..81c4cb332a6 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -92,7 +92,27 @@ def _get_single_resources_schema(): 'type': 'string', }, 'job_recovery': { - 'type': 'string', + # Either a string or a dict. + 'anyOf': [{ + 'type': 'string', + }, { + 'type': 'object', + 'required': [], + 'additionalProperties': False, + 'properties': { + 'strategy': { + 'anyOf': [{ + 'type': 'string', + }, { + 'type': 'null', + }], + }, + 'max_restarts_on_errors': { + 'type': 'integer', + 'minimum': 0, + }, + } + }], }, 'disk_size': { 'type': 'integer', @@ -357,19 +377,6 @@ def get_service_schema(): 'downscale_delay_seconds': { 'type': 'number', }, - # TODO(MaoZiming): Fields `qps_upper_threshold`, - # `qps_lower_threshold` and `auto_restart` are deprecated. - # Temporarily keep these fields for backward compatibility. - # Remove after 2 minor release, i.e., 0.6.0. - 'auto_restart': { - 'type': 'boolean', - }, - 'qps_upper_threshold': { - 'type': 'number', - }, - 'qps_lower_threshold': { - 'type': 'number', - }, } }, 'replicas': { @@ -595,7 +602,7 @@ def get_cluster_schema(): _LABELS_SCHEMA = { # Deprecated: 'instance_tags' is replaced by 'labels'. Keeping for backward - # compatibility. Will be removed after 0.7.0. + # compatibility. Will be removed after 0.8.0. 'instance_tags': { 'type': 'object', 'required': [], @@ -771,6 +778,9 @@ def get_config_schema(): 'storage_account': { 'type': 'string', }, + 'resource_group_vm': { + 'type': 'string', + }, } }, 'kubernetes': { diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index 4f83c379ccf..276fda899dd 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -167,8 +167,8 @@ MANAGED_JOB_JOB_NAME=${CLUSTER_NAME}-${uuid:0:4} if [ "$start_from" -le 7 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky spot launch -d --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -n ${MANAGED_JOB_JOB_NAME}-7-0 "echo hi; sleep 1000" -sky spot launch -d --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -n ${MANAGED_JOB_JOB_NAME}-7-1 "echo hi; sleep 400" +sky jobs launch -d --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -n ${MANAGED_JOB_JOB_NAME}-7-0 "echo hi; sleep 1000" +sky jobs launch -d --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -n ${MANAGED_JOB_JOB_NAME}-7-1 "echo hi; sleep 400" conda activate sky-back-compat-current rm -r ~/.sky/wheels || true s=$(sky jobs queue | grep ${MANAGED_JOB_JOB_NAME}-7 | grep "RUNNING" | wc -l) diff --git a/tests/common.py b/tests/common.py index c6f08588d99..d41ff3bead0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -70,6 +70,9 @@ def _get_az_mappings(_): lambda *_args, **_kwargs: [True, '']) monkeypatch.setattr('sky.provision.kubernetes.utils.get_spot_label', lambda *_args, **_kwargs: [None, None]) + monkeypatch.setattr( + 'sky.provision.kubernetes.utils.is_kubeconfig_exec_auth', + lambda *_args, **_kwargs: [False, None]) # monkeypatch class Kubernetes. monkeypatch.setattr( diff --git a/tests/conftest.py b/tests/conftest.py index b4e025a8f2d..aa0d0c88289 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ # to mark a test as slow and to skip by default. # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option -# By default, only run generic tests and cloud-specific tests for GCP and Azure, +# By default, only run generic tests and cloud-specific tests for AWS and Azure, # due to the cloud credit limit for the development account. # # A "generic test" tests a generic functionality (e.g., autostop) that @@ -24,7 +24,7 @@ 'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci', 'kubernetes', 'vsphere', 'cudo', 'fluidstack', 'paperspace' ] -default_clouds_to_run = ['gcp', 'azure'] +default_clouds_to_run = ['aws', 'azure'] # Translate cloud name to pytest keyword. We need this because # @pytest.mark.lambda is not allowed, so we use @pytest.mark.lambda_cloud @@ -72,7 +72,7 @@ def pytest_addoption(parser): parser.addoption( '--generic-cloud', type=str, - default='gcp', + default='aws', choices=all_clouds_in_smoke_tests, help='Cloud to use for generic tests. If the generic cloud is ' 'not within the clouds to be run, it will be reset to the first ' @@ -138,8 +138,8 @@ def pytest_collection_modifyitems(config, items): for cloud in all_clouds_in_smoke_tests: cloud_keyword = cloud_to_pytest_keyword[cloud] if (cloud_keyword in item.keywords and cloud not in cloud_to_run): - # Need to check both conditions as 'gcp' is added to cloud_to_run - # when tested for cloudflare + # Need to check both conditions as the first default cloud is + # added to cloud_to_run when tested for cloudflare if config.getoption('--cloudflare') and cloud == 'cloudflare': continue item.add_marker(skip_marks[cloud]) @@ -206,7 +206,7 @@ def enable_all_clouds(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture def aws_config_region(monkeypatch: pytest.MonkeyPatch) -> str: from sky import skypilot_config - region = 'us-west-2' + region = 'us-east-2' if skypilot_config.loaded(): ssh_proxy_command = skypilot_config.get_nested( ('aws', 'ssh_proxy_command'), None) diff --git a/tests/test_api.py b/tests/test_api.py index 4d6658fcd05..5a33336dd92 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,7 +1,20 @@ import sky +from sky.clouds.cloud import Cloud def test_sky_launch(enable_all_clouds): task = sky.Task() job_id, handle = sky.launch(task, dryrun=True) assert job_id is None and handle is None + + +def test_k8s_alias(enable_all_clouds): + + def dryrun_task_with_cloud(cloud: Cloud): + task = sky.Task() + task.set_resources_override({'cloud': cloud}) + sky.launch(task, dryrun=True) + + dryrun_task_with_cloud(sky.K8s()) + + dryrun_task_with_cloud(sky.Kubernetes()) diff --git a/tests/test_cli.py b/tests/test_cli.py index 3a2417a6cde..36f2a6ea782 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,6 @@ from click import testing as cli_testing -import sky from sky import exceptions import sky.cli as cli @@ -103,3 +102,40 @@ def test_show_gpus(): result = cli_runner.invoke(cli.show_gpus, ['V100:4', '--cloud', cloud, '--all']) assert isinstance(result.exception, SystemExit) + + +def test_k8s_alias_check(): + cli_runner = cli_testing.CliRunner() + + result = cli_runner.invoke(cli.check, ['k8s']) + assert not result.exit_code + + result = cli_runner.invoke(cli.check, ['kubernetes']) + assert not result.exit_code + + result = cli_runner.invoke(cli.check, ['notarealcloud']) + assert isinstance(result.exception, ValueError) + + +def test_k8s_alias(enable_all_clouds): + cli_runner = cli_testing.CliRunner() + + result = cli_runner.invoke(cli.launch, ['--cloud', 'k8s', '--dryrun']) + assert not result.exit_code + + result = cli_runner.invoke(cli.launch, + ['--cloud', 'kubernetes', '--dryrun']) + assert not result.exit_code + + result = cli_runner.invoke(cli.launch, + ['--cloud', 'notarealcloud', '--dryrun']) + assert isinstance(result.exception, ValueError) + + result = cli_runner.invoke(cli.show_gpus, ['--cloud', 'k8s']) + assert not result.exit_code + + result = cli_runner.invoke(cli.show_gpus, ['--cloud', 'kubernetes']) + assert not result.exit_code + + result = cli_runner.invoke(cli.show_gpus, ['--cloud', 'notarealcloud']) + assert isinstance(result.exception, ValueError) diff --git a/tests/test_jobs_and_serve.py b/tests/test_jobs_and_serve.py index a599fb7ba88..237ffd440da 100644 --- a/tests/test_jobs_and_serve.py +++ b/tests/test_jobs_and_serve.py @@ -307,7 +307,6 @@ def mock_get_services_one_service( 'controller_port': 30001, 'load_balancer_port': 30000, 'policy': None, - 'requested_resources': sky.Resources(), 'requested_resources_str': '', 'replica_info': [], } diff --git a/tests/test_smoke.py b/tests/test_smoke.py index ed86f93ca27..cdfd9dfc7cb 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -85,24 +85,15 @@ 'touch ~/.ssh/id_rsa.pub' ] -# Wait until the jobs controller is not in INIT state. -# This is a workaround for the issue that when multiple job tests -# are running in parallel, the jobs controller may be in INIT and -# the job queue/cancel command will return staled table. -_JOB_QUEUE_WAIT = ('s=$(sky jobs queue); ' - 'until ! echo "$s" | grep "jobs will not be shown until"; ' - 'do echo "Waiting for job queue to be ready..."; ' - 'sleep 5; s=$(sky jobs queue); done; echo "$s"; ' - 'echo; echo; echo "$s"') -_JOB_CANCEL_WAIT = ( - 's=$(sky jobs cancel -y -n {job_name}); ' - 'until ! echo "$s" | grep "Please wait for the controller to be ready."; ' - 'do echo "Waiting for the jobs controller ' - 'to be ready"; sleep 5; s=$(sky jobs cancel -y -n {job_name}); ' - 'done; echo "$s"; echo; echo; echo "$s"') -# TODO(zhwu): make the jobs controller on GCP, to avoid parallel test issues -# when the controller being on Azure, which takes a long time for launching -# step. +# Get the job queue, and print it once on its own, then print it again to +# use with grep by the caller. +_GET_JOB_QUEUE = 's=$(sky jobs queue); echo "$s"; echo "$s"' +# Wait for a job to be not in RUNNING state. Used to check for RECOVERING. +_JOB_WAIT_NOT_RUNNING = ( + 's=$(sky jobs queue);' + 'until ! echo "$s" | grep "{job_name}" | grep "RUNNING"; do ' + 'sleep 10; s=$(sky jobs queue);' + 'echo "Waiting for job to stop RUNNING"; echo "$s"; done') DEFAULT_CMD_TIMEOUT = 15 * 60 @@ -369,6 +360,69 @@ def test_minimal(generic_cloud: str): run_one_test(test) +# ---------- Test fast launch ---------- +def test_launch_fast(generic_cloud: str): + name = _get_cluster_name() + + test = Test( + 'test_launch_fast', + [ + # First launch to create the cluster + f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', + f'sky logs {name} 1 --status', + + # Second launch to test fast launch - should not reprovision + f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast tests/test_yamls/minimal.yaml) && ' + ' echo "$s" && ' + # Validate that cluster was not re-launched. + '! echo "$s" | grep -A 1 "Launching on" | grep "is up." && ' + # Validate that setup was not re-run. + '! echo "$s" | grep -A 1 "Running setup on" | grep "running setup" && ' + # Validate that the task ran and finished. + 'echo "$s" | grep -A 1 "task run finish" | grep "Job finished (status: SUCCEEDED)"', + f'sky logs {name} 2 --status', + f'sky status -r {name} | grep UP', + ], + f'sky down -y {name}', + timeout=_get_timeout(generic_cloud), + ) + run_one_test(test) + + +# See cloud exclusion explanations in test_autostop +@pytest.mark.no_fluidstack +@pytest.mark.no_lambda_cloud +@pytest.mark.no_ibm +@pytest.mark.no_kubernetes +def test_launch_fast_with_autostop(generic_cloud: str): + name = _get_cluster_name() + # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure + # the VM is stopped. + autostop_timeout = 600 if generic_cloud == 'azure' else 250 + + test = Test( + 'test_launch_fast_with_autostop', + [ + # First launch to create the cluster with a short autostop + f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', + f'sky logs {name} 1 --status', + f'sky status -r {name} | grep UP', + f'sleep {autostop_timeout}', + + # Ensure cluster is stopped + f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + + # Launch again. Do full output validation - we expect the cluster to re-launch + f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', + f'sky logs {name} 2 --status', + f'sky status -r {name} | grep UP', + ], + f'sky down -y {name}', + timeout=_get_timeout(generic_cloud) + autostop_timeout, + ) + run_one_test(test) + + # ---------- Test region ---------- @pytest.mark.aws def test_aws_region(): @@ -2643,6 +2697,9 @@ def test_stop_gcp_spot(): # ---------- Testing managed job ---------- +# TODO(zhwu): make the jobs controller on GCP, to avoid parallel test issues +# when the controller being on Azure, which takes a long time for launching +# step. @pytest.mark.managed_jobs def test_managed_jobs(generic_cloud: str): """Test the managed jobs yaml.""" @@ -2653,22 +2710,21 @@ def test_managed_jobs(generic_cloud: str): f'sky jobs launch -n {name}-1 --cloud {generic_cloud} examples/managed_job.yaml -y -d', f'sky jobs launch -n {name}-2 --cloud {generic_cloud} examples/managed_job.yaml -y -d', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name}-1 | head -n1 | grep "STARTING\|RUNNING"', - f'{_JOB_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "STARTING\|RUNNING"', - _JOB_CANCEL_WAIT.format(job_name=f'{name}-1'), + f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', + f'sky jobs cancel -y -n {name}-1', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name}-1 | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 200', - f'{_JOB_QUEUE_WAIT}| grep {name}-1 | head -n1 | grep CANCELLED', + f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep CANCELLED', # Test the functionality for logging. f's=$(sky jobs logs -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "start counting"', f's=$(sky jobs logs --controller -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "Cluster launched:"', - f'{_JOB_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "RUNNING\|SUCCEEDED"', + f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "RUNNING\|SUCCEEDED"', ], - # TODO(zhwu): Change to _JOB_CANCEL_WAIT.format(job_name=f'{name}-1 -n {name}-2') when + # TODO(zhwu): Change to f'sky jobs cancel -y -n {name}-1 -n {name}-2' when # canceling multiple job names is supported. - (_JOB_CANCEL_WAIT.format(job_name=f'{name}-1') + '; ' + - _JOB_CANCEL_WAIT.format(job_name=f'{name}-2')), + f'sky jobs cancel -y -n {name}-1; sky jobs cancel -y -n {name}-2', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) @@ -2690,26 +2746,26 @@ def test_job_pipeline(generic_cloud: str): [ f'sky jobs launch -n {name} tests/test_yamls/pipeline.yaml -y -d', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "STARTING\|RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING\|RUNNING"', # `grep -A 4 {name}` finds the job with {name} and the 4 lines # after it, i.e. the 4 tasks within the job. # `sed -n 2p` gets the second line of the 4 lines, i.e. the first # task within the job. - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 2p | grep "STARTING\|RUNNING"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 3p | grep "PENDING"', - _JOB_CANCEL_WAIT.format(job_name=f'{name}'), + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "STARTING\|RUNNING"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "PENDING"', + f'sky jobs cancel -y -n {name}', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 2p | grep "CANCELLING\|CANCELLED"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 3p | grep "CANCELLING\|CANCELLED"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 4p | grep "CANCELLING\|CANCELLED"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 5p | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLING\|CANCELLED"', 'sleep 200', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 2p | grep "CANCELLED"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 3p | grep "CANCELLED"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 4p | grep "CANCELLED"', - f'{_JOB_QUEUE_WAIT}| grep -A 4 {name}| sed -n 5p | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"', ], - _JOB_CANCEL_WAIT.format(job_name=f'{name}'), + f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=30 * 60, ) @@ -2732,9 +2788,9 @@ def test_managed_jobs_failed_setup(generic_cloud: str): f'sky jobs launch -n {name} --cloud {generic_cloud} -y -d tests/test_yamls/failed_setup.yaml', 'sleep 330', # Make sure the job failed quickly. - f'{_JOB_QUEUE_WAIT} | grep {name} | head -n1 | grep "FAILED_SETUP"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) @@ -2757,17 +2813,17 @@ def test_managed_jobs_pipeline_failed_setup(generic_cloud: str): f'sky jobs launch -n {name} -y -d tests/test_yamls/failed_setup_pipeline.yaml', 'sleep 600', # Make sure the job failed quickly. - f'{_JOB_QUEUE_WAIT} | grep {name} | head -n1 | grep "FAILED_SETUP"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', # Task 0 should be SUCCEEDED. - f'{_JOB_QUEUE_WAIT} | grep -A 4 {name}| sed -n 2p | grep "SUCCEEDED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 2p | grep "SUCCEEDED"', # Task 1 should be FAILED_SETUP. - f'{_JOB_QUEUE_WAIT} | grep -A 4 {name}| sed -n 3p | grep "FAILED_SETUP"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 3p | grep "FAILED_SETUP"', # Task 2 should be CANCELLED. - f'{_JOB_QUEUE_WAIT} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 4p | grep "CANCELLED"', # Task 3 should be CANCELLED. - f'{_JOB_QUEUE_WAIT} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep -A 4 {name}| sed -n 5p | grep "CANCELLED"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=30 * 60, ) @@ -2790,7 +2846,7 @@ def test_managed_jobs_recovery_aws(aws_config_region): [ f'sky jobs launch --cloud aws --region {region} --use-spot -n {name} "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', 'sleep 360', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -2798,13 +2854,13 @@ def test_managed_jobs_recovery_aws(aws_config_region): f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}* ' f'--query Reservations[].Instances[].InstanceId ' '--output text)'), - 'sleep 100', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RECOVERING"', + _JOB_WAIT_NOT_RUNNING.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', 'sleep 200', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | grep "$RUN_ID"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=25 * 60, ) run_one_test(test) @@ -2830,17 +2886,17 @@ def test_managed_jobs_recovery_gcp(): [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --cpus 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', 'sleep 360', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. terminate_cmd, - 'sleep 60', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RECOVERING"', + _JOB_WAIT_NOT_RUNNING.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', 'sleep 200', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=25 * 60, ) run_one_test(test) @@ -2861,7 +2917,7 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_aws.yaml -y -d', 'sleep 400', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -2878,16 +2934,16 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): f'-{user_hash} ' f'--query Reservations[].Instances[].InstanceId ' '--output text)'), - 'sleep 100', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RECOVERING"', + _JOB_WAIT_NOT_RUNNING.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', 'sleep 200', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', f'cat /tmp/{name}-run-ids | sed -n 2p | grep `cat /tmp/{name}-run-id`', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=25 * 60, ) run_one_test(test) @@ -2912,7 +2968,7 @@ def test_managed_jobs_pipeline_recovery_gcp(): [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_gcp.yaml -y -d', 'sleep 400', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -2921,16 +2977,16 @@ def test_managed_jobs_pipeline_recovery_gcp(): # separated by `-`. (f'MANAGED_JOB_ID=`cat /tmp/{name}-run-id | rev | ' f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'), - 'sleep 60', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RECOVERING"', + _JOB_WAIT_NOT_RUNNING.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', 'sleep 200', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', f'cat /tmp/{name}-run-ids | sed -n 2p | grep `cat /tmp/{name}-run-id`', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=25 * 60, ) run_one_test(test) @@ -2951,9 +3007,9 @@ def test_managed_jobs_recovery_default_resources(generic_cloud: str): [ f'sky jobs launch -n {name} --cloud {generic_cloud} --use-spot "sleep 30 && sudo shutdown now && sleep 1000" -y -d', 'sleep 360', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING\|RECOVERING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|RECOVERING"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=25 * 60, ) run_one_test(test) @@ -2972,7 +3028,7 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): [ f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', 'sleep 450', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -2981,13 +3037,13 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): 'Name=tag:ray-node-type,Values=worker ' f'--query Reservations[].Instances[].InstanceId ' '--output text)'), - 'sleep 50', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RECOVERING"', + _JOB_WAIT_NOT_RUNNING.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', 'sleep 560', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=30 * 60, ) run_one_test(test) @@ -3013,17 +3069,17 @@ def test_managed_jobs_recovery_multi_node_gcp(): [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', 'sleep 400', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. terminate_cmd, - 'sleep 50', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RECOVERING"', + _JOB_WAIT_NOT_RUNNING.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', 'sleep 420', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', timeout=25 * 60, ) run_one_test(test) @@ -3046,12 +3102,12 @@ def test_managed_jobs_cancellation_aws(aws_config_region): # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot "sleep 1000" -y -d', 'sleep 60', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "STARTING"', - _JOB_CANCEL_WAIT.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING\|RUNNING"', + f'sky jobs cancel -y -n {name}', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', (f's=$(aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}-* ' f'--query Reservations[].Instances[].State[].Name ' @@ -3060,11 +3116,11 @@ def test_managed_jobs_cancellation_aws(aws_config_region): # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud aws --region {region} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', 'sleep 300', - _JOB_CANCEL_WAIT.format(job_name=f'{name}-2'), + f'sky jobs cancel -y -n {name}-2', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', (f's=$(aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_2_on_cloud}-* ' f'--query Reservations[].Instances[].State[].Name ' @@ -3073,20 +3129,20 @@ def test_managed_jobs_cancellation_aws(aws_config_region): # Test cancellation during spot job is recovering. f'sky jobs launch --cloud aws --region {region} -n {name}-3 --use-spot "sleep 1000" -y -d', 'sleep 300', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' f'aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_3_on_cloud}-* ' f'--query Reservations[].Instances[].InstanceId ' '--output text)'), - 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "RECOVERING"', - _JOB_CANCEL_WAIT.format(job_name=f'{name}-3'), + _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', + f'sky jobs cancel -y -n {name}-3', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', # The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$(aws ec2 describe-instances --region {region} ' @@ -3122,33 +3178,33 @@ def test_managed_jobs_cancellation_gcp(): # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot "sleep 1000" -y -d', 'sleep 60', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "STARTING"', - _JOB_CANCEL_WAIT.format(job_name=name), + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING"', + f'sky jobs cancel -y -n {name}', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', 'sleep 300', - _JOB_CANCEL_WAIT.format(job_name=f'{name}-2'), + f'sky jobs cancel -y -n {name}-2', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name}-2 | head -n1 | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', # Test cancellation during spot job is recovering. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-3 --use-spot "sleep 1000" -y -d', 'sleep 300', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "RUNNING"', + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', # Terminate the cluster manually. terminate_cmd, - 'sleep 80', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "RECOVERING"', - _JOB_CANCEL_WAIT.format(job_name=f'{name}-3'), + _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', + f'sky jobs cancel -y -n {name}-3', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', 'sleep 120', - f'{_JOB_QUEUE_WAIT}| grep {name}-3 | head -n1 | grep "CANCELLED"', + f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', # The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"' @@ -3239,12 +3295,12 @@ def test_managed_jobs_storage(generic_cloud: str): f'sky jobs launch -n {name}{use_spot} --cloud {generic_cloud}{region_flag} {file_path} -y', region_validation_cmd, # Check if the bucket is created in the correct region 'sleep 60', # Wait the spot queue to be updated - f'{_JOB_QUEUE_WAIT}| grep {name} | grep SUCCEEDED', + f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]', # Check if file was written to the mounted output bucket output_check_cmd ], - (_JOB_CANCEL_WAIT.format(job_name=name), + (f'sky jobs cancel -y -n {name}', f'; sky storage delete {output_storage_name} || true'), # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, @@ -3264,11 +3320,11 @@ def test_managed_jobs_tpu(): [ f'sky jobs launch -n {name} --use-spot examples/tpu/tpuvm_mnist.yaml -y -d', 'sleep 5', - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep STARTING', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep STARTING', 'sleep 900', # TPU takes a while to launch - f'{_JOB_QUEUE_WAIT}| grep {name} | head -n1 | grep "RUNNING\|SUCCEEDED"', + f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|SUCCEEDED"', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) @@ -3285,9 +3341,9 @@ def test_managed_jobs_inline_env(generic_cloud: str): [ f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"', 'sleep 20', - f'{_JOB_QUEUE_WAIT} | grep {name} | grep SUCCEEDED', + f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', ], - _JOB_CANCEL_WAIT.format(job_name=name), + f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) @@ -3579,7 +3635,7 @@ def test_long_setup_run_script(generic_cloud: str): setup: | echo "start long setup" """)) - for i in range(1024 * 120): + for i in range(1024 * 200): f.write(f' echo {i}\n') f.write(' echo "end long setup"\n') f.write( @@ -3587,7 +3643,7 @@ def test_long_setup_run_script(generic_cloud: str): run: | echo "run" """)) - for i in range(1024 * 120): + for i in range(1024 * 200): f.write(f' echo {i}\n') f.write(' echo "end run"\n') f.flush() @@ -4383,6 +4439,28 @@ def test_core_api_sky_launch_exec(): sky.down(name) +# The sky launch CLI has some additional checks to make sure the cluster is up/ +# restarted. However, the core API doesn't have these; make sure it still works +def test_core_api_sky_launch_fast(generic_cloud: str): + name = _get_cluster_name() + cloud = sky.clouds.CLOUD_REGISTRY.from_str(generic_cloud) + try: + task = sky.Task(run="whoami").set_resources(sky.Resources(cloud=cloud)) + sky.launch(task, + cluster_name=name, + idle_minutes_to_autostop=1, + fast=True) + # Sleep to let the cluster autostop + time.sleep(120) + # Run it again - should work with fast=True + sky.launch(task, + cluster_name=name, + idle_minutes_to_autostop=1, + fast=True) + finally: + sky.down(name) + + # ---------- Testing Storage ---------- class TestStorageWithCredentials: """Storage tests which require credentials and network connection""" @@ -5566,7 +5644,7 @@ def test_multiple_accelerators_unordered(): def test_multiple_accelerators_unordered_with_default(): name = _get_cluster_name() test = Test( - 'multiple-accelerators-unordered', + 'multiple-accelerators-unordered-with-default', [ f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_unordered_with_default.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. @@ -5613,7 +5691,7 @@ def test_sky_bench(generic_cloud: str): @pytest.mark.kubernetes def test_kubernetes_context_failover(): """Test if the kubernetes context failover works. - + This test requires two kubernetes clusters: - kind-skypilot: the local cluster with mock labels for 8 H100 GPUs. - another accessible cluster: with enough CPUs diff --git a/tests/unit_tests/test_azure_utils.py b/tests/unit_tests/test_azure_utils.py new file mode 100644 index 00000000000..93ef5caadb0 --- /dev/null +++ b/tests/unit_tests/test_azure_utils.py @@ -0,0 +1,21 @@ +import pytest + +from sky.clouds.utils import azure_utils + + +def test_validate_image_id(): + # Valid marketplace image ID + azure_utils.validate_image_id("publisher:offer:sku:version") + + # Valid community image ID + azure_utils.validate_image_id( + "/CommunityGalleries/gallery-name/Images/image-name") + + # Invalid format (neither marketplace nor community) + with pytest.raises(ValueError): + azure_utils.validate_image_id( + "CommunityGalleries/gallery-name/Images/image-name") + + # Invalid marketplace image ID (too few parts) + with pytest.raises(ValueError): + azure_utils.validate_image_id("publisher:offer:sku") diff --git a/tests/unit_tests/test_controller_utils.py b/tests/unit_tests/test_controller_utils.py index 7465f648385..f41c7413bc1 100644 --- a/tests/unit_tests/test_controller_utils.py +++ b/tests/unit_tests/test_controller_utils.py @@ -1,5 +1,5 @@ """Test the controller_utils module.""" -from typing import Any, Dict +from typing import Any, Dict, Optional, Set, Tuple import pytest @@ -65,6 +65,24 @@ def get_custom_controller_resources(keys, default): controller_resources_config, k, v) +def _check_controller_resources( + controller_resources: Set[sky.Resources], + expected_combinations: Set[Tuple[Optional[str], Optional[str], + Optional[str]]], + default_controller_resources: Dict[str, Any]) -> None: + """Helper function to check that the controller resources match the + expected combinations.""" + for r in controller_resources: + config = r.to_yaml_config() + cloud = config.pop('cloud') + region = config.pop('region', None) + zone = config.pop('zone', None) + assert (cloud, region, zone) in expected_combinations + expected_combinations.remove((cloud, region, zone)) + assert config == default_controller_resources, config + assert not expected_combinations + + @pytest.mark.parametrize(('controller_type', 'default_controller_resources'), [ ('jobs', managed_job_constants.CONTROLLER_RESOURCES), ('serve', serve_constants.CONTROLLER_RESOURCES), @@ -79,17 +97,12 @@ def test_get_controller_resources_with_task_resources( # could host controllers. Return a set, each item has # one cloud specified plus the default resources. all_clouds = {sky.AWS(), sky.GCP(), sky.Azure()} - all_cloud_names = {str(c) for c in all_clouds} + expected_combinations = {(str(c), None, None) for c in all_clouds} controller_resources = controller_utils.get_controller_resources( controller=controller_utils.Controllers.from_type(controller_type), task_resources=[sky.Resources(cloud=c) for c in all_clouds]) - for r in controller_resources: - config = r.to_yaml_config() - cloud = config.pop('cloud') - assert cloud in all_cloud_names - all_cloud_names.remove(cloud) - assert config == default_controller_resources, config - assert not all_cloud_names + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) # 2. All resources has cloud specified. Some of them # could NOT host controllers. Return a set, only @@ -113,19 +126,14 @@ def _could_host_controllers(cloud: sky.clouds.Cloud) -> bool: return False return True - all_cloud_names_expected = { - str(c) for c in all_clouds if _could_host_controllers(c) + expected_combinations = { + (str(c), None, None) for c in all_clouds if _could_host_controllers(c) } controller_resources = controller_utils.get_controller_resources( controller=controller_utils.Controllers.from_type(controller_type), task_resources=[sky.Resources(cloud=c) for c in all_clouds]) - for r in controller_resources: - config = r.to_yaml_config() - cloud = config.pop('cloud') - assert cloud in all_cloud_names_expected - all_cloud_names_expected.remove(cloud) - assert config == default_controller_resources, config - assert not all_cloud_names_expected + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) # 3. Some resources does not have cloud specified. # Return the default resources. @@ -138,3 +146,73 @@ def _could_host_controllers(cloud: sky.clouds.Cloud) -> bool: assert len(controller_resources) == 1 config = list(controller_resources)[0].to_yaml_config() assert config == default_controller_resources, config + + # 4. All resources have clouds, regions, and zones specified. + # Return a set of controller resources for all combinations of clouds, + # regions, and zones. Each combination should contain the default resources + # along with the cloud, region, and zone. + all_cloud_regions_zones = [ + sky.Resources(cloud=sky.AWS(), region='us-east-1', zone='us-east-1a'), + sky.Resources(cloud=sky.AWS(), region='ap-south-1', zone='ap-south-1b'), + sky.Resources(cloud=sky.GCP(), + region='us-central1', + zone='us-central1-a'), + sky.Resources(cloud=sky.GCP(), + region='europe-west1', + zone='europe-west1-b') + ] + expected_combinations = {('AWS', 'us-east-1', 'us-east-1a'), + ('AWS', 'ap-south-1', 'ap-south-1b'), + ('GCP', 'us-central1', 'us-central1-a'), + ('GCP', 'europe-west1', 'europe-west1-b')} + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=all_cloud_regions_zones) + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) + + # 5. Clouds and regions are specified, but zones are partially specified. + # Return a set containing combinations where the zone is None when not all + # zones are specified in the input for the given region. The default + # resources should be returned along with the cloud and region, and the + # zone (if specified). + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=[ + sky.Resources(cloud=sky.AWS(), region='us-west-2'), + sky.Resources(cloud=sky.AWS(), + region='us-west-2', + zone='us-west-2b'), + sky.Resources(cloud=sky.GCP(), + region='us-central1', + zone='us-central1-a') + ]) + expected_combinations = {('AWS', 'us-west-2', None), + ('GCP', 'us-central1', 'us-central1-a')} + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) + + # 6. Mixed case: Some resources have clouds and regions or zones, others do + # not. For clouds where regions or zones are not specified in the input, + # return None for those fields. The default resources should be returned + # along with the cloud, region (if specified), and zone (if specified). + controller_resources = controller_utils.get_controller_resources( + controller=controller_utils.Controllers.from_type(controller_type), + task_resources=[ + sky.Resources(cloud=sky.GCP(), region='europe-west1'), + sky.Resources(cloud=sky.GCP()), + sky.Resources(cloud=sky.AWS(), + region='eu-north-1', + zone='eu-north-1a'), + sky.Resources(cloud=sky.AWS(), region='eu-north-1'), + sky.Resources(cloud=sky.AWS(), region='ap-south-1'), + sky.Resources(cloud=sky.Azure()), + ]) + expected_combinations = { + ('AWS', 'eu-north-1', None), + ('AWS', 'ap-south-1', None), + ('GCP', None, None), + ('Azure', None, None), + } + _check_controller_resources(controller_resources, expected_combinations, + default_controller_resources) diff --git a/tests/unit_tests/test_dag_utils.py b/tests/unit_tests/test_dag_utils.py new file mode 100644 index 00000000000..a083757800f --- /dev/null +++ b/tests/unit_tests/test_dag_utils.py @@ -0,0 +1,95 @@ +"""Test dag utils.""" +import textwrap + +import pytest +import yaml + +import sky +from sky import jobs +from sky.utils import common_utils +from sky.utils import dag_utils + + +def test_jobs_recovery_fill_default_values(): + """Test jobs recovery fill default values.""" + task_str = textwrap.dedent("""\ + resources: + cpus: 2+ + """) + task_config = yaml.safe_load(task_str) + task = sky.Task.from_yaml_config(task_config) + dag = dag_utils.convert_entrypoint_to_dag(task) + dag_utils.fill_default_config_in_dag_for_job_launch(dag) + + resources = list(dag.tasks[0].resources) + assert len(resources) == 1 + assert resources[0].job_recovery[ + 'strategy'] == jobs.DEFAULT_RECOVERY_STRATEGY + + task_str = textwrap.dedent("""\ + resources: + cpus: 2+ + job_recovery: + max_restarts_on_errors: 3 + """) + + task_config = yaml.safe_load(task_str) + task = sky.Task.from_yaml_config(task_config) + dag = dag_utils.convert_entrypoint_to_dag(task) + dag_utils.fill_default_config_in_dag_for_job_launch(dag) + + resources = list(dag.tasks[0].resources) + assert len(resources) == 1 + assert resources[0].job_recovery[ + 'strategy'] == jobs.DEFAULT_RECOVERY_STRATEGY + assert resources[0].job_recovery['max_restarts_on_errors'] == 3 + + task_str = textwrap.dedent(f"""\ + resources: + cpus: 2+ + job_recovery: + strategy: FAILOVER + max_restarts_on_errors: 3 + """) + + task_config = yaml.safe_load(task_str) + task = sky.Task.from_yaml_config(task_config) + dag = dag_utils.convert_entrypoint_to_dag(task) + dag_utils.fill_default_config_in_dag_for_job_launch(dag) + + resources = list(dag.tasks[0].resources) + assert len(resources) == 1 + assert resources[0].job_recovery['strategy'] == 'FAILOVER' + assert resources[0].job_recovery['max_restarts_on_errors'] == 3 + + task_str = textwrap.dedent("""\ + resources: + cpus: 2+ + job_recovery: + """) + + task_config = yaml.safe_load(task_str) + task = sky.Task.from_yaml_config(task_config) + dag = dag_utils.convert_entrypoint_to_dag(task) + dag_utils.fill_default_config_in_dag_for_job_launch(dag) + + resources = list(dag.tasks[0].resources) + assert len(resources) == 1 + assert resources[0].job_recovery[ + 'strategy'] == jobs.DEFAULT_RECOVERY_STRATEGY + + task_str = textwrap.dedent("""\ + resources: + cpus: 2+ + any_of: + - cpus: 2+ + job_recovery: + max_restarts_on_errors: 3 + - cpus: 4+ + """) + + task_config = yaml.safe_load(task_str) + task = sky.Task.from_yaml_config(task_config) + dag = dag_utils.convert_entrypoint_to_dag(task) + with pytest.raises(ValueError): + dag_utils.fill_default_config_in_dag_for_job_launch(dag)