From 9ed3b2d2c7865134e50975a3412f5f29d80eaac1 Mon Sep 17 00:00:00 2001 From: Hironori Yamamoto Date: Tue, 15 Oct 2024 02:15:02 +0900 Subject: [PATCH] tmp --- gokart/build.py | 35 ++++++++++++++++++++++- gokart/task.py | 9 +++--- gokart/worker.py | 50 +++++++++++++++++++++++++++++++++ test/test_task_on_kart.py | 58 --------------------------------------- test/test_worker.py | 50 +++++++++++++++++++++++++++++++++ 5 files changed, 139 insertions(+), 63 deletions(-) create mode 100644 gokart/worker.py create mode 100644 test/test_worker.py diff --git a/gokart/build.py b/gokart/build.py index 46001e32..0af1c327 100644 --- a/gokart/build.py +++ b/gokart/build.py @@ -5,8 +5,11 @@ import backoff import luigi +from luigi import rpc, scheduler +from luigi import worker as luigi_worker import gokart +from gokart import worker from gokart.conflict_prevention_lock.task_lock import TaskLockException from gokart.target import TargetOnKart from gokart.task import TaskOnKart @@ -43,6 +46,25 @@ def __init__(self): self.flag: bool = False +class WorkerSchedulerFactory: + def __init__(self, complete_check_at_run: bool = True): + """ + Args: + complete_check_at_run (bool, optional): If True, check if the task is already completed before running the task. Defaults to True. + When the task is already completed, the task is skipped. + """ + self._complete_check_at_run = complete_check_at_run + + def create_local_scheduler(self) -> scheduler.Scheduler: + return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False) + + def create_remote_scheduler(self, url) -> rpc.RemoteScheduler: + return rpc.RemoteScheduler(url) + + def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> luigi_worker.Worker: + return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant, complete_check_at_run=self._complete_check_at_run) + + def _get_output(task: TaskOnKart[T]) -> T: output = task.output() # FIXME: currently, nested output is not supported @@ -98,6 +120,7 @@ def build( log_level: int = logging.ERROR, task_lock_exception_max_tries: int = 10, task_lock_exception_max_wait_seconds: int = 600, + worker_scheduler_factory: Optional[WorkerSchedulerFactory] = None, **env_params, ) -> Optional[T]: """ @@ -106,6 +129,9 @@ def build( """ if reset_register: _reset_register() + if not worker_scheduler_factory: + worker_scheduler_factory = WorkerSchedulerFactory() + with LoggerConfig(level=log_level): task_lock_exception_raised = TaskLockExceptionRaisedFlag() @@ -119,7 +145,14 @@ def when_failure(task, exception): ) def _build_task(): task_lock_exception_raised.flag = False - result = luigi.build([task], local_scheduler=True, detailed_summary=True, log_level=logging.getLevelName(log_level), **env_params) + result = luigi.build( + [task], + local_scheduler=True, + detailed_summary=True, + worker_scheduler_factory=worker_scheduler_factory, + log_level=logging.getLevelName(log_level), + **env_params, + ) if task_lock_exception_raised.flag: raise HasLockedTaskException() if result.status == luigi.LuigiStatusCode.FAILED: diff --git a/gokart/task.py b/gokart/task.py index 2f351d2a..ff4e66be 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -19,7 +19,6 @@ from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.target import TargetOnKart -from gokart.task_complete_check import task_complete_check_wrapper from gokart.utils import FlattenableItems, flatten logger = getLogger(__name__) @@ -95,7 +94,9 @@ class TaskOnKart(luigi.Task, Generic[T]): significant=False, ) complete_check_at_run: bool = ExplicitBoolParameter( - default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False + default=True, + description='Check if output file exists at run. If exists, run() will be skipped.', + significant=False, ) should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.') @@ -112,8 +113,8 @@ def __init__(self, *args, **kwargs): self._rerun_state = self.rerun self._lock_at_dump = True - if self.complete_check_at_run: - self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore + if not self.complete_check_at_run: + logger.warning('parameter `complete_check_at_run` is deprecated. Please set by `WorkerSchedulerFactory`.') if self.should_lock_run: self._lock_at_dump = False diff --git a/gokart/worker.py b/gokart/worker.py new file mode 100644 index 00000000..ba7b2f5b --- /dev/null +++ b/gokart/worker.py @@ -0,0 +1,50 @@ +from logging import getLogger +from typing import Any, Dict, Optional + +from luigi import scheduler, worker + +logger = getLogger(__name__) + + +class Worker(worker.Worker): + def __init__( + self, + scheduler: Optional[scheduler.Scheduler] = None, + worker_id: Optional[int] = None, + worker_processes: int = 1, + assistant: bool = False, + check_complete_on_run: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + scheduler=scheduler, + worker_id=worker_id, + worker_processes=worker_processes, + assistant=assistant, + check_complete_on_run=check_complete_on_run, + **kwargs, + ) + + # def _run_task(self, task_id: int) -> None: + # if task_id in self._running_tasks: + # logger.debug('Got already running task id {} from scheduler, taking a break'.format(task_id)) + # next(self._sleeper()) + # return + + # task = self._scheduled_tasks[task_id] + # if not isinstance(task, TaskOnKart): + # raise ValueError(f'Task must be an instance of TaskOnKart, but got {type(task)}') + # if self._complete_check_at_run and task.complete(): + # logger.warning(f'Task {task} is already completed. Skipping...') + # return + + # task_process = self._create_task_process(task) + + # self._running_tasks[task_id] = task_process + + # if task_process.use_multiprocessing: + # with worker.fork_lock: + # task_process.start() + # else: + # # Run in the same process + # task_process.run() diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index e3946b49..2a347338 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -599,63 +599,5 @@ def test_should_fail_lock_run_when_port_unset(self): gokart.TaskOnKart(redis_host='host', redis_timeout=180, should_lock_run=True) -class _DummyTaskWithNonCompleted(gokart.TaskOnKart): - def dump(self, _obj: Any, _target: Any = None): - # overrive dump() to do nothing. - pass - - def run(self): - self.dump('hello') - - def complete(self): - return False - - -class _DummyTaskWithCompleted(gokart.TaskOnKart): - def dump(self, obj: Any, _target: Any = None): - # overrive dump() to do nothing. - pass - - def run(self): - self.dump('hello') - - def complete(self): - return True - - -class TestCompleteCheckAtRun(unittest.TestCase): - def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self): - task = _DummyTaskWithNonCompleted(complete_check_at_run=False) - task.dump = Mock() # type: ignore - task.run() - - # since run() is called, dump() should be called. - task.dump.assert_called_once() - - def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self): - task = _DummyTaskWithCompleted(complete_check_at_run=False) - task.dump = Mock() # type: ignore - task.run() - - # even task is completed, since run() is called, dump() should be called. - task.dump.assert_called_once() - - def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self): - task = _DummyTaskWithNonCompleted(complete_check_at_run=True) - task.dump = Mock() # type: ignore - task.run() - - # since task is not completed, when run() is called, dump() should be called. - task.dump.assert_called_once() - - def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self): - task = _DummyTaskWithCompleted(complete_check_at_run=True) - task.dump = Mock() # type: ignore - task.run() - - # since task is completed, even when run() is called, dump() should not be called. - task.dump.assert_not_called() - - if __name__ == '__main__': unittest.main() diff --git a/test/test_worker.py b/test/test_worker.py new file mode 100644 index 00000000..5c510174 --- /dev/null +++ b/test/test_worker.py @@ -0,0 +1,50 @@ +import pdb +import uuid +from unittest.mock import Mock + +import luigi +import pytest +from luigi import scheduler + +import gokart +from gokart.worker import Worker + + +class _DummyTask(gokart.TaskOnKart): + task_namespace = __name__ + random_id = luigi.Parameter() + + def _run(self): ... + + def run(self): + self._run() + self.dump('test') + + +class TestWorkerCompleteCheckAtRun: + @pytest.mark.parametrize( + 'is_complete,complete_check_at_run,n_run_called', + [ + pytest.param(False, True, 1, id='not complete, check at run'), + pytest.param(False, False, 1, id='not complete, no check at run'), + pytest.param(True, True, 0, id='complete, check at run'), + pytest.param(True, False, 1, id='complete, no check at run'), + ], + ) + def test_run(self, is_complete: bool, complete_check_at_run: bool, n_run_called: int, monkeypatch: pytest.MonkeyPatch): + """Check run is called when the task is not completed""" + sch = scheduler.Scheduler() + worker = Worker(scheduler=sch, check_complete_on_run=complete_check_at_run) + + task = _DummyTask(random_id=uuid.uuid4().hex) + mock_run = Mock() + monkeypatch.setattr(task, '_run', mock_run) + with worker: + assert worker.add(task) + mock_complete = Mock() + monkeypatch.setattr(task, 'complete', mock_complete) + mock_complete.return_value = is_complete + # mock_complete.side_effect = pdb.set_trace + assert worker.run() + assert mock_run.call_count == n_run_called +