diff --git a/examples/mpirun.yaml b/examples/mpirun.yaml new file mode 100644 index 00000000000..4ec7ce0107c --- /dev/null +++ b/examples/mpirun.yaml @@ -0,0 +1,24 @@ +workdir: . + +resources: + cloud: aws + +num_nodes: 2 # Total number of nodes (1 head + 1 worker) + +setup: | + echo "Running setup on node ${SKYPILOT_NODE_RANK}." + # Install MPI if not already present. This will vary based on your OS/distro. + sudo apt update + sudo apt install -y openmpi-bin openmpi-common libopenmpi-dev + +run: | + if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then + echo "head node" + num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l` + mpi_nodes=$(echo "$SKYPILOT_NODE_IPS" | tr '\n' ',') + mpi_nodes=${mpi_nodes::-1} + echo "$mpi_nodes" + mpirun -np $num_nodes -H $mpi_nodes bash -c 'echo "mpirun hello from IP $(hostname -I)"' + else + echo "worker nodes" + fi diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index d7211d18a65..b83817b9b42 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -428,6 +428,7 @@ def _get_generated_config(cls, autogen_comment: str, host_name: str, HostName {ip} User {username} IdentityFile {ssh_key_path} + AddKeysToAgent yes IdentitiesOnly yes ForwardAgent yes StrictHostKeyChecking no diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 37b912db979..0c188599ae6 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -259,6 +259,8 @@ def _ssh_probe_command(ip: str, '-o', 'IdentitiesOnly=yes', '-o', + 'AddKeysToAgent=yes', + '-o', 'ExitOnForwardFailure=yes', '-o', 'ServerAliveInterval=5', diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 4d57854bf90..1cb1dfc88e6 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -85,6 +85,10 @@ def ssh_options_list( 'LogLevel': 'ERROR', # Try fewer extraneous key pairs. 'IdentitiesOnly': 'yes', + # Add the current private key used for this SSH connection to the + # SSH agent, so that forward agent parameter will then make SSH + # agent forward it. + 'AddKeysToAgent': 'yes', # Abort if port forwarding fails (instead of just printing to # stderr). 'ExitOnForwardFailure': 'yes',