Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use gokart worker #402

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion gokart/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import backoff
import luigi
from luigi import rpc, scheduler
from luigi import worker as luigi_worker

import gokart
from gokart import worker
from gokart.conflict_prevention_lock.task_lock import TaskLockException
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart
Expand Down Expand Up @@ -43,6 +46,25 @@ def __init__(self):
self.flag: bool = False


class WorkerSchedulerFactory:
def __init__(self, complete_check_at_run: bool = True):
"""
Args:
complete_check_at_run (bool, optional): If True, check if the task is already completed before running the task. Defaults to True.
When the task is already completed, the task is skipped.
"""
self._complete_check_at_run = complete_check_at_run

def create_local_scheduler(self) -> scheduler.Scheduler:
return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)

def create_remote_scheduler(self, url) -> rpc.RemoteScheduler:
return rpc.RemoteScheduler(url)

def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant=False) -> luigi_worker.Worker:
return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant, complete_check_at_run=self._complete_check_at_run)


def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
Expand Down Expand Up @@ -98,6 +120,7 @@ def build(
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
worker_scheduler_factory: Optional[WorkerSchedulerFactory] = None,
**env_params,
) -> Optional[T]:
"""
Expand All @@ -106,6 +129,9 @@ def build(
"""
if reset_register:
_reset_register()
if not worker_scheduler_factory:
worker_scheduler_factory = WorkerSchedulerFactory()

with LoggerConfig(level=log_level):
task_lock_exception_raised = TaskLockExceptionRaisedFlag()

Expand All @@ -119,7 +145,14 @@ def when_failure(task, exception):
)
def _build_task():
task_lock_exception_raised.flag = False
result = luigi.build([task], local_scheduler=True, detailed_summary=True, log_level=logging.getLevelName(log_level), **env_params)
result = luigi.build(
[task],
local_scheduler=True,
detailed_summary=True,
worker_scheduler_factory=worker_scheduler_factory,
log_level=logging.getLevelName(log_level),
**env_params,
)
if task_lock_exception_raised.flag:
raise HasLockedTaskException()
if result.status == luigi.LuigiStatusCode.FAILED:
Expand Down
9 changes: 5 additions & 4 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper
from gokart.utils import FlattenableItems, flatten

logger = getLogger(__name__)
Expand Down Expand Up @@ -95,7 +94,9 @@ class TaskOnKart(luigi.Task, Generic[T]):
significant=False,
)
complete_check_at_run: bool = ExplicitBoolParameter(
default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False
default=True,
description='Check if output file exists at run. If exists, run() will be skipped.',
significant=False,
)
should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.')

Expand All @@ -112,8 +113,8 @@ def __init__(self, *args, **kwargs):
self._rerun_state = self.rerun
self._lock_at_dump = True

if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore
if not self.complete_check_at_run:
logger.warning('parameter `complete_check_at_run` is deprecated. Please set by `WorkerSchedulerFactory`.')

if self.should_lock_run:
self._lock_at_dump = False
Expand Down
50 changes: 50 additions & 0 deletions gokart/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from logging import getLogger
from typing import Any, Dict, Optional

from luigi import scheduler, worker

logger = getLogger(__name__)


class Worker(worker.Worker):
def __init__(
self,
scheduler: Optional[scheduler.Scheduler] = None,
worker_id: Optional[int] = None,
worker_processes: int = 1,
assistant: bool = False,
check_complete_on_run: bool = True,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
scheduler=scheduler,
worker_id=worker_id,
worker_processes=worker_processes,
assistant=assistant,
check_complete_on_run=check_complete_on_run,
**kwargs,
)

# def _run_task(self, task_id: int) -> None:
# if task_id in self._running_tasks:
# logger.debug('Got already running task id {} from scheduler, taking a break'.format(task_id))
# next(self._sleeper())
# return

# task = self._scheduled_tasks[task_id]
# if not isinstance(task, TaskOnKart):
# raise ValueError(f'Task must be an instance of TaskOnKart, but got {type(task)}')
# if self._complete_check_at_run and task.complete():
# logger.warning(f'Task {task} is already completed. Skipping...')
# return

# task_process = self._create_task_process(task)

# self._running_tasks[task_id] = task_process

# if task_process.use_multiprocessing:
# with worker.fork_lock:
# task_process.start()
# else:
# # Run in the same process
# task_process.run()
58 changes: 0 additions & 58 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,63 +599,5 @@ def test_should_fail_lock_run_when_port_unset(self):
gokart.TaskOnKart(redis_host='host', redis_timeout=180, should_lock_run=True)


class _DummyTaskWithNonCompleted(gokart.TaskOnKart):
def dump(self, _obj: Any, _target: Any = None):
# overrive dump() to do nothing.
pass

def run(self):
self.dump('hello')

def complete(self):
return False


class _DummyTaskWithCompleted(gokart.TaskOnKart):
def dump(self, obj: Any, _target: Any = None):
# overrive dump() to do nothing.
pass

def run(self):
self.dump('hello')

def complete(self):
return True


class TestCompleteCheckAtRun(unittest.TestCase):
def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=False)
task.dump = Mock() # type: ignore
task.run()

# since run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=False)
task.dump = Mock() # type: ignore
task.run()

# even task is completed, since run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=True)
task.dump = Mock() # type: ignore
task.run()

# since task is not completed, when run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=True)
task.dump = Mock() # type: ignore
task.run()

# since task is completed, even when run() is called, dump() should not be called.
task.dump.assert_not_called()


if __name__ == '__main__':
unittest.main()
50 changes: 50 additions & 0 deletions test/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pdb
import uuid
from unittest.mock import Mock

import luigi
import pytest
from luigi import scheduler

import gokart
from gokart.worker import Worker


class _DummyTask(gokart.TaskOnKart):
task_namespace = __name__
random_id = luigi.Parameter()

def _run(self): ...

def run(self):
self._run()
self.dump('test')


class TestWorkerCompleteCheckAtRun:
@pytest.mark.parametrize(
'is_complete,complete_check_at_run,n_run_called',
[
pytest.param(False, True, 1, id='not complete, check at run'),
pytest.param(False, False, 1, id='not complete, no check at run'),
pytest.param(True, True, 0, id='complete, check at run'),
pytest.param(True, False, 1, id='complete, no check at run'),
],
)
def test_run(self, is_complete: bool, complete_check_at_run: bool, n_run_called: int, monkeypatch: pytest.MonkeyPatch):
"""Check run is called when the task is not completed"""
sch = scheduler.Scheduler()
worker = Worker(scheduler=sch, check_complete_on_run=complete_check_at_run)

task = _DummyTask(random_id=uuid.uuid4().hex)
mock_run = Mock()
monkeypatch.setattr(task, '_run', mock_run)
with worker:
assert worker.add(task)
mock_complete = Mock()
monkeypatch.setattr(task, 'complete', mock_complete)
mock_complete.return_value = is_complete
# mock_complete.side_effect = pdb.set_trace
assert worker.run()
assert mock_run.call_count == n_run_called

Loading