Skip to content

Commit

Permalink
add all annotations and rename task_queue to ready_queue
Browse files Browse the repository at this point in the history
  • Loading branch information
andylizf committed Nov 3, 2024
1 parent 4ef56e1 commit e7d99d4
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, job_id: int, dag_yaml: str,
self._backend = cloud_vm_ray_backend.CloudVmRayBackend()

self._dag_graph = self._dag.get_graph()
self._task_queue = self._initialize_task_queue()
self._ready_tasks = self._initialize_ready_tasks()
self._task_status: Dict['sky.Task', TaskStatus] = {}

# Add a unique identifier to the task environment variables, so that
Expand Down Expand Up @@ -100,13 +100,15 @@ def __init__(self, job_id: int, dag_yaml: str,
job_id_env_vars)
task.update_envs(task_envs)

def _initialize_task_queue(self) -> queue.Queue:
task_queue: queue.Queue = queue.Queue()
def _initialize_ready_tasks(self) -> queue.Queue:
"""Initialize a queue with tasks that are ready to execute
(no dependencies)."""
ready_tasks: queue.Queue = queue.Queue()
for task in self._dag_graph.nodes():
if self._dag_graph.in_degree(task) == 0:
task_id = self._dag.tasks.index(task)
task_queue.put(task_id)
return task_queue
ready_tasks.put(task_id)
return ready_tasks

def _download_log_and_stream(
self,
Expand Down Expand Up @@ -374,9 +376,10 @@ def _try_add_successors_to_queue(self, task_id: int) -> None:
for successor in self._dag_graph.successors(task):
successor_id = self._dag.tasks.index(successor)
if is_task_runnable(successor):
self._task_queue.put(successor_id)
self._ready_tasks.put(successor_id)

def _handle_future_completion(self, future: futures.Future, task_id: int):
def _handle_future_completion(self, future: futures.Future,
task_id: int) -> None:
succeeded = False
try:
succeeded = future.result()
Expand Down Expand Up @@ -421,7 +424,7 @@ def _handle_future_completion(self, future: futures.Future, task_id: int):
self._task_status[task] = TaskStatus.FAILED
self._cancel_all_tasks(task_id)

def _cancel_all_tasks(self, task_id: int):
def _cancel_all_tasks(self, task_id: int) -> None:
callback_func = managed_job_utils.event_callback_func(
job_id=self._job_id, task_id=task_id, task=self._dag.tasks[task_id])
for task in self._dag.tasks:
Expand All @@ -433,7 +436,7 @@ def _cancel_all_tasks(self, task_id: int):
managed_job_state.set_cancelling(self._job_id, callback_func)
managed_job_state.set_cancelled(self._job_id, callback_func)

def run(self):
def run(self) -> None:
"""Run controller logic and handle exceptions."""
all_tasks_completed = lambda: self._num_tasks == len(self._task_status)
# TODO(andy):Serve has a logic to prevent from too many services running
Expand All @@ -448,8 +451,8 @@ def run(self):
with futures.ThreadPoolExecutor(max_workers) as executor:
future_to_task = {}
while not all_tasks_completed():
while not self._task_queue.empty():
task_id = self._task_queue.get()
while not self._ready_tasks.empty():
task_id = self._ready_tasks.get()
log_file_name = managed_job_utils.get_launch_log_file_name(
self._job_id, task_id)

Expand Down Expand Up @@ -477,7 +480,7 @@ def run(self):
def _update_failed_task_state(
self, task_id: int,
failure_type: managed_job_state.ManagedJobStatus,
failure_reason: str):
failure_reason: str) -> None:
"""Update the state of the failed task."""
managed_job_state.set_failed(
self._job_id,
Expand All @@ -490,15 +493,15 @@ def _update_failed_task_state(
task=self._dag.tasks[task_id]))


def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool):
def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool) -> None:
"""Runs the controller in a remote process for interruption."""
# The controller needs to be instantiated in the remote process, since
# the controller is not serializable.
jobs_controller = JobsController(job_id, dag_yaml, retry_until_up)
jobs_controller.run()


def _handle_signal(job_id):
def _handle_signal(job_id) -> None:
"""Handle the signal if the user sent it."""
signal_file = pathlib.Path(
managed_job_utils.SIGNAL_FILE_PREFIX.format(job_id))
Expand All @@ -508,9 +511,9 @@ def _handle_signal(job_id):
# signal writing.
with filelock.FileLock(str(signal_file) + '.lock'):
with signal_file.open(mode='r', encoding='utf-8') as f:
user_signal = f.read().strip()
user_signal_str = f.read().strip()
try:
user_signal = managed_job_utils.UserSignal(user_signal)
user_signal = managed_job_utils.UserSignal(user_signal_str)
except ValueError:
logger.warning(
f'Unknown signal received: {user_signal}. Ignoring.')
Expand All @@ -526,7 +529,7 @@ def _handle_signal(job_id):
f'User sent {user_signal.value} signal.')


def _cleanup(job_id: int, dag_yaml: str):
def _cleanup(job_id: int, dag_yaml: str) -> None:
"""Clean up the cluster(s) and storages.
(1) Clean up the succeeded task(s)' ephemeral storage. The storage has
Expand All @@ -551,7 +554,7 @@ def _cleanup(job_id: int, dag_yaml: str):
backend.teardown_ephemeral_storage(task)


def start(job_id, dag_yaml, retry_until_up):
def start(job_id, dag_yaml, retry_until_up) -> None:
"""Start the controller."""
controller_process = None
cancelling = False
Expand Down Expand Up @@ -606,6 +609,7 @@ def start(job_id, dag_yaml, retry_until_up):
logger.info(f'Cluster of managed job {job_id} has been cleaned up.')

if cancelling:
assert task_id is not None, task_id
managed_job_state.set_cancelled(
job_id=job_id,
callback_func=managed_job_utils.event_callback_func(
Expand Down

0 comments on commit e7d99d4

Please sign in to comment.