diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d281772d3b..a27d017d65 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,7 +4,7 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence, -* @dlstadther @Tarrasch @spotify/dataex +* @dlstadther @spotify/dataex # Specific files, directories, paths, or file types can be # assigned more specificially. diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000..12be2b33aa --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,65 @@ +name: "CodeQL" + +on: + push: + branches: [ 'master' ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ 'master' ] + schedule: + - cron: '29 18 * * 0' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python', 'javascript' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Use only 'java' to analyze code written in Java, Kotlin or both + # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both + # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + + # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + + # If the Autobuild fails above, remove it and uncomment the following three lines. + # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. + + # - run: | + # echo "Run, Build Application using script" + # ./location_of_script_within_repo/buildscript.sh + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index d4096dca4a..388b0569cd 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -13,14 +13,16 @@ jobs: strategy: matrix: include: - - python-version: 3.6 + - python-version: "3.6" tox-env: py36-core - - python-version: 3.7 + - python-version: "3.7" tox-env: py37-core - - python-version: 3.8 + - python-version: "3.8" tox-env: py38-core - - python-version: 3.9 + - python-version: "3.9" tox-env: py39-core + - python-version: "3.10" + tox-env: py310-core steps: - uses: actions/checkout@v2 @@ -37,7 +39,7 @@ jobs: key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.txt', format('requirements{0}.txt', matrix.spark-version-suffix))) }} - name: Install dependencies run: | - python -m pip install --upgrade pip 'tox<3.0' + python -m pip install --upgrade pip 'tox<4.0' - name: Setup MySQL DB run: | sudo /etc/init.d/mysql start @@ -49,7 +51,7 @@ jobs: TOXENV: ${{ matrix.tox-env }} run: tox - name: Codecov - env: + env: COVERAGE_PROCESS_START: .coveragerc run: | pip install codecov @@ -76,14 +78,16 @@ jobs: strategy: matrix: include: - - python-version: 3.6 + - python-version: "3.6" tox-env: py36-postgres - - python-version: 3.7 + - python-version: "3.7" tox-env: py37-postgres - - python-version: 3.8 + - python-version: "3.8" tox-env: py38-postgres - - python-version: 3.9 + - python-version: "3.9" tox-env: py39-postgres + - python-version: "3.10" + tox-env: py310-postgres steps: - uses: actions/checkout@v2 @@ -100,7 +104,7 @@ jobs: key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.txt', format('requirements{0}.txt', matrix.spark-version-suffix))) }} - name: Install dependencies run: | - python -m pip install --upgrade pip 'tox<3.0' + python -m pip install --upgrade pip 'tox<4.0' - name: Create PSQL database run: | PGPASSWORD=postgres psql -h localhost -p 5432 -c 'create database spotify;' -U postgres @@ -109,7 +113,7 @@ jobs: TOXENV: ${{ matrix.tox-env }} run: tox - name: Codecov - env: + env: COVERAGE_PROCESS_START: .coveragerc run: | pip install codecov @@ -125,45 +129,54 @@ jobs: strategy: matrix: include: - - python-version: 3.6 + - python-version: "3.6" tox-env: py36-aws - - python-version: 3.7 + - python-version: "3.7" tox-env: py37-aws - - python-version: 3.8 + - python-version: "3.8" tox-env: py38-aws - - python-version: 3.9 + - python-version: "3.9" tox-env: py39-aws + - python-version: "3.10" + tox-env: py310-aws - - python-version: 3.6 + - python-version: "3.6" tox-env: py36-unixsocket OVERRIDE_SKIP_CI_TESTS: True - - python-version: 3.7 + - python-version: "3.7" tox-env: py37-unixsocket OVERRIDE_SKIP_CI_TESTS: True - - python-version: 3.8 + - python-version: "3.8" tox-env: py38-unixsocket OVERRIDE_SKIP_CI_TESTS: True - - python-version: 3.9 + - python-version: "3.9" tox-env: py39-unixsocket OVERRIDE_SKIP_CI_TESTS: True + - python-version: "3.10" + tox-env: py310-unixsocket + OVERRIDE_SKIP_CI_TESTS: True - - python-version: 3.6 + - python-version: "3.6" tox-env: py36-apache - - python-version: 3.7 + - python-version: "3.7" tox-env: py37-apache - - python-version: 3.8.9 + - python-version: "3.8" tox-env: py38-apache - - python-version: 3.9.4 + - python-version: "3.9" tox-env: py39-apache - - - python-version: 3.6 + - python-version: "3.10" + tox-env: py310-apache + + - python-version: "3.6" tox-env: py36-azureblob - - python-version: 3.7 + - python-version: "3.7" tox-env: py37-azureblob - - python-version: 3.8 + - python-version: "3.8" tox-env: py38-azureblob - - python-version: 3.9 + - python-version: "3.9" tox-env: py39-azureblob + - python-version: "3.10" + tox-env: py310-azureblob - python-version: 3.9 @@ -194,7 +207,7 @@ jobs: run: tox - name: Codecov if: ${{ matrix.tox-env != 'flake8' && matrix.tox-env != 'docs' }} - env: + env: COVERAGE_PROCESS_START: .coveragerc run: | pip install codecov diff --git a/README.rst b/README.rst index 1040fa49c6..f2bcc3f408 100644 --- a/README.rst +++ b/README.rst @@ -2,8 +2,8 @@ :alt: Luigi Logo :align: center -.. image:: https://img.shields.io/travis/spotify/luigi/master.svg?style=flat - :target: https://travis-ci.org/spotify/luigi +.. image:: https://img.shields.io/endpoint.svg?url=https%3A%2F%2Factions-badge.atrox.dev%2Fspotify%2Fluigi%2Fbadge&label=build&logo=none&%3Fref%3Dmaster&style=flat + :target: https://actions-badge.atrox.dev/spotify/luigi/goto?ref=master .. image:: https://img.shields.io/codecov/c/github/spotify/luigi/master.svg?style=flat :target: https://codecov.io/gh/spotify/luigi?branch=master @@ -14,7 +14,7 @@ .. image:: https://img.shields.io/pypi/l/luigi.svg?style=flat :target: https://pypi.python.org/pypi/luigi -Luigi is a Python (3.6, 3.7, 3.8, 3.9 tested) package that helps you build complex +Luigi is a Python (3.6, 3.7, 3.8, 3.9, 3.10 tested) package that helps you build complex pipelines of batch jobs. It handles dependency resolution, workflow management, visualization, handling failures, command line integration, and much more. @@ -100,7 +100,7 @@ Conceptually, Luigi is similar to `GNU Make `_ where you have certain tasks and these tasks in turn may have dependencies on other tasks. There are also some similarities to `Oozie `_ -and `Azkaban `_. One major +and `Azkaban `_. One major difference is that Luigi is not just built specifically for Hadoop, and it's easy to extend it with other kinds of tasks. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000..5a9d5cd552 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,5 @@ +# Security Policy + +## Reporting a Vulnerability + +Please report sensitive security issues via Spotify's [bug-bounty program](https://hackerone.com/spotify) by following this [instruction](https://docs.hackerone.com/programs/security-page.html), rather than GitHub. diff --git a/catalog-info.yaml b/catalog-info.yaml new file mode 100644 index 0000000000..8f6a0305b6 --- /dev/null +++ b/catalog-info.yaml @@ -0,0 +1,7 @@ +apiVersion: backstage.io/v1alpha1 +kind: Component +metadata: + name: luigi +spec: + type: library + owner: dataex diff --git a/doc/configuration.rst b/doc/configuration.rst index 5b73aae1af..2499fb7567 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -168,24 +168,6 @@ log_level logging_conf_file Location of the logging configuration file. -max_shown_tasks - .. versionadded:: 1.0.20 - - The maximum number of tasks returned in a task_list api call. This - will restrict the number of tasks shown in task lists in the - visualiser. Small values can alleviate frozen browsers when there are - too many done tasks. This defaults to 100000 (one hundred thousand). - -max_graph_nodes - .. versionadded:: 2.0.0 - - The maximum number of nodes returned by a dep_graph or - inverse_dep_graph api call. Small values can greatly speed up graph - display in the visualiser by limiting the number of nodes shown. Some - of the nodes that are not sent to the visualiser will still show up as - dependencies of nodes that were sent. These nodes are given TRUNCATED - status. - no_configure_logging If true, logging is not configured. Defaults to false. @@ -303,14 +285,14 @@ wait_interval available jobs. wait_jitter - Size of jitter to add to the worker wait interval such that the multiple - workers do not ask the scheduler for another job at the same time. + Duration of jitter to add to the worker wait interval such that the multiple + workers do not ask the scheduler for another job at the same time, in seconds. Default: 5.0 max_keep_alive_idle_duration .. versionadded:: 2.8.4 - Maximum duration to keep worker alive while in idle state. + Maximum duration in seconds to keep worker alive while in idle state. Default: 0 (Indefinitely) max_reschedules @@ -374,6 +356,15 @@ check_complete_on_run missing. Defaults to false. +cache_task_completion + By default, luigi task processes might check the completion status multiple + times per task which is a safe way to avoid potential inconsistencies. For + tasks with many dynamic dependencies, yielded in multiple stages, this might + become expensive, e.g. in case the per-task completion check entails remote + resources. When set to true, completion checks are cached so that tasks + declared as complete once are not checked again. + Defaults to false. + [elasticsearch] --------------- @@ -447,7 +438,7 @@ traceback_max_length Parameters controlling the contents of batch notifications sent from the scheduler -email_interval +email_interval_minutes Number of minutes between e-mail sends. Making this larger results in fewer, bigger e-mails. Defaults to 60. @@ -789,6 +780,24 @@ disable_window scheduler forgets about disables that have occurred longer ago than this amount of time. Defaults to 3600 (1 hour). +max_shown_tasks + .. versionadded:: 1.0.20 + + The maximum number of tasks returned in a task_list api call. This + will restrict the number of tasks shown in task lists in the + visualiser. Small values can alleviate frozen browsers when there are + too many done tasks. This defaults to 100000 (one hundred thousand). + +max_graph_nodes + .. versionadded:: 2.0.0 + + The maximum number of nodes returned by a dep_graph or + inverse_dep_graph api call. Small values can greatly speed up graph + display in the visualiser by limiting the number of nodes shown. Some + of the nodes that are not sent to the visualiser will still show up as + dependencies of nodes that were sent. These nodes are given TRUNCATED + status. + record_task_history If true, stores task history in a database. Defaults to false. @@ -836,7 +845,12 @@ metrics_collector Optional setting allowing Luigi to use a contribution to collect metrics about the pipeline to a third-party. By default this uses the default metric collector that acts as a shell and does nothing. The currently available - options are "datadog" and "prometheus". + options are "datadog", "prometheus" and "custom". If it's custom the + 'metrics_custom_import' needs to be set. + +metrics_custom_import + Optional setting allowing Luigi to import a custom subclass of MetricsCollector + at runtime. The string should be formatted like "module.sub_module.ClassName". [sendgrid] diff --git a/doc/example_top_artists.rst b/doc/example_top_artists.rst index 67db656b4b..119d2083d8 100644 --- a/doc/example_top_artists.rst +++ b/doc/example_top_artists.rst @@ -41,7 +41,7 @@ Step 1 - Aggregate Artist Streams with self.output().open('w') as out_file: for artist, count in artist_count.iteritems(): - print >> out_file, artist, count + print(artist, count, file=out_file) Note that this is just a portion of the file ``examples/top_artists.py``. In particular, ``Streams`` is defined as a :class:`~luigi.task.Task`, @@ -200,7 +200,7 @@ we choose to do this not as a Hadoop job, but just as a plain old for-loop in Py top_10 = nlargest(10, self._input_iterator()) with self.output().open('w') as out_file: for streams, artist in top_10: - print >> out_file, self.date_interval.date_a, self.date_interval.date_b, artist, streams + print(self.date_interval.date_a, self.date_interval.date_b, artist, streams, file=out_file) def _input_iterator(self): with self.input().open('r') as in_file: diff --git a/doc/tasks.rst b/doc/tasks.rst index 936d1e416a..0d8c7bf060 100644 --- a/doc/tasks.rst +++ b/doc/tasks.rst @@ -134,6 +134,8 @@ LocalTarget. Following the above example: .. code:: python + from luigi.format import Nop + class GenerateWords(luigi.Task): def output(self): @@ -196,8 +198,8 @@ You can also yield a list of tasks. def run(self): other_target = yield OtherTask() - # dynamic dependencies resolve into targets - f = other_target.open('r') + # dynamic dependencies resolve into targets + f = other_target.open('r') This mechanism is an alternative to Task.requires_ in case @@ -206,6 +208,10 @@ It does come with some constraints: the Task.run_ method will resume from scratch each time a new task is yielded. In other words, you should make sure your Task.run_ method is idempotent. (This is good practice for all Tasks in Luigi, but especially so for tasks with dynamic dependencies). +As this might entail redundant calls to tasks' :func:`~luigi.task.Task.complete` methods, +you should consider setting the "cache_task_completion" option in the :ref:`worker-config`. +To further control how dynamic task requirements are handled internally by worker nodes, +there is also the option to wrap dependent tasks by :class:`~luigi.task.DynamicRequirements`. For an example of a workflow using dynamic dependencies, see `examples/dynamic_requirements.py `_. diff --git a/examples/dynamic_requirements.py b/examples/dynamic_requirements.py index ed2feba81f..13389c45e7 100644 --- a/examples/dynamic_requirements.py +++ b/examples/dynamic_requirements.py @@ -15,6 +15,7 @@ # limitations under the License. # +import os import random as rnd import time @@ -91,6 +92,20 @@ def run(self): with self.output().open('w') as f: f.write('Tada!') + # and in case data is rather long, consider wrapping the requirements + # in DynamicRequirements and optionally define a custom complete method + def custom_complete(complete_fn): + # example: Data() stores all outputs in the same directory, so avoid doing len(data) fs + # calls but rather check only the first, and compare basenames for the rest + # (complete_fn defaults to "lambda task: task.complete()" but can also include caching) + if not complete_fn(data_dependent_deps[0]): + return False + paths = [task.output().path for task in data_dependent_deps] + basenames = os.listdir(os.path.dirname(paths[0])) # a single fs call + return all(os.path.basename(path) in basenames for path in paths) + + yield luigi.DynamicRequirements(data_dependent_deps, custom_complete) + if __name__ == '__main__': luigi.run() diff --git a/luigi/__init__.py b/luigi/__init__.py index b9384422e0..874014211a 100644 --- a/luigi/__init__.py +++ b/luigi/__init__.py @@ -21,7 +21,9 @@ from luigi.__meta__ import __version__ from luigi import task -from luigi.task import Task, Config, ExternalTask, WrapperTask, namespace, auto_namespace +from luigi.task import ( + Task, Config, ExternalTask, WrapperTask, namespace, auto_namespace, DynamicRequirements, +) from luigi import target from luigi.target import Target @@ -36,10 +38,10 @@ Parameter, DateParameter, MonthParameter, YearParameter, DateHourParameter, DateMinuteParameter, DateSecondParameter, DateIntervalParameter, TimeDeltaParameter, - IntParameter, FloatParameter, BoolParameter, - TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, + IntParameter, FloatParameter, BoolParameter, PathParameter, + TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, EnumListParameter, NumericalParameter, ChoiceParameter, OptionalParameter, OptionalStrParameter, - OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter, + OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter, OptionalPathParameter, OptionalDictParameter, OptionalListParameter, OptionalTupleParameter, OptionalChoiceParameter, OptionalNumericalParameter, ) @@ -56,15 +58,16 @@ __all__ = [ 'task', 'Task', 'Config', 'ExternalTask', 'WrapperTask', 'namespace', 'auto_namespace', + 'DynamicRequirements', 'target', 'Target', 'LocalTarget', 'rpc', 'RemoteScheduler', 'RPCError', 'parameter', 'Parameter', 'DateParameter', 'MonthParameter', 'YearParameter', 'DateHourParameter', 'DateMinuteParameter', 'DateSecondParameter', 'DateIntervalParameter', 'TimeDeltaParameter', 'IntParameter', - 'FloatParameter', 'BoolParameter', 'TaskParameter', - 'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', + 'FloatParameter', 'BoolParameter', 'PathParameter', 'TaskParameter', + 'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'EnumListParameter', 'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event', 'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'OptionalStrParameter', - 'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter', + 'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter', 'OptionalPathParameter', 'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter', 'OptionalChoiceParameter', 'OptionalNumericalParameter', 'LuigiStatusCode', '__version__', diff --git a/luigi/__meta__.py b/luigi/__meta__.py index fed8e0002c..b3adee7a35 100644 --- a/luigi/__meta__.py +++ b/luigi/__meta__.py @@ -7,5 +7,5 @@ __author__ = 'The Luigi Authors' __contact__ = 'https://github.com/spotify/luigi' __license__ = 'Apache License 2.0' -__version__ = '3.0.3' +__version__ = '3.4.0' __status__ = 'Production' diff --git a/luigi/batch_notifier.py b/luigi/batch_notifier.py index c42d1ead15..e5c3760699 100644 --- a/luigi/batch_notifier.py +++ b/luigi/batch_notifier.py @@ -15,7 +15,9 @@ class batch_email(luigi.Config): email_interval = luigi.parameter.IntParameter( - default=60, description='Number of minutes between e-mail sends (default: 60)') + default=60, + config_path=dict(section="batch-notifier", name="email-interval-minutes"), + description='Number of minutes between e-mail sends (default: 60)') batch_mode = luigi.parameter.ChoiceParameter( default='unbatched_params', choices=('family', 'all', 'unbatched_params'), description='Method used for batching failures in e-mail. If "family" all failures for ' diff --git a/luigi/configuration/cfg_parser.py b/luigi/configuration/cfg_parser.py index 734aedcea3..abca6d713a 100644 --- a/luigi/configuration/cfg_parser.py +++ b/luigi/configuration/cfg_parser.py @@ -117,6 +117,7 @@ def before_write(self, parser, section, option, value): class LuigiConfigParser(BaseParser, ConfigParser): NO_DEFAULT = object() enabled = True + optionxform = str _instance = None _config_paths = [ '/etc/luigi/client.cfg', # Deprecated old-style global luigi config diff --git a/luigi/configuration/core.py b/luigi/configuration/core.py index 54e34f3b64..710f72b129 100644 --- a/luigi/configuration/core.py +++ b/luigi/configuration/core.py @@ -32,12 +32,15 @@ 'toml': LuigiTomlParser, } -# select parser via env var DEFAULT_PARSER = 'cfg' -PARSER = os.environ.get('LUIGI_CONFIG_PARSER', DEFAULT_PARSER) -if PARSER not in PARSERS: - warnings.warn("Invalid parser: {parser}".format(parser=PARSER)) - PARSER = DEFAULT_PARSER + + +def _get_default_parser(): + parser = os.environ.get('LUIGI_CONFIG_PARSER', DEFAULT_PARSER) + if parser not in PARSERS: + warnings.warn("Invalid parser: {parser}".format(parser=DEFAULT_PARSER)) + parser = DEFAULT_PARSER + return parser def _check_parser(parser_class, parser): @@ -50,9 +53,11 @@ def _check_parser(parser_class, parser): raise ImportError(msg.format(parser=parser)) -def get_config(parser=PARSER): +def get_config(parser=None): """Get configs singleton for parser """ + if parser is None: + parser = _get_default_parser() parser_class = PARSERS[parser] _check_parser(parser_class, parser) return parser_class.instance() @@ -66,21 +71,22 @@ def add_config_path(path): return False # select parser by file extension + default_parser = _get_default_parser() _base, ext = os.path.splitext(path) if ext and ext[1:] in PARSERS: parser = ext[1:] else: - parser = PARSER + parser = default_parser parser_class = PARSERS[parser] _check_parser(parser_class, parser) - if parser != PARSER: + if parser != default_parser: msg = ( "Config for {added} parser added, but used {used} parser. " "Set up right parser via env var: " "export LUIGI_CONFIG_PARSER={added}" ) - warnings.warn(msg.format(added=parser, used=PARSER)) + warnings.warn(msg.format(added=parser, used=default_parser)) # add config path to parser parser_class.add_config_path(path) diff --git a/luigi/configuration/toml_parser.py b/luigi/configuration/toml_parser.py index 43a97a08d7..683d8b5c4b 100644 --- a/luigi/configuration/toml_parser.py +++ b/luigi/configuration/toml_parser.py @@ -15,6 +15,7 @@ # limitations under the License. # import os.path +from configparser import ConfigParser try: import toml @@ -25,7 +26,7 @@ from ..freezing import recursively_freeze -class LuigiTomlParser(BaseParser): +class LuigiTomlParser(BaseParser, ConfigParser): NO_DEFAULT = object() enabled = bool(toml) data = dict() diff --git a/luigi/contrib/bigquery.py b/luigi/contrib/bigquery.py index 9da6d883ed..5d93efdaed 100644 --- a/luigi/contrib/bigquery.py +++ b/luigi/contrib/bigquery.py @@ -51,7 +51,7 @@ def is_error_5xx(err): wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(3), reraise=True, - after=lambda x: x.args[0].__initialise_client() + after=lambda x: x.args[0]._initialise_client() ) @@ -152,9 +152,9 @@ def __init__(self, oauth_credentials=None, descriptor='', http_=None): self.descriptor = descriptor self.http_ = http_ - self.__initialise_client() + self._initialise_client() - def __initialise_client(self): + def _initialise_client(self): authenticate_kwargs = gcp.get_authenticate_kwargs(self.oauth_credentials, self.http_) if self.descriptor: diff --git a/luigi/contrib/docker_runner.py b/luigi/contrib/docker_runner.py index 9655190760..7fde4d47a1 100644 --- a/luigi/contrib/docker_runner.py +++ b/luigi/contrib/docker_runner.py @@ -77,8 +77,24 @@ def command(self): def name(self): return None + @property + def host_config_options(self): + ''' + Override this to specify host_config options like gpu requests or shm + size e.g. `{"device_requests": [docker.types.DeviceRequest(count=1, capabilities=[["gpu"]])]}` + + See https://docker-py.readthedocs.io/en/stable/api.html#docker.api.container.ContainerApiMixin.create_host_config + ''' + return {} + @property def container_options(self): + ''' + Override this to specify container options like user or ports e.g. + `{"user": f"{os.getuid()}:{os.getgid()}"}` + + See https://docker-py.readthedocs.io/en/stable/api.html#docker.api.container.ContainerApiMixin.create_container + ''' return {} @property @@ -192,7 +208,8 @@ def run(self): % (self._image, self.command, self._binds)) host_config = self._client.create_host_config(binds=self._binds, - network_mode=self.network_mode) + network_mode=self.network_mode, + **self.host_config_options) container = self._client.create_container(self._image, command=self.command, diff --git a/luigi/contrib/dropbox.py b/luigi/contrib/dropbox.py index 9a8f4d7dfb..aaa77953b2 100644 --- a/luigi/contrib/dropbox.py +++ b/luigi/contrib/dropbox.py @@ -305,7 +305,7 @@ def fs(self): @contextmanager def temporary_path(self): tmp_dir = tempfile.mkdtemp() - num = random.randrange(0, 1e10) + num = random.randrange(0, 10_000_000_000) temp_path = '{}{}luigi-tmp-{:010}{}'.format( tmp_dir, os.sep, num, ntpath.basename(self.path)) diff --git a/luigi/contrib/ecs.py b/luigi/contrib/ecs.py index 8612d00499..6311db72b5 100644 --- a/luigi/contrib/ecs.py +++ b/luigi/contrib/ecs.py @@ -50,6 +50,7 @@ """ +import copy import time import logging import luigi @@ -150,8 +151,10 @@ def command(self): Override to return list of dicts with keys 'name' and 'command', describing the container names and commands to pass to the container. - Directly corresponds to the `overrides` parameter of runTask API. For - example:: + These values will be specified in the `containerOverrides` property of + the `overrides` parameter passed to the runTask API. + + Example:: [ { @@ -163,6 +166,80 @@ def command(self): """ pass + @staticmethod + def update_container_overrides_command(container_overrides, command): + """ + Update a list of container overrides with the specified command. + + The specified command will take precedence over any existing commands + in `container_overrides` for the same container name. If no existing + command yet exists in `container_overrides` for the specified command, + it will be added. + """ + for colliding_override in filter(lambda x: x['name'] == command['name'], container_overrides): + colliding_override['command'] = command['command'] + break + else: + container_overrides.append(command) + + @property + def combined_overrides(self): + """ + Return single dict combining any provided `overrides` parameters. + + This is used to allow custom `overrides` parameters to be specified in + `self.run_task_kwargs` while ensuring that the values specified in + `self.command` are honored in `containerOverrides`. + """ + overrides = copy.deepcopy(self.run_task_kwargs.get('overrides', {})) + if self.command: + if 'containerOverrides' in overrides: + for command in self.command: + self.update_container_overrides_command(overrides['containerOverrides'], command) + else: + overrides['containerOverrides'] = self.command + return overrides + + @property + def run_task_kwargs(self): + """ + Additional keyword arguments to be provided to ECS runTask API. + + Override this property in a subclass to provide additional parameters + such as `network_configuration`, `launchType`, etc. + + If the returned `dict` includes an `overrides` value with a nested + `containerOverrides` array defining one or more container `command` + values, prior to calling `run_task` they will be combined with and + superseded by any colliding values specified separately in the + `command` property. + + Example:: + + { + 'launchType': 'FARGATE', + 'platformVersion': '1.4.0', + 'networkConfiguration': { + 'awsvpcConfiguration': { + 'subnets': [ + 'subnet-01234567890abcdef', + 'subnet-abcdef01234567890' + ], + 'securityGroups': [ + 'sg-abcdef01234567890', + ], + 'assignPublicIp': 'ENABLED' + } + }, + 'overrides': { + 'ephemeralStorage': { + 'sizeInGiB': 30 + } + } + } + """ + return {} + def run(self): if (not self.task_def and not self.task_def_arn) or \ (self.task_def and self.task_def_arn): @@ -173,15 +250,16 @@ def run(self): response = client.register_task_definition(**self.task_def) self.task_def_arn = response['taskDefinition']['taskDefinitionArn'] + run_task_kwargs = self.run_task_kwargs + run_task_kwargs.update({ + 'taskDefinition': self.task_def_arn, + 'cluster': self.cluster, + 'overrides': self.combined_overrides, + }) + # Submit the task to AWS ECS and get assigned task ID # (list containing 1 string) - if self.command: - overrides = {'containerOverrides': self.command} - else: - overrides = {} - response = client.run_task(taskDefinition=self.task_def_arn, - overrides=overrides, - cluster=self.cluster) + response = client.run_task(**run_task_kwargs) if response['failures']: raise Exception(", ".join(["fail to run task {0} reason: {1}".format(failure['arn'], failure['reason']) diff --git a/luigi/contrib/external_program.py b/luigi/contrib/external_program.py index d5dcf26d77..b92c4767e2 100644 --- a/luigi/contrib/external_program.py +++ b/luigi/contrib/external_program.py @@ -260,13 +260,13 @@ class ExternalPythonProgramTask(ExternalProgramTask): :py:class:`luigi.parameter.Parameter` s for setting a virtualenv and for extending the ``PYTHONPATH``. """ - virtualenv = luigi.Parameter( + virtualenv = luigi.OptionalParameter( default=None, positional=False, description='path to the virtualenv directory to use. It should point to ' 'the directory containing the ``bin/activate`` file used for ' 'enabling the virtualenv.') - extra_pythonpath = luigi.Parameter( + extra_pythonpath = luigi.OptionalParameter( default=None, positional=False, description='extend the search path for modules by prepending this ' diff --git a/luigi/contrib/ftp.py b/luigi/contrib/ftp.py index 91c766729f..d155e8ef8b 100644 --- a/luigi/contrib/ftp.py +++ b/luigi/contrib/ftp.py @@ -254,7 +254,7 @@ def _sftp_put(self, local_path, path, atomic): self.conn.makedirs(directory) if atomic: - tmp_path = os.path.join(directory, 'luigi-tmp-{:09d}'.format(random.randrange(0, 1e10))) + tmp_path = os.path.join(directory, 'luigi-tmp-{:09d}'.format(random.randrange(0, 10_000_000_000))) else: tmp_path = normpath @@ -279,7 +279,7 @@ def _ftp_put(self, local_path, path, atomic): # random file name if atomic: - tmp_path = folder + os.sep + 'luigi-tmp-%09d' % random.randrange(0, 1e10) + tmp_path = folder + os.sep + 'luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) else: tmp_path = normpath @@ -297,7 +297,7 @@ def get(self, path, local_path): if folder and not os.path.exists(folder): os.makedirs(folder) - tmp_local_path = local_path + '-luigi-tmp-%09d' % random.randrange(0, 1e10) + tmp_local_path = local_path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) # download file self._connect() @@ -409,7 +409,7 @@ def open(self, mode): elif mode == 'r': temppath = '{}-luigi-tmp-{:09d}'.format( - self.path.lstrip('/'), random.randrange(0, 1e10) + self.path.lstrip('/'), random.randrange(0, 10_000_000_000) ) try: # store reference to the TemporaryDirectory because it will be removed on GC diff --git a/luigi/contrib/hadoop.py b/luigi/contrib/hadoop.py index 595e7bedec..07b0702050 100644 --- a/luigi/contrib/hadoop.py +++ b/luigi/contrib/hadoop.py @@ -25,6 +25,7 @@ import abc import datetime import glob +import hashlib import logging import os import pickle @@ -37,7 +38,6 @@ import sys import tempfile import warnings -from hashlib import md5 from itertools import groupby from luigi import configuration @@ -620,7 +620,7 @@ def group(self, input_stream): lines = [] for i, line in enumerate(input_stream): parts = line.rstrip('\n').split('\t') - blob = md5(str(i).encode('ascii')).hexdigest() # pseudo-random blob to make sure the input isn't sorted + blob = hashlib.new('md5', str(i).encode('ascii'), usedforsecurity=False).hexdigest() # pseudo-random blob to make sure the input isn't sorted lines.append((parts[:-1], blob, line)) for _, _, line in sorted(lines): output.write(line) @@ -913,7 +913,7 @@ def extra_modules(self): def extra_files(self): """ - Can be overriden in subclass. + Can be overridden in subclass. Each element is either a string, or a pair of two strings (src, dst). diff --git a/luigi/contrib/hadoop_jar.py b/luigi/contrib/hadoop_jar.py index 095fac4c4f..2635aeca5e 100644 --- a/luigi/contrib/hadoop_jar.py +++ b/luigi/contrib/hadoop_jar.py @@ -47,7 +47,7 @@ def fix_paths(job): args.append(x.path) else: # output x_path_no_slash = x.path[:-1] if x.path[-1] == '/' else x.path - y = luigi.contrib.hdfs.HdfsTarget(x_path_no_slash + '-luigi-tmp-%09d' % random.randrange(0, 1e10)) + y = luigi.contrib.hdfs.HdfsTarget(x_path_no_slash + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000)) tmp_files.append((y, x_path_no_slash)) logger.info('Using temp path: %s for path %s', y.path, x.path) args.append(y.path) diff --git a/luigi/contrib/hdfs/config.py b/luigi/contrib/hdfs/config.py index e80ca12fa9..abafcdfb87 100644 --- a/luigi/contrib/hdfs/config.py +++ b/luigi/contrib/hdfs/config.py @@ -84,7 +84,7 @@ def tmppath(path=None, include_unix_username=True): Note that include_unix_username might work on windows too. """ - addon = "luigitemp-%08d" % random.randrange(1e9) + addon = "luigitemp-%09d" % random.randrange(0, 10_000_000_000) temp_dir = '/tmp' # default tmp dir if none is specified in config # 1. Figure out to which temporary directory to place diff --git a/luigi/contrib/hdfs/target.py b/luigi/contrib/hdfs/target.py index 182c275aa1..8617fde26a 100644 --- a/luigi/contrib/hdfs/target.py +++ b/luigi/contrib/hdfs/target.py @@ -177,7 +177,7 @@ def is_writable(self): return False def _is_writable(self, path): - test_path = path + '.test_write_access-%09d' % random.randrange(1e10) + test_path = path + '.test_write_access-%09d' % random.randrange(10_000_000_000) try: self.fs.touchz(test_path) self.fs.remove(test_path, recursive=False) diff --git a/luigi/contrib/mrrunner.py b/luigi/contrib/mrrunner.py index e86bb6993d..7b60f03a42 100644 --- a/luigi/contrib/mrrunner.py +++ b/luigi/contrib/mrrunner.py @@ -19,7 +19,7 @@ """ Since after Luigi 2.5.0, this is a private module to Luigi. Luigi users should not rely on that importing this module works. Furthermore, "luigi mr streaming" -have been greatly superseeded by technoligies like Spark, Hive, etc. +have been greatly superseeded by technologies like Spark, Hive, etc. The hadoop runner. diff --git a/luigi/contrib/mssqldb.py b/luigi/contrib/mssqldb.py index 90e30bd2e2..57c0570673 100644 --- a/luigi/contrib/mssqldb.py +++ b/luigi/contrib/mssqldb.py @@ -22,7 +22,7 @@ logger = logging.getLogger('luigi-interface') try: - import _mssql + from pymssql import _mssql except ImportError: logger.warning("Loading MSSQL module without the python package pymssql. \ This will crash at runtime if SQL Server functionality is used.") @@ -107,7 +107,7 @@ def exists(self, connection=None): WHERE update_id = %s """.format(marker_table=self.marker_table), (self.update_id,)) - except _mssql.MSSQLDatabaseException as e: + except _mssql.MssqlDatabaseException as e: # Error number for table doesn't exist if e.number == 208: row = None @@ -145,7 +145,7 @@ def create_marker_table(self): """ .format(marker_table=self.marker_table) ) - except _mssql.MSSQLDatabaseException as e: + except _mssql.MssqlDatabaseException as e: # Table already exists code if e.number == 2714: pass diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index 74ad0836f7..f1dfebf4b8 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -19,6 +19,7 @@ Also provides a helper task to copy data into a Postgres table. """ +import os import datetime import logging import re @@ -29,12 +30,77 @@ logger = logging.getLogger('luigi-interface') -try: - import psycopg2 - import psycopg2.errorcodes - import psycopg2.extensions -except ImportError: - logger.warning("Loading postgres module without psycopg2 installed. Will crash at runtime if postgres functionality is used.") +DB_DRIVER = os.environ.get('LUIGI_PGSQL_DRIVER', 'psycopg2') + +DB_ERROR_CODES = {} +ERROR_DUPLICATE_TABLE = 'duplicate_table' +ERROR_UNDEFINED_TABLE = 'undefined_table' + +dbapi = None + +if DB_DRIVER == 'psycopg2': + try: + import psycopg2 as dbapi + + def update_error_codes(): + import psycopg2.errorcodes + + DB_ERROR_CODES.update({ + psycopg2.errorcodes.DUPLICATE_TABLE: ERROR_DUPLICATE_TABLE, + psycopg2.errorcodes.UNDEFINED_TABLE: ERROR_UNDEFINED_TABLE, + }) + update_error_codes() + except ImportError: + pass + +if dbapi is None or DB_DRIVER == 'pg8000': + try: + import pg8000.dbapi as dbapi # noqa: F811 + import pg8000.core + # pg8000 doesn't have an error code catalog so we need to make our own + # from https://www.postgresql.org/docs/8.2/errcodes-appendix.html + DB_ERROR_CODES.update({'42P07': ERROR_DUPLICATE_TABLE, '42P01': ERROR_UNDEFINED_TABLE}) + except ImportError: + pass + + +if dbapi is None: + logger.warning("Loading postgres module without psycopg2 nor pg8000 installed. " + "Will crash at runtime if postgres functionality is used.") + + +def _is_pg8000_error(exception): + try: + return isinstance(exception, dbapi.DatabaseError) and \ + isinstance(exception.args, tuple) and \ + isinstance(exception.args[0], dict) and \ + pg8000.core.RESPONSE_CODE in exception.args[0] + except NameError: + return False + + +def _pg8000_connection_reset(connection): + cursor = connection.cursor() + if connection.autocommit: + cursor.execute("DISCARD ALL") + else: + cursor.execute("ABORT") + cursor.execute("BEGIN TRANSACTION") + cursor.close() + + +def db_error_code(exception): + try: + error_code = None + if hasattr(exception, 'pgcode'): + error_code = exception.pgcode + elif _is_pg8000_error(exception): + error_code = exception.args[0][pg8000.core.RESPONSE_CODE] + + return DB_ERROR_CODES.get(error_code) + except TypeError as error: + error.__cause__ = exception + raise error class MultiReplacer: @@ -61,7 +127,7 @@ class MultiReplacer: >>> MultiReplacer(replace_pairs)("ab") 'xb' """ -# TODO: move to misc/util module + # TODO: move to misc/util module def __init__(self, replace_pairs): """ @@ -111,7 +177,7 @@ class PostgresTarget(luigi.Target): use_db_timestamps = True def __init__( - self, host, database, user, password, table, update_id, port=None + self, host, database, user, password, table, update_id, port=None ): """ Args: @@ -175,8 +241,8 @@ def exists(self, connection=None): (self.update_id,) ) row = cursor.fetchone() - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE: + except dbapi.DatabaseError as e: + if db_error_code(e) == ERROR_UNDEFINED_TABLE: row = None else: raise @@ -184,9 +250,9 @@ def exists(self, connection=None): def connect(self): """ - Get a psycopg2 connection object to the database where the table is. + Get a DBAPI 2.0 connection object to the database where the table is. """ - connection = psycopg2.connect( + connection = dbapi.connect( host=self.host, port=self.port, database=self.database, @@ -219,8 +285,8 @@ def create_marker_table(self): try: cursor.execute(sql) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.DUPLICATE_TABLE: + except dbapi.DatabaseError as e: + if db_error_code(e) == ERROR_DUPLICATE_TABLE: pass else: raise @@ -261,7 +327,7 @@ def map_column(self, value): else: return default_escape(str(value)) -# everything below will rarely have to be overridden + # everything below will rarely have to be overridden def output(self): """ @@ -286,7 +352,18 @@ def copy(self, cursor, file): column_names = [c[0] for c in self.columns] else: raise Exception('columns must consist of column strings or (column string, type string) tuples (was %r ...)' % (self.columns[0],)) - cursor.copy_from(file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names) + + # cursor.copy_from is not available in pg8000 + if hasattr(cursor, 'copy_from'): + cursor.copy_from( + file, self.table, null=r'\\N', sep=self.column_separator, columns=column_names) + else: + copy_sql = ( + "COPY {table} ({column_list}) FROM STDIN " + "WITH (FORMAT text, NULL '{null_string}', DELIMITER '{delimiter}')" + ).format(table=self.table, delimiter=self.column_separator, null_string=r'\\N', + column_list=", ".join(column_names)) + cursor.execute(copy_sql, stream=file) def run(self): """ @@ -327,11 +404,15 @@ def run(self): self.post_copy(connection) if self.enable_metadata_columns: self.post_copy_metacolumns(cursor) - except psycopg2.ProgrammingError as e: - if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: + except dbapi.DatabaseError as e: + if db_error_code(e) == ERROR_UNDEFINED_TABLE and attempt == 0: # if first attempt fails with "relation not found", try creating table logger.info("Creating table %s", self.table) - connection.reset() + # reset() is a psycopg2-specific method + if hasattr(connection, 'reset'): + connection.reset() + else: + _pg8000_connection_reset(connection) self.create_table(connection) else: raise diff --git a/luigi/contrib/rdbms.py b/luigi/contrib/rdbms.py index a9e0429a9b..ca24ae8ff1 100644 --- a/luigi/contrib/rdbms.py +++ b/luigi/contrib/rdbms.py @@ -249,7 +249,7 @@ def init_copy(self, connection): raise Exception("The clear_table attribute has been removed. Override init_copy instead!") if self.enable_metadata_columns: - self._add_metadata_columns(connection.cursor()) + self._add_metadata_columns(connection) def post_copy(self, connection): """ diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py index 7742053a42..4f3e57b455 100644 --- a/luigi/contrib/s3.py +++ b/luigi/contrib/s3.py @@ -440,7 +440,8 @@ def mkdir(self, path, parents=True, raise_if_exists=False): def listdir(self, path, start_time=None, end_time=None, return_key=False): """ Get an iterable with S3 folder contents. - Iterable contains paths relative to queried path. + Iterable contains absolute paths for which queried path is a prefix. + :param path: URL for target S3 location :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time @@ -471,6 +472,15 @@ def listdir(self, path, start_time=None, end_time=None, return_key=False): yield self._add_path_delimiter(path) + item.key[key_path_len:] def list(self, path, start_time=None, end_time=None, return_key=False): # backwards compat + """ + Get an iterable with S3 folder contents. + Iterable contains paths relative to queried path. + + :param path: URL for target S3 location + :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time + :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time + :param return_key: Optional argument, when set to True will return boto3's ObjectSummary (instead of the filename) + """ key_path_len = len(self._add_path_delimiter(path)) for item in self.listdir(path, start_time=start_time, end_time=end_time, return_key=return_key): if return_key: diff --git a/luigi/contrib/simulate.py b/luigi/contrib/simulate.py index 5ff8274e46..88ea90664c 100644 --- a/luigi/contrib/simulate.py +++ b/luigi/contrib/simulate.py @@ -83,7 +83,7 @@ def get_path(self): """ Returns a temporary file path based on a MD5 hash generated with the task's name and its arguments """ - md5_hash = hashlib.md5(self.task_id.encode()).hexdigest() + md5_hash = hashlib.new('md5', self.task_id.encode(), usedforsecurity=False).hexdigest() logger.debug('Hash %s corresponds to task %s', md5_hash, self.task_id) return os.path.join(self.temp_dir, str(self.unique.value), md5_hash) diff --git a/luigi/contrib/spark.py b/luigi/contrib/spark.py index 668b7ba19a..84068043f1 100644 --- a/luigi/contrib/spark.py +++ b/luigi/contrib/spark.py @@ -292,6 +292,10 @@ def files(self): if self.deploy_mode == "cluster": return [self.run_pickle] + @property + def pickle_protocol(self): + return configuration.get_config().getint('spark', 'pickle-protocol', pickle.DEFAULT_PROTOCOL) + def setup(self, conf): """ Called by the pyspark_runner with a SparkConf instance that will be used to instantiate the SparkContext @@ -335,12 +339,12 @@ def run(self): def _dump(self, fd): with self.no_unpicklable_properties(): if self.__module__ == '__main__': - d = pickle.dumps(self) + d = pickle.dumps(self, protocol=self.pickle_protocol) module_name = os.path.basename(sys.argv[0]).rsplit('.', 1)[0] d = d.replace(b'c__main__', b'c' + module_name.encode('ascii')) fd.write(d) else: - pickle.dump(self, fd) + pickle.dump(self, fd, protocol=self.pickle_protocol) def _setup_packages(self, sc): """ diff --git a/luigi/contrib/ssh.py b/luigi/contrib/ssh.py index 3119a94117..beda1c9e23 100644 --- a/luigi/contrib/ssh.py +++ b/luigi/contrib/ssh.py @@ -254,7 +254,7 @@ def put(self, local_path, path): if folder and not self.exists(folder): self.remote_context.check_output(['mkdir', '-p', folder]) - tmp_path = path + '-luigi-tmp-%09d' % random.randrange(0, 1e10) + tmp_path = path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) self._scp(local_path, "%s:%s" % (self.remote_context._host_ref(), tmp_path)) self.remote_context.check_output(['mv', tmp_path, path]) @@ -268,7 +268,7 @@ def get(self, path, local_path): except OSError: pass - tmp_local_path = local_path + '-luigi-tmp-%09d' % random.randrange(0, 1e10) + tmp_local_path = local_path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) self._scp("%s:%s" % (self.remote_context._host_ref(), path), tmp_local_path) os.rename(tmp_local_path, local_path) @@ -285,7 +285,7 @@ def __init__(self, fs, path): if folder: self.fs.mkdir(folder) - self.__tmp_path = self.path + '-luigi-tmp-%09d' % random.randrange(0, 1e10) + self.__tmp_path = self.path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) super(AtomicRemoteFileWriter, self).__init__( self.fs.remote_context._prepare_cmd(['cat', '>', self.__tmp_path])) diff --git a/luigi/execution_summary.py b/luigi/execution_summary.py index a430fd61c8..e818a1ed14 100644 --- a/luigi/execution_summary.py +++ b/luigi/execution_summary.py @@ -94,14 +94,14 @@ def _partition_tasks(worker): Still_pending_not_ext is only used to get upstream_failure, upstream_missing_dependency and run_by_other_worker """ task_history = worker._add_task_history - pending_tasks = {task for(task, status, ext) in task_history if status == 'PENDING'} + pending_tasks = {task for (task, status, ext) in task_history if status == 'PENDING'} set_tasks = {} set_tasks["completed"] = {task for (task, status, ext) in task_history if status == 'DONE' and task in pending_tasks} set_tasks["already_done"] = {task for (task, status, ext) in task_history if status == 'DONE' and task not in pending_tasks and task not in set_tasks["completed"]} set_tasks["ever_failed"] = {task for (task, status, ext) in task_history if status == 'FAILED'} set_tasks["failed"] = set_tasks["ever_failed"] - set_tasks["completed"] - set_tasks["scheduling_error"] = {task for(task, status, ext) in task_history if status == 'UNKNOWN'} + set_tasks["scheduling_error"] = {task for (task, status, ext) in task_history if status == 'UNKNOWN'} set_tasks["still_pending_ext"] = {task for (task, status, ext) in task_history if status == 'PENDING' and task not in set_tasks["ever_failed"] and task not in set_tasks["completed"] and not ext} set_tasks["still_pending_not_ext"] = {task for (task, status, ext) in task_history diff --git a/luigi/format.py b/luigi/format.py index f07befa96e..2faf94508f 100644 --- a/luigi/format.py +++ b/luigi/format.py @@ -69,7 +69,7 @@ def __init__(self, command, input_pipe=None): if input_pipe is not None: try: input_pipe.fileno() - except AttributeError: + except (AttributeError, io.UnsupportedOperation): # subprocess require a fileno to work, if not present we copy to disk first self._original_input = False f = tempfile.NamedTemporaryFile('wb', prefix='luigi-process_tmp', delete=False) diff --git a/luigi/freezing.py b/luigi/freezing.py index 3143d6b7c0..2f0a4b49f6 100644 --- a/luigi/freezing.py +++ b/luigi/freezing.py @@ -56,3 +56,14 @@ def recursively_freeze(value): elif isinstance(value, list) or isinstance(value, tuple): return tuple(recursively_freeze(v) for v in value) return value + + +def recursively_unfreeze(value): + """ + Recursively walks ``FrozenOrderedDict``s and ``tuple``s and converts them to ``dict`` and ``list``, respectively. + """ + if isinstance(value, Mapping): + return dict(((k, recursively_unfreeze(v)) for k, v in value.items())) + elif isinstance(value, list) or isinstance(value, tuple): + return list(recursively_unfreeze(v) for v in value) + return value diff --git a/luigi/interface.py b/luigi/interface.py index 0f51b07783..867463c159 100644 --- a/luigi/interface.py +++ b/luigi/interface.py @@ -18,7 +18,7 @@ This module contains the bindings for command line integration and dynamic loading of tasks If you don't want to run luigi from the command line. You may use the methods -defined in this module to programatically run luigi. +defined in this module to programmatically run luigi. """ import logging @@ -49,6 +49,10 @@ class core(task.Config): This is arguably a bit of a hack. ''' use_cmdline_section = False + ignore_unconsumed = { + 'autoload_range', + 'no_configure_logging', + } local_scheduler = parameter.BoolParameter( default=False, @@ -146,7 +150,7 @@ def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=No kill_signal = signal.SIGUSR1 if env_params.take_lock else None if (not env_params.no_lock and - not(lock.acquire_for(env_params.lock_pid_dir, env_params.lock_size, kill_signal))): + not (lock.acquire_for(env_params.lock_pid_dir, env_params.lock_size, kill_signal))): raise PidLockAlreadyTakenExit() if env_params.local_scheduler: @@ -173,6 +177,8 @@ def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=No success &= worker.run() luigi_run_result = LuigiRunResult(worker, success) logger.info(luigi_run_result.summary_text) + if hasattr(sch, 'close'): + sch.close() return luigi_run_result diff --git a/luigi/local_target.py b/luigi/local_target.py index c3302a7118..5cdade2ec7 100644 --- a/luigi/local_target.py +++ b/luigi/local_target.py @@ -40,7 +40,7 @@ def move_to_final_destination(self): os.rename(self.tmp_path, self.path) def generate_tmp_path(self, path): - return path + '-luigi-tmp-%09d' % random.randrange(0, 1e10) + return path + '-luigi-tmp-%09d' % random.randrange(0, 10_000_000_000) class LocalFileSystem(FileSystem): @@ -186,5 +186,5 @@ def fn(self): return self.path def __del__(self): - if self.is_tmp and self.exists(): + if hasattr(self, "is_tmp") and self.is_tmp and self.exists(): self.remove() diff --git a/luigi/lock.py b/luigi/lock.py index ae5fc67a86..dfa5acbcdb 100644 --- a/luigi/lock.py +++ b/luigi/lock.py @@ -80,7 +80,7 @@ def get_info(pid_dir, my_pid=None): my_cmd = getpcmd(my_pid) cmd_hash = my_cmd.encode('utf8') - pid_file = os.path.join(pid_dir, hashlib.md5(cmd_hash).hexdigest()) + '.pid' + pid_file = os.path.join(pid_dir, hashlib.new('md5', cmd_hash, usedforsecurity=False).hexdigest()) + '.pid' return my_pid, my_cmd, pid_file diff --git a/luigi/metrics.py b/luigi/metrics.py index 91b548363e..cd3364e251 100644 --- a/luigi/metrics.py +++ b/luigi/metrics.py @@ -1,16 +1,18 @@ import abc +import importlib from enum import Enum class MetricsCollectors(Enum): + custom = -1 default = 1 none = 1 datadog = 2 prometheus = 3 @classmethod - def get(cls, which): + def get(cls, which, custom_import=None): if which == MetricsCollectors.none: return NoMetricsCollector() elif which == MetricsCollectors.datadog: @@ -19,6 +21,22 @@ def get(cls, which): elif which == MetricsCollectors.prometheus: from luigi.contrib.prometheus_metric import PrometheusMetricsCollector return PrometheusMetricsCollector() + elif which == MetricsCollectors.custom: + if custom_import is None: + raise ValueError(f"MetricsCollectors value ' {which} ' is -1 and custom_import is None") + + split_import_string = custom_import.split(".") + + import_path = ".".join(split_import_string[:-1]) + import_class_string = split_import_string[-1] + + mod = importlib.import_module(import_path) + metrics_class = getattr(mod, import_class_string) + + if issubclass(metrics_class, MetricsCollector): + return metrics_class() + else: + raise ValueError(f"Custom Import: {custom_import} is not a subclass of MetricsCollector") else: raise ValueError("MetricsCollectors value ' {0} ' isn't supported", which) diff --git a/luigi/parameter.py b/luigi/parameter.py index 5ef47c31c0..3278377f1c 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -28,6 +28,12 @@ from json import JSONEncoder import operator from ast import literal_eval +from pathlib import Path +try: + import jsonschema + _JSONSCHEMA_ENABLED = True +except ImportError: + _JSONSCHEMA_ENABLED = False from configparser import NoOptionError, NoSectionError @@ -36,7 +42,7 @@ from luigi import configuration from luigi.cmdline_parser import CmdlineParser -from .freezing import recursively_freeze, FrozenOrderedDict +from .freezing import recursively_freeze, recursively_unfreeze, FrozenOrderedDict _no_value = object() @@ -94,6 +100,10 @@ class OptionalParameterTypeWarning(UserWarning): pass +class UnconsumedParameterWarning(UserWarning): + """Warning class for parameters that are not consumed by the task.""" + + class Parameter: """ Parameter whose value is a ``str``, and a base class for other parameter types. @@ -367,10 +377,14 @@ def normalize(self, x): def _warn_on_wrong_param_type(self, param_name, param_value): if not isinstance(param_value, self.expected_type) and param_value is not None: + try: + param_type = "any type in " + str([i.__name__ for i in self.expected_type]).replace("'", '"') + except TypeError: + param_type = f'type "{self.expected_type.__name__}"' warnings.warn( ( f'{self.__class__.__name__} "{param_name}" with value ' - f'"{param_value}" is not of type "{self.expected_type.__name__}" or None.' + f'"{param_value}" is not of {param_type} or None.' ), OptionalParameterTypeWarning, ) @@ -805,6 +819,7 @@ class TimeDeltaParameter(Parameter): """ Class that maps to timedelta using strings in any of the following forms: + * A bare number is interpreted as duration in seconds. * ``n {w[eek[s]]|d[ay[s]]|h[our[s]]|m[inute[s]|s[second[s]]}`` (e.g. "1 week 2 days" or "1 h") Note: multiple arguments must be supplied in longest to shortest unit order * ISO 8601 duration ``PnDTnHnMnS`` (each field optional, years and months not supported) @@ -852,6 +867,10 @@ def parse(self, input): See :py:class:`TimeDeltaParameter` for details on supported formats. """ + try: + return datetime.timedelta(seconds=float(input)) + except ValueError: + pass result = self._parseIso8601(input) if not result: result = self._parseSimple(input) @@ -1042,12 +1061,89 @@ def run(self): It can be used to define dynamic parameters, when you do not know the exact list of your parameters (e.g. list of tags, that are dynamically constructed outside Luigi), or you have a complex parameter containing logically related values (like a database connection config). + + It is possible to provide a JSON schema that should be validated by the given value: + + .. code-block:: python + + class MyTask(luigi.Task): + tags = luigi.DictParameter( + schema={ + "type": "object", + "patternProperties": { + ".*": {"type": "string", "enum": ["web", "staging"]}, + } + } + ) + + def run(self): + logging.info("Find server with role: %s", self.tags['role']) + server = aws.ec2.find_my_resource(self.tags) + + Using this schema, the following command will work: + + .. code-block:: console + + $ luigi --module my_tasks MyTask --tags '{"role": "web", "env": "staging"}' + + while this command will fail because the parameter is not valid: + + .. code-block:: console + + $ luigi --module my_tasks MyTask --tags '{"role": "UNKNOWN_VALUE", "env": "staging"}' + + Finally, the provided schema can be a custom validator: + + .. code-block:: python + + custom_validator = jsonschema.Draft4Validator( + schema={ + "type": "object", + "patternProperties": { + ".*": {"type": "string", "enum": ["web", "staging"]}, + } + } + ) + + class MyTask(luigi.Task): + tags = luigi.DictParameter(schema=custom_validator) + + def run(self): + logging.info("Find server with role: %s", self.tags['role']) + server = aws.ec2.find_my_resource(self.tags) + """ + def __init__( + self, + *args, + schema=None, + **kwargs, + ): + if schema is not None and not _JSONSCHEMA_ENABLED: + warnings.warn( + "The 'jsonschema' package is not installed so the parameter can not be validated " + "even though a schema is given." + ) + self.schema = None + else: + self.schema = schema + super().__init__( + *args, + **kwargs, + ) + def normalize(self, value): """ Ensure that dictionary parameter is converted to a FrozenOrderedDict so it can be hashed. """ + if self.schema is not None: + unfrozen_value = recursively_unfreeze(value) + try: + self.schema.validate(unfrozen_value) + value = unfrozen_value # Validators may update the instance inplace + except AttributeError: + jsonschema.validate(instance=unfrozen_value, schema=self.schema) return recursively_freeze(value) def parse(self, source): @@ -1105,8 +1201,89 @@ def run(self): .. code-block:: console $ luigi --module my_tasks MyTask --grades '[100,70]' + + It is possible to provide a JSON schema that should be validated by the given value: + + .. code-block:: python + + class MyTask(luigi.Task): + grades = luigi.ListParameter( + schema={ + "type": "array", + "items": { + "type": "number", + "minimum": 0, + "maximum": 10 + }, + "minItems": 1 + } + ) + + def run(self): + sum = 0 + for element in self.grades: + sum += element + avg = sum / len(self.grades) + + Using this schema, the following command will work: + + .. code-block:: console + + $ luigi --module my_tasks MyTask --numbers '[1, 8.7, 6]' + + while these commands will fail because the parameter is not valid: + + .. code-block:: console + + $ luigi --module my_tasks MyTask --numbers '[]' # must have at least 1 element + $ luigi --module my_tasks MyTask --numbers '[-999, 999]' # elements must be in [0, 10] + + Finally, the provided schema can be a custom validator: + + .. code-block:: python + + custom_validator = jsonschema.Draft4Validator( + schema={ + "type": "array", + "items": { + "type": "number", + "minimum": 0, + "maximum": 10 + }, + "minItems": 1 + } + ) + + class MyTask(luigi.Task): + grades = luigi.ListParameter(schema=custom_validator) + + def run(self): + sum = 0 + for element in self.grades: + sum += element + avg = sum / len(self.grades) + """ + def __init__( + self, + *args, + schema=None, + **kwargs, + ): + if schema is not None and not _JSONSCHEMA_ENABLED: + warnings.warn( + "The 'jsonschema' package is not installed so the parameter can not be validated " + "even though a schema is given." + ) + self.schema = None + else: + self.schema = schema + super().__init__( + *args, + **kwargs, + ) + def normalize(self, x): """ Ensure that struct is recursively converted to a tuple so it can be hashed. @@ -1114,6 +1291,13 @@ def normalize(self, x): :param str x: the value to parse. :return: the normalized (hashable/immutable) value. """ + if self.schema is not None: + unfrozen_value = recursively_unfreeze(x) + try: + self.schema.validate(unfrozen_value) + x = unfrozen_value # Validators may update the instance inplace + except AttributeError: + jsonschema.validate(instance=unfrozen_value, schema=self.schema) return recursively_freeze(x) def parse(self, x): @@ -1354,3 +1538,61 @@ class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.expected_type = self._var_type + + +class PathParameter(Parameter): + """ + Parameter whose value is a path. + + In the task definition, use + + .. code-block:: python + + class MyTask(luigi.Task): + existing_file_path = luigi.PathParameter(exists=True) + new_file_path = luigi.PathParameter() + + def run(self): + # Get data from existing file + with self.existing_file_path.open("r", encoding="utf-8") as f: + data = f.read() + + # Output message in new file + self.new_file_path.parent.mkdir(parents=True, exist_ok=True) + with self.new_file_path.open("w", encoding="utf-8") as f: + f.write("hello from a PathParameter => ") + f.write(data) + + At the command line, use + + .. code-block:: console + + $ luigi --module my_tasks MyTask --existing-file-path --new-file-path + """ + + def __init__(self, *args, absolute=False, exists=False, **kwargs): + """ + :param bool absolute: If set to ``True``, the given path is converted to an absolute path. + :param bool exists: If set to ``True``, a :class:`ValueError` is raised if the path does not exist. + """ + super().__init__(*args, **kwargs) + + self.absolute = absolute + self.exists = exists + + def normalize(self, x): + """ + Normalize the given value to a :class:`pathlib.Path` object. + """ + path = Path(x) + if self.absolute: + path = path.absolute() + if self.exists and not path.exists(): + raise ValueError(f"The path {path} does not exist.") + return path + + +class OptionalPathParameter(OptionalParameter, PathParameter): + """Class to parse optional path parameters.""" + + expected_type = (str, Path) diff --git a/luigi/rpc.py b/luigi/rpc.py index e30146c7bd..d5215e78ae 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -19,6 +19,7 @@ rpc.py implements the client side of it, server.py implements the server side. See :doc:`/central_scheduler` for more info. """ +import abc import os import json import logging @@ -54,11 +55,12 @@ def _urljoin(base, url): """ Join relative URLs to base URLs like urllib.parse.urljoin but support arbitrary URIs (esp. 'http+unix://'). + base part is fixed or mounted point, every url contains full base part. """ parsed = urlparse(base) scheme = parsed.scheme return urlparse( - urljoin(parsed._replace(scheme='http').geturl(), url) + urljoin(parsed._replace(scheme='http').geturl(), parsed.path + (url if url[0] == '/' else '/' + url)) )._replace(scheme=scheme).geturl() @@ -69,7 +71,17 @@ def __init__(self, message, sub_exception=None): self.sub_exception = sub_exception -class URLLibFetcher: +class _FetcherInterface(metaclass=abc.ABCMeta): + @abc.abstractmethod + def fetch(self, full_url, body, timeout): + pass + + @abc.abstractmethod + def close(self): + pass + + +class URLLibFetcher(_FetcherInterface): raises = (URLError, socket.timeout) def _create_request(self, full_url, body=None): @@ -96,12 +108,15 @@ def fetch(self, full_url, body, timeout): req = self._create_request(full_url, body=body) return urlopen(req, timeout=timeout).read().decode('utf-8') + def close(self): + pass -class RequestsFetcher: - def __init__(self, session): + +class RequestsFetcher(_FetcherInterface): + def __init__(self): from requests import exceptions as requests_exceptions self.raises = requests_exceptions.RequestException - self.session = session + self.session = requests.Session() self.process_id = os.getpid() def check_pid(self): @@ -117,6 +132,9 @@ def fetch(self, full_url, body, timeout): resp.raise_for_status() return resp.text + def close(self): + self.session.close() + class RemoteScheduler: """ @@ -140,10 +158,13 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._rpc_log_retries = config.getboolean('core', 'rpc-log-retries', True) if HAS_REQUESTS: - self._fetcher = RequestsFetcher(requests.Session()) + self._fetcher = RequestsFetcher() else: self._fetcher = URLLibFetcher() + def close(self): + self._fetcher.close() + def _get_retryer(self): def retry_logging(retry_state): if self._rpc_log_retries: diff --git a/luigi/scheduler.py b/luigi/scheduler.py index b34f89646a..10d67a10af 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -145,6 +145,7 @@ class scheduler(Config): send_messages = parameter.BoolParameter(default=True) metrics_collector = parameter.EnumParameter(enum=MetricsCollectors, default=MetricsCollectors.default) + metrics_custom_import = parameter.OptionalStrParameter(default=None) stable_done_cooldown_secs = parameter.IntParameter(default=10, description="Sets cooldown period to avoid running the same task twice") @@ -695,7 +696,7 @@ def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs if self._config.batch_emails: self._email_batcher = BatchNotifier() - self._state._metrics_collector = MetricsCollectors.get(self._config.metrics_collector) + self._state._metrics_collector = MetricsCollectors.get(self._config.metrics_collector, self._config.metrics_custom_import) def load(self): self._state.load() @@ -1230,7 +1231,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, if len(batched_tasks) > 1: batch_string = '|'.join(task.id for task in batched_tasks) - batch_id = hashlib.md5(batch_string.encode('utf-8')).hexdigest() + batch_id = hashlib.new('md5', batch_string.encode('utf-8'), usedforsecurity=False).hexdigest() for task in batched_tasks: self._state.set_batch_running(task, batch_id, worker_id) @@ -1441,7 +1442,7 @@ def filter_func(_): terms = search.split() def filter_func(t): - return all(term in t.pretty_id for term in terms) + return all(term.casefold() in t.pretty_id.casefold() for term in terms) tasks = self._state.get_active_tasks_by_status(status) if status else self._state.get_active_tasks() for task in filter(filter_func, tasks): diff --git a/luigi/server.py b/luigi/server.py index 50948624d3..cc6d32a4f6 100644 --- a/luigi/server.py +++ b/luigi/server.py @@ -245,28 +245,32 @@ def from_utc(utcTime, fmt=None): class RecentRunHandler(BaseTaskHistoryHandler): def get(self): - tasks = self._scheduler.task_history.find_latest_runs() - self.render("recent.html", tasks=tasks) + with self._scheduler.task_history._session(None) as session: + tasks = self._scheduler.task_history.find_latest_runs(session) + self.render("recent.html", tasks=tasks) class ByNameHandler(BaseTaskHistoryHandler): def get(self, name): - tasks = self._scheduler.task_history.find_all_by_name(name) - self.render("recent.html", tasks=tasks) + with self._scheduler.task_history._session(None) as session: + tasks = self._scheduler.task_history.find_all_by_name(name, session) + self.render("recent.html", tasks=tasks) class ByIdHandler(BaseTaskHistoryHandler): def get(self, id): - task = self._scheduler.task_history.find_task_by_id(id) - self.render("show.html", task=task) + with self._scheduler.task_history._session(None) as session: + task = self._scheduler.task_history.find_task_by_id(id, session) + self.render("show.html", task=task) class ByParamsHandler(BaseTaskHistoryHandler): def get(self, name): payload = self.get_argument('data', default="{}") arguments = json.loads(payload) - tasks = self._scheduler.task_history.find_all_by_parameters(name, session=None, **arguments) - self.render("recent.html", tasks=tasks) + with self._scheduler.task_history._session(None) as session: + tasks = self._scheduler.task_history.find_all_by_parameters(name, session=session, **arguments) + self.render("recent.html", tasks=tasks) class RootPathHandler(BaseTaskHistoryHandler): diff --git a/luigi/setup_logging.py b/luigi/setup_logging.py index 60ff39dc83..fdda06b79c 100644 --- a/luigi/setup_logging.py +++ b/luigi/setup_logging.py @@ -24,6 +24,7 @@ import os.path from luigi.configuration import get_config, LuigiConfigParser +from luigi.freezing import recursively_unfreeze from configparser import NoSectionError @@ -40,7 +41,7 @@ def _section(cls, opts): logging_config = cls.config['logging'] except (TypeError, KeyError, NoSectionError): return False - logging.config.dictConfig(logging_config) + logging.config.dictConfig(recursively_unfreeze(logging_config)) return True @classmethod diff --git a/luigi/static/visualiser/index.html b/luigi/static/visualiser/index.html index 9c433bb25e..63052a5e10 100644 --- a/luigi/static/visualiser/index.html +++ b/luigi/static/visualiser/index.html @@ -16,6 +16,7 @@ + diff --git a/luigi/static/visualiser/js/graph.js b/luigi/static/visualiser/js/graph.js index e147edadf5..f996526bd7 100644 --- a/luigi/static/visualiser/js/graph.js +++ b/luigi/static/visualiser/js/graph.js @@ -19,6 +19,9 @@ Graph = (function() { /* Amount of horizontal space given for each node */ var nodeWidth = 200; + /* Random horizontal offset for each row */ + var jitterWidth = 100; + /* Calculate minimum SVG height required for legend */ var legendMaxY = (function () { return Object.keys(statusColors).length * legendLineHeight + ( legendLineHeight / 2 ) @@ -65,21 +68,72 @@ Graph = (function() { }); return edges; } - - /* Compute the maximum depth of each node for layout purposes, returns the number - of nodes at each depth level (for layout purposes) */ + /* Compute the depth of each node for layout purposes */ function computeDepth(nodes, nodeIndex) { + var selfDependencies = false function descend(n, depth) { if (n.depth === undefined || depth > n.depth) { n.depth = depth; $.each(n.deps, function(i, dep) { if (nodeIndex[dep]) { - descend(nodes[nodeIndex[dep]], depth + 1); + var child_node = nodes[nodeIndex[dep]] + descend(child_node, depth + 1); + if (!selfDependencies && n.name == child_node.name) { + selfDependencies = true; + } } }); } } descend(nodes[0], 0); + return selfDependencies + } + + /* Group tasks, so all tasks with the same name appear at the same depth. */ + function groupTasks(nodes) { + + // compute average assigned depth + var taskDepths = {}; + $.each(nodes, function(i, n) { + if (taskDepths[n.name] === undefined) { + taskDepths[n.name] = [n.depth]; + } else { + taskDepths[n.name].push(n.depth); + } + }); + var averages = []; + $.each(taskDepths, function(key, array) { + var total = 0; + for (var i in array) total += array[i]; + var mean = total / array.length; + averages.push([key, mean]); + }); + + // sort tasks + averages.sort( function(first, second) { + return first[1] - second[1]; + }); + + // reassign task depths and node depths + var classDepths = {} + $.each(averages, function(i, a) { + classDepths[a[0]] = i; + }); + + $.each(nodes, function(i, n) { + n.depth = classDepths[n.name]; + }); + return classDepths + } + + /* Compute the depth of each node for layout purposes, returns the number + of nodes at each depth level (for layout purposes) */ + function computeRows(nodes, nodeIndex) { + var selfDependencies = computeDepth(nodes, nodeIndex) + + if (!selfDependencies) { + var classDepths = groupTasks(nodes) + } var rowSizes = []; function placeNodes(n, depth) { @@ -91,7 +145,9 @@ Graph = (function() { rowSizes[depth]++; $.each(n.deps, function(i, dep) { if (nodeIndex[dep]) { - placeNodes(nodes[nodeIndex[dep]], depth + 1); + var next_node = nodes[nodeIndex[dep]] + var depth = (selfDependencies ? depth + 1 : classDepths[next_node.name]) + placeNodes(next_node, depth); } }); } @@ -100,7 +156,6 @@ Graph = (function() { return rowSizes; } - /* Format nodes according to their depth and horizontal sort order. Algorithm: evenly distribute nodes along each depth level, offsetting each by the text line height to prevent overlapping text. This is done within @@ -108,18 +163,25 @@ Graph = (function() { is at least nodeWidth to ensure readability. The height of each level is determined by number of nodes divided by number of columns, rounded up. */ function layoutNodes(nodes, rowSizes) { - var numCols = Math.max(2, Math.floor(graphWidth / nodeWidth)); + var numCols = Math.max(2, Math.floor((graphWidth - jitterWidth) / nodeWidth)); function rowStartPosition(depth) { if (depth === 0) return 20; var rowHeight = Math.ceil(rowSizes[depth-1] / numCols); return rowStartPosition(depth-1)+Math.max(rowHeight * nodeHeight + 100); } + var jitter = [] + for (var i in rowSizes) { + jitter[i] = Math.ceil(Math.random() * jitterWidth) + } $.each(nodes, function(i, node) { var numRows = Math.ceil(rowSizes[node.depth] / numCols); var levelCols = Math.ceil(rowSizes[node.depth] / numRows); var row = node.xOrder % numRows; var col = node.xOrder / numRows; - node.x = ((col + 1) / (levelCols + 1)) * (graphWidth - 200); + node.x = + ((col + 1) / (levelCols + 1)) + * (graphWidth - jitterWidth - nodeWidth) + + jitter[node.depth]; node.y = rowStartPosition(node.depth) + row * nodeHeight; }); } @@ -132,7 +194,7 @@ Graph = (function() { var nodes = $.map(tasks, nodeFromTask); var nodeIndex = uniqueIndexByProperty(nodes, "taskId"); - var rowSizes = computeDepth(nodes, nodeIndex); + var rowSizes = computeRows(nodes, nodeIndex); nodes = $.map(nodes, function(node) { return node.depth >= 0 ? node: null; }); @@ -207,7 +269,7 @@ Graph = (function() { $(svgLink(node.trackingUrl)) .append( $(svgElement("text")) - .text(node.name) + .text(escapeHtml(node.name)) .attr("y", 3)) .attr("class","graph-node-a") .attr("data-task-status", node.status) @@ -215,7 +277,7 @@ Graph = (function() { .appendTo(g); var titleText = node.name; - var content = $.map(node.params, function (value, name) { return name + ": " + value; }).join("
"); + var content = $.map(node.params, function (value, name) { return escapeHtml(name + ": " + value); }).join("
"); g.attr("title", titleText) .popover({ trigger: 'hover', @@ -251,7 +313,7 @@ Graph = (function() { .appendTo(legend); $(svgElement("text")) - .text(key.charAt(0).toUpperCase() + key.substring(1).toLowerCase().replace(/_./gi, function (x) { return " " + x[1].toUpperCase(); })) + .text(escapeHtml(key.charAt(0).toUpperCase() + key.substring(1).toLowerCase().replace(/_./gi, function (x) { return " " + x[1].toUpperCase(); }))) .attr("x", legendLineHeight + 14) .attr("y", legendLineHeight+(x*legendLineHeight)) .appendTo(legend); @@ -278,6 +340,7 @@ Graph = (function() { uniqueIndexByProperty: uniqueIndexByProperty, createDependencyEdges: createDependencyEdges, computeDepth: computeDepth, + computeRows: computeRows, createGraph: createGraph, findBounds: findBounds } diff --git a/luigi/static/visualiser/js/test/graph_test.js b/luigi/static/visualiser/js/test/graph_test.js index 66129b5206..780f6d3839 100644 --- a/luigi/static/visualiser/js/test/graph_test.js +++ b/luigi/static/visualiser/js/test/graph_test.js @@ -62,6 +62,40 @@ test("computeDepth", function() { equal(E.depth, -1); }); +test("computeRowsSelfDeps", function () { + var A1 = {name: "A", taskId: "A1", deps: ["A2"], depth: -1} + var A2 = {name: "A", taskId: "A2", deps: [], depth: -1} + var nodes = [A1, A2] + var nodeIndex = {"A1": 0, "A2": 1} + var rowSizes = Graph.testableMethods.computeRows(nodes, nodeIndex) + equal(A1.depth, 0) + equal(A2.depth, 1) + equal(rowSizes, [1, 1]) +}); + +test("computeRowsGrouped", function() { + var A0 = {name: "A", taskId: "A0", deps: ["D0", "B0"], depth: -1} + var B0 = {name: "B", taskId: "B0", deps: ["C1", "C2"], depth: -1} + var C1 = {name: "C", taskId: "C1", deps: ["D1", "E1"], depth: -1} + var C2 = {name: "C", taskId: "C2", deps: ["D2", "E2"], depth: -1} + var D0 = {name: "D", taskId: "D0", deps: [], depth: -1} + var D1 = {name: "D", taskId: "D1", deps: [], depth: -1} + var D2 = {name: "D", taskId: "D2", deps: [], depth: -1} + var E1 = {name: "E", taskId: "E1", deps: [], depth: -1} + var E2 = {name: "E", taskId: "E2", deps: [], depth: -1} + var rowSizes = Graph.testableMethods.computeRows(nodes, nodeIndex) + equal(A0.depth, 0) + equal(B0.depth, 1) + equal(C1.depth, 2) + equal(C2.depth, 2) + equal(D0.depth, 3) + equal(D1.depth, 3) + equal(D2.depth, 3) + equal(E1.depth, 4) + equal(E2.depth, 4) + equal(rowSizes, [1, 1, 2, 3, 2]) +}); + test("createGraph", function() { var tasks = [ {taskId: "A", deps: ["B","C"], status: "PENDING"}, diff --git a/luigi/static/visualiser/js/util.js b/luigi/static/visualiser/js/util.js new file mode 100644 index 0000000000..9a693e515f --- /dev/null +++ b/luigi/static/visualiser/js/util.js @@ -0,0 +1,8 @@ +function escapeHtml(unsafe) { + return unsafe + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); +} diff --git a/luigi/static/visualiser/js/visualiserApp.js b/luigi/static/visualiser/js/visualiserApp.js index 83de46b9ba..156106d714 100644 --- a/luigi/static/visualiser/js/visualiserApp.js +++ b/luigi/static/visualiser/js/visualiserApp.js @@ -1018,8 +1018,8 @@ function visualiserApp(luigi) { function renderParams(params) { var htmls = []; for (var key in params) { - htmls.push('' + key + - '=' + params[key] + ''); + htmls.push('' + escapeHtml(key) + + '=' + escapeHtml(params[key]) + ''); } return htmls.join(', '); } diff --git a/luigi/target.py b/luigi/target.py index 7ed094f094..8b333f5326 100644 --- a/luigi/target.py +++ b/luigi/target.py @@ -87,7 +87,7 @@ class FileSystem(metaclass=abc.ABCMeta): FileSystem abstraction used in conjunction with :py:class:`FileSystemTarget`. Typically, a FileSystem is associated with instances of a :py:class:`FileSystemTarget`. The - instances of the py:class:`FileSystemTarget` will delegate methods such as + instances of the :py:class:`FileSystemTarget` will delegate methods such as :py:meth:`FileSystemTarget.exists` and :py:meth:`FileSystemTarget.remove` to the FileSystem. Methods of FileSystem raise :py:class:`FileSystemException` if there is a problem completing the @@ -284,7 +284,7 @@ def run(self): with self.output().temporary_path() as self.temp_output_path: run_some_external_command(output_path=self.temp_output_path) """ - num = random.randrange(0, 1e10) + num = random.randrange(0, 10_000_000_000) slashless_path = self.path.rstrip('/').rstrip("\\") _temp_path = '{}-luigi-tmp-{:010}{}'.format( slashless_path, @@ -328,7 +328,7 @@ def close(self): self.move_to_final_destination() def generate_tmp_path(self, path): - return os.path.join(tempfile.gettempdir(), 'luigi-s3-tmp-%09d' % random.randrange(0, 1e10)) + return os.path.join(tempfile.gettempdir(), 'luigi-s3-tmp-%09d' % random.randrange(0, 10_000_000_000)) def move_to_final_destination(self): raise NotImplementedError() diff --git a/luigi/task.py b/luigi/task.py index b24cac16e5..f5a108dc23 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -19,7 +19,7 @@ It is a central concept of Luigi and represents the state of the workflow. See :doc:`/tasks` for an overview. """ - +from collections import deque, OrderedDict from contextlib import contextmanager import logging import traceback @@ -32,9 +32,11 @@ import luigi +from luigi import configuration from luigi import parameter from luigi.task_register import Register from luigi.parameter import ParameterVisibility +from luigi.parameter import UnconsumedParameterWarning Parameter = parameter.Parameter logger = logging.getLogger('luigi-interface') @@ -123,7 +125,7 @@ def task_id_str(task_family, params): # task_id is a concatenation of task family, the first values of the first 3 parameters # sorted by parameter name and a md5hash of the family/parameters as a cananocalised json. param_str = json.dumps(params, separators=(',', ':'), sort_keys=True) - param_hash = hashlib.md5(param_str.encode('utf-8')).hexdigest() + param_hash = hashlib.new('md5', param_str.encode('utf-8'), usedforsecurity=False).hexdigest() param_summary = '_'.join(p[:TASK_ID_TRUNCATE_PARAMS] for p in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS])) @@ -302,7 +304,7 @@ def task_module(self): task_namespace = __not_user_specified """ - This value can be overriden to set the namespace that will be used. + This value can be overridden to set the namespace that will be used. (See :ref:`Task.namespaces_famlies_and_ids`) If it's not specified and you try to read this value anyway, it will return garbage. Please use :py:meth:`get_task_namespace` to read the namespace. @@ -429,6 +431,25 @@ def list_to_tuple(x): return tuple(x) else: return x + + # Check for unconsumed parameters + conf = configuration.get_config() + if not hasattr(cls, "_unconsumed_params"): + cls._unconsumed_params = set() + if task_family in conf.sections(): + ignore_unconsumed = getattr(cls, 'ignore_unconsumed', set()) + for key, value in conf[task_family].items(): + key = key.replace('-', '_') + composite_key = f"{task_family}_{key}" + if key not in result and key not in ignore_unconsumed and composite_key not in cls._unconsumed_params: + warnings.warn( + "The configuration contains the parameter " + f"'{key}' with value '{value}' that is not consumed by the task " + f"'{task_family}'.", + UnconsumedParameterWarning, + ) + cls._unconsumed_params.add(composite_key) + # Sort it by the correct order and make a list return [(param_name, list_to_tuple(result[param_name])) for param_name, param_obj in params] @@ -746,6 +767,87 @@ def bulk_complete(cls, parameter_tuples): return generated_tuples +class DynamicRequirements(object): + """ + Wraps dynamic requirements yielded in tasks's run methods to control how completeness checks of + (e.g.) large chunks of tasks are performed. Besides the wrapped *requirements*, instances of + this class can be passed an optional function *custom_complete* that might implement an + optimized check for completeness. If set, the function will be called with a single argument, + *complete_fn*, which should be used to perform the per-task check. Example: + + .. code-block:: python + + class SomeTaskWithDynamicRequirements(luigi.Task): + ... + + def run(self): + large_chunk_of_tasks = [OtherTask(i=i) for i in range(10000)] + + def custom_complete(complete_fn): + # example: assume OtherTask always write into the same directory, so just check + # if the first task is complete, and compare basenames for the rest + if not complete_fn(large_chunk_of_tasks[0]): + return False + paths = [task.output().path for task in large_chunk_of_tasks] + basenames = os.listdir(os.path.dirname(paths[0])) # a single fs call + return all(os.path.basename(path) in basenames for path in paths) + + yield DynamicRequirements(large_chunk_of_tasks, custom_complete) + + .. py:attribute:: requirements + + The original, wrapped requirements. + + .. py:attribute:: flat_requirements + + Flattened view of the wrapped requirements (via :py:func:`flatten`). Read only. + + .. py:attribute:: paths + + Outputs of the requirements in the identical structure (via :py:func:`getpaths`). Read only. + + .. py:attribute:: custom_complete + + The optional, custom function performing the completeness check of the wrapped requirements. + """ + + def __init__(self, requirements, custom_complete=None): + super().__init__() + + # store attributes + self.requirements = requirements + self.custom_complete = custom_complete + + # cached flat requirements and paths + self._flat_requirements = None + self._paths = None + + @property + def flat_requirements(self): + if self._flat_requirements is None: + self._flat_requirements = flatten(self.requirements) + return self._flat_requirements + + @property + def paths(self): + if self._paths is None: + self._paths = getpaths(self.requirements) + return self._paths + + def complete(self, complete_fn=None): + # default completeness check + if complete_fn is None: + def complete_fn(task): + return task.complete() + + # use the custom complete function when set + if self.custom_complete: + return self.custom_complete(complete_fn) + + # default implementation + return all(complete_fn(t) for t in self.flat_requirements) + + class ExternalTask(Task): """ Subclass for references to external dependencies. @@ -855,7 +957,7 @@ def getpaths(struct): def flatten(struct): """ - Creates a flat list of all all items in structured output (dicts, lists, items): + Creates a flat list of all items in structured output (dicts, lists, items): .. code-block:: python @@ -892,14 +994,19 @@ def flatten(struct): def flatten_output(task): """ Lists all output targets by recursively walking output-less (wrapper) tasks. - - FIXME order consistently. """ - r = flatten(task.output()) - if not r: - for dep in flatten(task.requires()): - r += flatten_output(dep) - return r + + output_tasks = OrderedDict() # OrderedDict used as ordered set + tasks_to_process = deque([task]) + while tasks_to_process: + current_task = tasks_to_process.popleft() + if flatten(current_task.output()): + if current_task not in output_tasks: + output_tasks[current_task] = None + else: + tasks_to_process.extend(flatten(current_task.requires())) + + return flatten(task.output() for task in output_tasks) def _task_wraps(task_class): diff --git a/luigi/task_register.py b/luigi/task_register.py index 5d61df36c7..f5e0acdd32 100644 --- a/luigi/task_register.py +++ b/luigi/task_register.py @@ -54,7 +54,7 @@ class Register(abc.ABCMeta): ambiguous task name (two :py:class:`Task` have the same name). This denotes an error.""" - def __new__(metacls, classname, bases, classdict): + def __new__(metacls, classname, bases, classdict, **kwargs): """ Custom class creation for namespacing. @@ -63,7 +63,7 @@ def __new__(metacls, classname, bases, classdict): When the set or inherited namespace evaluates to ``None``, set the task namespace to whatever the currently declared namespace is. """ - cls = super(Register, metacls).__new__(metacls, classname, bases, classdict) + cls = super(Register, metacls).__new__(metacls, classname, bases, classdict, **kwargs) cls._namespace_at_class_time = metacls._get_namespace(cls.__module__) metacls._reg.append(cls) return cls @@ -118,10 +118,11 @@ def task_family(cls): """ Internal note: This function will be deleted soon. """ - if not cls.get_task_namespace(): + task_namespace = cls.get_task_namespace() + if not task_namespace: return cls.__name__ else: - return "{}.{}".format(cls.get_task_namespace(), cls.__name__) + return f"{task_namespace}.{cls.__name__}" @classmethod def _get_reg(cls): diff --git a/luigi/tools/deps_tree.py b/luigi/tools/deps_tree.py index 9f207341b7..77e49e2229 100755 --- a/luigi/tools/deps_tree.py +++ b/luigi/tools/deps_tree.py @@ -50,7 +50,7 @@ def print_tree(task, indent='', last=True): name = task.__class__.__name__ params = task.to_str_params(only_significant=True) result = '\n' + indent - if(last): + if (last): result += '└─--' indent += ' ' else: diff --git a/luigi/worker.py b/luigi/worker.py index ba575b7fdf..01a2772a68 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -29,6 +29,7 @@ """ import collections +import collections.abc import datetime import getpass import importlib @@ -39,6 +40,7 @@ import subprocess import sys import contextlib +import functools import queue as Queue import random @@ -46,7 +48,6 @@ import threading import time import traceback -import types from luigi import notifications from luigi.event import Event @@ -54,7 +55,7 @@ from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, UNKNOWN, Scheduler, RetryPolicy from luigi.scheduler import WORKER_STATE_ACTIVE, WORKER_STATE_DISABLED from luigi.target import Target -from luigi.task import Task, flatten, getpaths, Config +from luigi.task import Task, Config, DynamicRequirements from luigi.task_register import TaskClassException from luigi.task_status import RUNNING from luigi.parameter import BoolParameter, FloatParameter, IntParameter, OptionalParameter, Parameter, TimeDeltaParameter @@ -117,7 +118,7 @@ class TaskProcess(multiprocessing.Process): def __init__(self, task, worker_id, result_queue, status_reporter, use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True, - check_complete_on_run=False): + check_complete_on_run=False, task_completion_cache=None): super(TaskProcess, self).__init__() self.task = task self.worker_id = worker_id @@ -128,11 +129,15 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None self.check_unfulfilled_deps = check_unfulfilled_deps self.check_complete_on_run = check_complete_on_run + self.task_completion_cache = task_completion_cache + + # completeness check using the cache + self.check_complete = functools.partial(check_complete_cached, completion_cache=task_completion_cache) def _run_get_new_deps(self): task_gen = self.task.run() - if not isinstance(task_gen, types.GeneratorType): + if not isinstance(task_gen, collections.abc.Generator): return None next_send = None @@ -145,20 +150,27 @@ def _run_get_new_deps(self): except StopIteration: return None - new_req = flatten(requires) - if all(t.complete() for t in new_req): - next_send = getpaths(requires) - else: + # if requires is not a DynamicRequirements, create one to use its default behavior + if not isinstance(requires, DynamicRequirements): + requires = DynamicRequirements(requires) + + if not requires.complete(self.check_complete): + # not all requirements are complete, return them which adds them to the tree new_deps = [(t.task_module, t.task_family, t.to_str_params()) - for t in new_req] + for t in requires.flat_requirements] return new_deps + # get the next generator result + next_send = requires.paths + def run(self): logger.info('[pid %s] Worker %s running %s', os.getpid(), self.worker_id, self.task) if self.use_multiprocessing: # Need to have different random seeds if running in separate processes - random.seed((os.getpid(), time.time())) + processID = os.getpid() + currentTime = time.time() + random.seed(processID * currentTime) status = FAILED expl = '' @@ -170,7 +182,7 @@ def run(self): # checking completeness of self.task so outputs of dependencies are # irrelevant. if self.check_unfulfilled_deps and not _is_external(self.task): - missing = [dep.task_id for dep in self.task.deps() if not dep.complete()] + missing = [dep.task_id for dep in self.task.deps() if not self.check_complete(dep)] if missing: deps = 'dependency' if len(missing) == 1 else 'dependencies' raise RuntimeError('Unfulfilled %s at run time: %s' % (deps, ', '.join(missing))) @@ -180,7 +192,7 @@ def run(self): if _is_external(self.task): # External task - if self.task.complete(): + if self.check_complete(self.task): status = DONE else: status = FAILED @@ -190,7 +202,12 @@ def run(self): with self._forward_attributes(): new_deps = self._run_get_new_deps() if not new_deps: - if not self.check_complete_on_run or self.task.complete(): + if not self.check_complete_on_run: + # update the cache + if self.task_completion_cache is not None: + self.task_completion_cache[self.task.task_id] = True + status = DONE + elif self.check_complete(self.task): status = DONE else: raise TaskException("Task finished running, but complete() is still returning false.") @@ -392,13 +409,29 @@ def __init__(self, trace): self.trace = trace -def check_complete(task, out_queue): +def check_complete_cached(task, completion_cache=None): + # check if cached and complete + cache_key = task.task_id + if completion_cache is not None and completion_cache.get(cache_key): + return True + + # (re-)check the status + is_complete = task.complete() + + # tell the cache when complete + if completion_cache is not None and is_complete: + completion_cache[cache_key] = is_complete + + return is_complete + + +def check_complete(task, out_queue, completion_cache=None): """ - Checks if task is complete, puts the result to out_queue. + Checks if task is complete, puts the result to out_queue, optionally using the completion cache. """ logger.debug("Checking if %s is complete", task) try: - is_complete = task.complete() + is_complete = check_complete_cached(task, completion_cache) except Exception: is_complete = TracebackWrapper(traceback.format_exc()) out_queue.put((task, is_complete)) @@ -460,6 +493,11 @@ class worker(Config): 'applied as a context manager around its run() call, so this can be ' 'used for obtaining high level customizable monitoring or logging of ' 'each individual Task run.') + cache_task_completion = BoolParameter(default=False, + description='If true, cache the response of successful completion checks ' + 'of tasks assigned to a worker. This can especially speed up tasks with ' + 'dynamic dependencies but assumes that the completion status does not change ' + 'after it was true the first time.') class KeepAliveThread(threading.Thread): @@ -558,6 +596,11 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant self._running_tasks = {} self._idle_since = None + # mp-safe dictionary for caching completation checks across task processes + self._task_completion_cache = None + if self._config.cache_task_completion: + self._task_completion_cache = multiprocessing.Manager().dict() + # Stuff for execution_summary self._add_task_history = [] self._get_work_response_history = [] @@ -612,7 +655,7 @@ def __exit__(self, type, value, traceback): def _generate_worker_info(self): # Generate as much info as possible about the worker # Some of these calls might not be available on all OS's - args = [('salt', '%09d' % random.randrange(0, 999999999)), + args = [('salt', '%09d' % random.randrange(0, 10_000_000_000)), ('workers', self.worker_processes)] try: args += [('host', socket.gethostname())] @@ -743,7 +786,7 @@ def add(self, task, multiprocess=False, processes=0): queue = DequeQueue() pool = SingleProcessPool() self._validate_task(task) - pool.apply_async(check_complete, [task, queue]) + pool.apply_async(check_complete, [task, queue, self._task_completion_cache]) # we track queue size ourselves because len(queue) won't work for multiprocessing queue_size = 1 @@ -757,7 +800,7 @@ def add(self, task, multiprocess=False, processes=0): if next.task_id not in seen: self._validate_task(next) seen.add(next.task_id) - pool.apply_async(check_complete, [next, queue]) + pool.apply_async(check_complete, [next, queue, self._task_completion_cache]) queue_size += 1 except (KeyboardInterrupt, TaskException): raise @@ -1022,6 +1065,7 @@ def _create_task_process(self, task): worker_timeout=self._config.timeout, check_unfulfilled_deps=self._config.check_unfulfilled_deps, check_complete_on_run=self._config.check_complete_on_run, + task_completion_cache=self._task_completion_cache, ) def _purge_children(self): diff --git a/setup.py b/setup.py index a10f81b108..3a47b22d5f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_static_files(path): with open('README.rst') as fobj: long_description = "\n\n" + readme_note + "\n\n" + fobj.read() -install_requires = ['python-dateutil>=2.7.5,<3', 'tenacity>=6.3.0,<7'] +install_requires = ['python-dateutil>=2.7.5,<3', 'tenacity>=8,<9'] # Can't use python-daemon>=2.2.0 if on windows # See https://pagure.io/python-daemon/issue/18 @@ -100,7 +100,8 @@ def get_static_files(path): }, install_requires=install_requires, extras_require={ - 'prometheus': ['prometheus-client==0.5.0'], + 'jsonschema': ['jsonschema'], + 'prometheus': ['prometheus-client>=0.5,<0.15'], 'toml': ['toml<2.0.0'], }, classifiers=[ @@ -115,6 +116,7 @@ def get_static_files(path): 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Topic :: System :: Monitoring', ], ) diff --git a/test/config_env_test.py b/test/config_env_test.py index dbc9064ca9..3ff7e5523c 100644 --- a/test/config_env_test.py +++ b/test/config_env_test.py @@ -16,7 +16,7 @@ # import os -from luigi.configuration import LuigiConfigParser, get_config +from luigi.configuration import LuigiConfigParser, LuigiTomlParser, get_config from luigi.configuration.cfg_parser import InterpolationMissingEnvvarError from helpers import LuigiTestCase, with_config @@ -43,6 +43,8 @@ def tearDown(self): os.environ.pop(key) for key, value in self.environ_backup: os.environ[key] = value + if 'LUIGI_CONFIG_PARSER' in os.environ: + del os.environ["LUIGI_CONFIG_PARSER"] @with_config({"test": { "a": "testval", @@ -95,3 +97,10 @@ def test_underscore_vs_dash_style_priority(self): config = get_config() self.assertEqual(config.get("test", "foo-bar"), "bax") self.assertEqual(config.get("test", "foo_bar"), "bax") + + def test_default_parser(self): + config = get_config() + self.assertIsInstance(config, LuigiConfigParser) + os.environ["LUIGI_CONFIG_PARSER"] = "toml" + config = get_config() + self.assertIsInstance(config, LuigiTomlParser) diff --git a/test/contrib/azureblob_test.py b/test/contrib/azureblob_test.py index 4e725f320c..d587768c2e 100644 --- a/test/contrib/azureblob_test.py +++ b/test/contrib/azureblob_test.py @@ -60,7 +60,7 @@ def test_splitfilepath_blob_nested(self): def test_create_delete_container(self): import datetime import hashlib - m = hashlib.md5() + m = hashlib.new('md5', usedforsecurity=False) m.update(datetime.datetime.now().__str__().encode()) container_name = m.hexdigest() @@ -74,7 +74,7 @@ def test_upload_copy_move_remove_blob(self): import datetime import hashlib import tempfile - m = hashlib.md5() + m = hashlib.new('md5', usedforsecurity=False) m.update(datetime.datetime.now().__str__().encode()) container_name = m.hexdigest() m.update(datetime.datetime.now().__str__().encode()) diff --git a/test/contrib/bigquery_test.py b/test/contrib/bigquery_test.py index 61ee2d7cda..25ad19fef9 100644 --- a/test/contrib/bigquery_test.py +++ b/test/contrib/bigquery_test.py @@ -23,9 +23,11 @@ import mock import pytest +from mock.mock import MagicMock +from luigi.contrib import bigquery from luigi.contrib.bigquery import BigQueryLoadTask, BigQueryTarget, BQDataset, \ - BigQueryRunQueryTask, BigQueryExtractTask + BigQueryRunQueryTask, BigQueryExtractTask, BigQueryClient from luigi.contrib.gcs import GCSTarget @@ -147,3 +149,31 @@ def output(self): } } run_job.assert_called_with('proj', expected_body, dataset=BQDataset('proj', 'ds', None)) + + +class BigQueryClientTest(unittest.TestCase): + + def test_retry_succeeds_on_second_attempt(self): + try: + from googleapiclient import errors + except ImportError: + raise unittest.SkipTest('Unable to load googleapiclient module') + client = MagicMock(spec=BigQueryClient) + attempts = 0 + + @bigquery.bq_retry + def fail_once(bq_client): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise errors.HttpError( + resp=MagicMock(status=500), + content=b'{"error": {"message": "stub"}', + ) + else: + return MagicMock(status=200) + + response = fail_once(client) + client._initialise_client.assert_called_once() + self.assertEqual(attempts, 2) + self.assertEqual(response.status, 200) diff --git a/test/contrib/docker_runner_test.py b/test/contrib/docker_runner_test.py index 22982a9dc1..a33694f577 100644 --- a/test/contrib/docker_runner_test.py +++ b/test/contrib/docker_runner_test.py @@ -49,7 +49,7 @@ except Exception: raise unittest.SkipTest('Unable to connect to docker daemon') -tempfile.tempdir = '/tmp' # set it explicitely to make it work out of the box in mac os +tempfile.tempdir = '/tmp' # set it explicitly to make it work out of the box in mac os local_file = NamedTemporaryFile() local_file.write(b'this is a test file\n') local_file.flush() diff --git a/test/contrib/ecs_test.py b/test/contrib/ecs_test.py index 9a8e8819d4..ee77431552 100644 --- a/test/contrib/ecs_test.py +++ b/test/contrib/ecs_test.py @@ -53,6 +53,13 @@ 'name': 'hello-world', 'image': 'ubuntu', 'command': ['/bin/echo', 'hello world'] + }, + { + 'memory': 1, + 'essential': True, + 'name': 'hello-world-2', + 'image': 'ubuntu', + 'command': ['/bin/echo', 'hello world #2!'] } ] } @@ -74,6 +81,86 @@ def command(self): return [{'name': 'hello-world', 'command': ['/bin/sleep', '10']}] +class ECSTaskCustomRunTaskKwargs(ECSTaskNoOutput): + + @property + def run_task_kwargs(self): + return {'overrides': {'ephemeralStorage': {'sizeInGiB': 30}}} + + +class ECSTaskCustomRunTaskKwargsWithCollidingCommand(ECSTaskNoOutput): + + @property + def command(self): + return [ + {'name': 'hello-world', 'command': ['/bin/sleep', '10']}, + {'name': 'hello-world-2', 'command': ['/bin/sleep', '10']}, + ] + + @property + def run_task_kwargs(self): + return { + 'launchType': 'FARGATE', + 'platformVersion': '1.4.0', + 'networkConfiguration': { + 'awsvpcConfiguration': { + 'subnets': [ + 'subnet-01234567890abcdef', + 'subnet-abcdef01234567890' + ], + 'securityGroups': [ + 'sg-abcdef01234567890', + ], + 'assignPublicIp': 'ENABLED' + } + }, + 'overrides': { + 'containerOverrides': [ + {'name': 'hello-world-2', 'command': ['command-to-be-overwritten']} + ], + 'ephemeralStorage': { + 'sizeInGiB': 30 + } + } + } + + +class ECSTaskCustomRunTaskKwargsWithMergedCommands(ECSTaskNoOutput): + + @property + def command(self): + return [ + {'name': 'hello-world', 'command': ['/bin/sleep', '10']} + ] + + @property + def run_task_kwargs(self): + return { + 'launchType': 'FARGATE', + 'platformVersion': '1.4.0', + 'networkConfiguration': { + 'awsvpcConfiguration': { + 'subnets': [ + 'subnet-01234567890abcdef', + 'subnet-abcdef01234567890' + ], + 'securityGroups': [ + 'sg-abcdef01234567890', + ], + 'assignPublicIp': 'ENABLED' + } + }, + 'overrides': { + 'containerOverrides': [ + {'name': 'hello-world-2', 'command': ['/bin/sleep', '10']} + ], + 'ephemeralStorage': { + 'sizeInGiB': 30 + } + } + } + + @pytest.mark.aws class TestECSTask(unittest.TestCase): @@ -97,3 +184,45 @@ def test_registered_task(self): def test_override_command(self): t = ECSTaskOverrideCommand(task_def_arn=self.arn) luigi.build([t], local_scheduler=True) + + @mock_ecs + def test_custom_run_task_kwargs(self): + t = ECSTaskCustomRunTaskKwargs(task_def_arn=self.arn) + self.assertEqual(t.combined_overrides, { + 'ephemeralStorage': {'sizeInGiB': 30} + }) + luigi.build([t], local_scheduler=True) + + @mock_ecs + def test_custom_run_task_kwargs_with_colliding_command(self): + t = ECSTaskCustomRunTaskKwargsWithCollidingCommand(task_def_arn=self.arn) + combined_overrides = t.combined_overrides + self.assertEqual( + sorted(combined_overrides['containerOverrides'], key=lambda x: x['name']), + sorted( + [ + {'name': 'hello-world', 'command': ['/bin/sleep', '10']}, + {'name': 'hello-world-2', 'command': ['/bin/sleep', '10']}, + ], + key=lambda x: x['name'] + ) + ) + self.assertEqual(combined_overrides['ephemeralStorage'], {'sizeInGiB': 30}) + luigi.build([t], local_scheduler=True) + + @mock_ecs + def test_custom_run_task_kwargs_with_merged_commands(self): + t = ECSTaskCustomRunTaskKwargsWithMergedCommands(task_def_arn=self.arn) + combined_overrides = t.combined_overrides + self.assertEqual( + sorted(combined_overrides['containerOverrides'], key=lambda x: x['name']), + sorted( + [ + {'name': 'hello-world', 'command': ['/bin/sleep', '10']}, + {'name': 'hello-world-2', 'command': ['/bin/sleep', '10']}, + ], + key=lambda x: x['name'] + ) + ) + self.assertEqual(combined_overrides['ephemeralStorage'], {'sizeInGiB': 30}) + luigi.build([t], local_scheduler=True) diff --git a/test/contrib/hadoop_test.py b/test/contrib/hadoop_test.py index 761871b221..b3db99a0bc 100644 --- a/test/contrib/hadoop_test.py +++ b/test/contrib/hadoop_test.py @@ -318,42 +318,49 @@ def _assert_package_subpackage(self, add): @mock.patch('tarfile.open') def test_create_packages_archive_module(self, tar): module = __import__("module", None, None, 'dummy') + module.__file__ = os.path.relpath(module.__file__, os.getcwd()) luigi.contrib.hadoop.create_packages_archive([module], '/dev/null') self._assert_module(tar.return_value.add) @mock.patch('tarfile.open') def test_create_packages_archive_package(self, tar): package = __import__("package", None, None, 'dummy') + package.__path__[0] = os.path.relpath(package.__path__[0], os.getcwd()) luigi.contrib.hadoop.create_packages_archive([package], '/dev/null') self._assert_package(tar.return_value.add) @mock.patch('tarfile.open') def test_create_packages_archive_package_submodule(self, tar): package_submodule = __import__("package.submodule", None, None, 'dummy') + package_submodule.__file__ = os.path.relpath(package_submodule.__file__, os.getcwd()) luigi.contrib.hadoop.create_packages_archive([package_submodule], '/dev/null') self._assert_package(tar.return_value.add) @mock.patch('tarfile.open') def test_create_packages_archive_package_submodule_with_absolute_import(self, tar): package_submodule_with_absolute_import = __import__("package.submodule_with_absolute_import", None, None, 'dummy') + package_submodule_with_absolute_import.__file__ = os.path.relpath(package_submodule_with_absolute_import.__file__, os.getcwd()) luigi.contrib.hadoop.create_packages_archive([package_submodule_with_absolute_import], '/dev/null') self._assert_package(tar.return_value.add) @mock.patch('tarfile.open') def test_create_packages_archive_package_submodule_without_imports(self, tar): package_submodule_without_imports = __import__("package.submodule_without_imports", None, None, 'dummy') + package_submodule_without_imports.__file__ = os.path.relpath(package_submodule_without_imports.__file__, os.getcwd()) luigi.contrib.hadoop.create_packages_archive([package_submodule_without_imports], '/dev/null') self._assert_package(tar.return_value.add) @mock.patch('tarfile.open') def test_create_packages_archive_package_subpackage(self, tar): package_subpackage = __import__("package.subpackage", None, None, 'dummy') + package_subpackage.__path__[0] = os.path.relpath(package_subpackage.__path__[0], os.getcwd()) luigi.contrib.hadoop.create_packages_archive([package_subpackage], '/dev/null') self._assert_package_subpackage(tar.return_value.add) @mock.patch('tarfile.open') def test_create_packages_archive_package_subpackage_submodule(self, tar): package_subpackage_submodule = __import__("package.subpackage.submodule", None, None, 'dummy') + package_subpackage_submodule.__file__ = os.path.relpath(package_subpackage_submodule.__file__, os.getcwd()) luigi.contrib.hadoop.create_packages_archive([package_subpackage_submodule], '/dev/null') self._assert_package_subpackage(tar.return_value.add) diff --git a/test/dict_parameter_test.py b/test/dict_parameter_test.py index fe86745dc1..3dd3306dc4 100644 --- a/test/dict_parameter_test.py +++ b/test/dict_parameter_test.py @@ -15,12 +15,16 @@ # limitations under the License. # +from jsonschema import Draft4Validator +from jsonschema.exceptions import ValidationError from helpers import unittest, in_parse import luigi import luigi.interface import json +import mock import collections +import pytest class DictParameterTask(luigi.Task): @@ -61,3 +65,80 @@ def test_hash_normalize(self): a = luigi.DictParameter().normalize({"a": [{"b": []}]}) b = luigi.DictParameter().normalize({"a": [{"b": []}]}) self.assertEqual(hash(a), hash(b)) + + def test_schema(self): + a = luigi.parameter.DictParameter( + schema={ + "type": "object", + "properties": { + "an_int": {"type": "integer"}, + "an_optional_str": {"type": "string"}, + }, + "additionalProperties": False, + "required": ["an_int"], + }, + ) + + # Check that the default value is validated + with pytest.raises( + ValidationError, + match=r"Additional properties are not allowed \('INVALID_ATTRIBUTE' was unexpected\)", + ): + a.normalize({"INVALID_ATTRIBUTE": 0}) + + # Check that empty dict is not valid + with pytest.raises(ValidationError, match="'an_int' is a required property"): + a.normalize({}) + + # Check that valid dicts work + a.normalize({"an_int": 1}) + a.normalize({"an_int": 1, "an_optional_str": "hello"}) + + # Check that invalid dicts raise correct errors + with pytest.raises(ValidationError, match="'999' is not of type 'integer'"): + a.normalize({"an_int": "999"}) + + with pytest.raises(ValidationError, match="999 is not of type 'string'"): + a.normalize({"an_int": 1, "an_optional_str": 999}) + + # Test the example given in docstring + b = luigi.DictParameter( + schema={ + "type": "object", + "patternProperties": { + ".*": {"type": "string", "enum": ["web", "staging"]}, + } + } + ) + b.normalize({"role": "web", "env": "staging"}) + with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): + b.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) + + # Check that warnings are properly emitted + with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False): + with pytest.warns( + UserWarning, + match=( + "The 'jsonschema' package is not installed so the parameter can not be " + "validated even though a schema is given." + ) + ): + luigi.ListParameter(schema={"type": "object"}) + + # Test with a custom validator + validator = Draft4Validator( + schema={ + "type": "object", + "patternProperties": { + ".*": {"type": "string", "enum": ["web", "staging"]}, + }, + } + ) + c = luigi.DictParameter(schema=validator) + c.normalize({"role": "web", "env": "staging"}) + with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): + c.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) + + # Test with frozen data + frozen_data = luigi.freezing.recursively_freeze({"role": "web", "env": "staging"}) + c.normalize(frozen_data) diff --git a/test/list_parameter_test.py b/test/list_parameter_test.py new file mode 100644 index 0000000000..26204e48cf --- /dev/null +++ b/test/list_parameter_test.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from jsonschema import Draft4Validator +from jsonschema.exceptions import ValidationError +from helpers import unittest, in_parse + +import luigi +import json +import mock +import pytest + + +class ListParameterTask(luigi.Task): + param = luigi.ListParameter() + + +class ListParameterTest(unittest.TestCase): + + _list = [1, "one", True] + + def test_parse(self): + d = luigi.ListParameter().parse(json.dumps(ListParameterTest._list)) + self.assertEqual(d, ListParameterTest._list) + + def test_serialize(self): + d = luigi.ListParameter().serialize(ListParameterTest._list) + self.assertEqual(d, '[1, "one", true]') + + def test_list_serialize_parse(self): + a = luigi.ListParameter() + b_list = [1, 2, 3] + self.assertEqual(b_list, a.parse(a.serialize(b_list))) + + def test_parse_interface(self): + in_parse(["ListParameterTask", "--param", '[1, "one", true]'], + lambda task: self.assertEqual(task.param, tuple(ListParameterTest._list))) + + def test_serialize_task(self): + t = ListParameterTask(ListParameterTest._list) + self.assertEqual(str(t), 'ListParameterTask(param=[1, "one", true])') + + def test_parse_invalid_input(self): + self.assertRaises(ValueError, lambda: luigi.ListParameter().parse('{"invalid"}')) + + def test_hash_normalize(self): + self.assertRaises(TypeError, lambda: hash(luigi.ListParameter().parse('"NOT A LIST"'))) + a = luigi.ListParameter().normalize([0]) + b = luigi.ListParameter().normalize([0]) + self.assertEqual(hash(a), hash(b)) + + def test_schema(self): + a = luigi.ListParameter( + schema={ + "type": "array", + "items": { + "type": "number", + "minimum": 0, + "maximum": 10, + }, + "minItems": 1, + } + ) + + # Check that the default value is validated + with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'"): + a.normalize(["INVALID_ATTRIBUTE"]) + + # Check that empty list is not valid + with pytest.raises(ValidationError, match=r"\[\] is too short"): + a.normalize([]) + + # Check that valid lists work + valid_list = [1, 2, 3] + a.normalize(valid_list) + + # Check that invalid lists raise correct errors + invalid_list_type = ["NOT AN INT"] + invalid_list_value = [-999, 999] + + with pytest.raises(ValidationError, match="'NOT AN INT' is not of type 'number'"): + a.normalize(invalid_list_type) + + with pytest.raises(ValidationError, match="-999 is less than the minimum of 0"): + a.normalize(invalid_list_value) + + # Check that warnings are properly emitted + with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False): + with pytest.warns( + UserWarning, + match=( + "The 'jsonschema' package is not installed so the parameter can not be " + "validated even though a schema is given." + ) + ): + luigi.ListParameter(schema={"type": "array", "items": {"type": "number"}}) + + # Test with a custom validator + validator = Draft4Validator( + schema={ + "type": "array", + "items": { + "type": "number", + "minimum": 0, + "maximum": 10, + }, + "minItems": 1, + } + ) + c = luigi.DictParameter(schema=validator) + c.normalize(valid_list) + with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'",): + c.normalize(["INVALID_ATTRIBUTE"]) + + # Test with frozen data + frozen_data = luigi.freezing.recursively_freeze(valid_list) + c.normalize(frozen_data) diff --git a/test/local_target_test.py b/test/local_target_test.py index 35c497c205..d9c78539c5 100644 --- a/test/local_target_test.py +++ b/test/local_target_test.py @@ -354,3 +354,13 @@ def test_move_to_new_dir(self): LocalTarget(src).open('w').close() self.fs.move(src, dest) self.assertTrue(os.path.exists(dest)) + + +class DestructorTest(unittest.TestCase): + + def test_destructor(self): + # LocalTarget might not be fully initialised if an exception is thrown in the constructor of LocalTarget or a + # subclass. The destructor can't expect attributes to be initialised. + t = LocalTarget(is_tmp=True) + del t.is_tmp + t.__del__() diff --git a/test/lock_test.py b/test/lock_test.py index 7ef8761cc1..2701bd963f 100644 --- a/test/lock_test.py +++ b/test/lock_test.py @@ -21,6 +21,7 @@ import mock from helpers import unittest +from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential import luigi import luigi.lock import luigi.notifications @@ -31,13 +32,21 @@ class TestCmd(unittest.TestCase): def test_getpcmd(self): + def _is_empty(cmd): + return cmd == "" + + # for CI stability, add retring + @retry(retry=retry_if_result(_is_empty), wait=wait_exponential(multiplier=0.2, min=0.1, max=3), stop=stop_after_attempt(3)) + def _getpcmd(pid): + return luigi.lock.getpcmd(pid) + if os.name == 'nt': command = ["ping", "1.1.1.1", "-w", "1000"] else: command = ["sleep", "1"] external_process = subprocess.Popen(command) - result = luigi.lock.getpcmd(external_process.pid) + result = _getpcmd(external_process.pid) self.assertTrue( result.strip() in ["sleep 1", '[sleep]', 'ping 1.1.1.1 -w 1000'] @@ -57,9 +66,17 @@ def tearDown(self): os.rmdir(self.pid_dir) def test_get_info(self): + def _is_empty(result): + return result[1] == "" # cmd is empty + + # for CI stability, add retring + @retry(retry=retry_if_result(_is_empty), wait=wait_exponential(multiplier=0.2, min=0.1, max=3), stop=stop_after_attempt(3)) + def _get_info(pid_dir, pid): + return luigi.lock.get_info(pid_dir, pid) + try: p = subprocess.Popen(["yes", u"à我Ņ„"], stdout=subprocess.PIPE) - pid, cmd, pid_file = luigi.lock.get_info(self.pid_dir, p.pid) + pid, cmd, pid_file = _get_info(self.pid_dir, p.pid) finally: p.kill() self.assertEqual(cmd, u'yes à我Ņ„') diff --git a/test/notifications_test.py b/test/notifications_test.py index f0fd392865..c5974bafa6 100644 --- a/test/notifications_test.py +++ b/test/notifications_test.py @@ -387,7 +387,7 @@ def test_sends_sns_email(self): def test_sns_subject_is_shortened(self): """ Call notifications.send_email_sns with too long Subject (more than 100 chars) - and check that it is cut to lenght of 100 chars. + and check that it is cut to length of 100 chars. """ long_subject = 'Luigi: SanityCheck(regexPattern=aligned-source\\|data-not-older\\|source-chunks-compl,'\ diff --git a/test/optional_parameter_test.py b/test/optional_parameter_test.py index 0bddb033ab..2bbea15505 100644 --- a/test/optional_parameter_test.py +++ b/test/optional_parameter_test.py @@ -1,3 +1,5 @@ +import warnings + import luigi import mock @@ -102,3 +104,48 @@ def test_optional_choice_parameter_int(self): choices = [0, 1, 2] self.actual_test(luigi.OptionalChoiceParameter, None, 1, "int", "bad data", var_type=int, choices=choices) self.actual_test(luigi.OptionalChoiceParameter, "default value", 1, "int", "bad data", var_type=int, choices=choices) + + def test_warning(self): + class TestOptionalFloatParameterSingleType( + luigi.parameter.OptionalParameter, luigi.FloatParameter + ): + expected_type = float + + class TestOptionalFloatParameterMultiTypes( + luigi.parameter.OptionalParameter, luigi.FloatParameter + ): + expected_type = (int, float) + + class TestConfig(luigi.Config): + param_single = TestOptionalFloatParameterSingleType() + param_multi = TestOptionalFloatParameterMultiTypes() + + with warnings.catch_warnings(record=True) as record: + TestConfig(param_single=0.0, param_multi=1.0) + + assert len(record) == 0 + + with warnings.catch_warnings(record=True) as record: + warnings.filterwarnings( + action="ignore", + category=Warning, + ) + warnings.simplefilter( + action="always", + category=luigi.parameter.OptionalParameterTypeWarning, + ) + assert luigi.build( + [TestConfig(param_single="0", param_multi="1")], local_scheduler=True + ) + + assert len(record) == 2 + assert issubclass(record[0].category, luigi.parameter.OptionalParameterTypeWarning) + assert issubclass(record[1].category, luigi.parameter.OptionalParameterTypeWarning) + assert str(record[0].message) == ( + 'TestOptionalFloatParameterSingleType "param_single" with value "0" is not of type ' + '"float" or None.' + ) + assert str(record[1].message) == ( + 'TestOptionalFloatParameterMultiTypes "param_multi" with value "1" is not of any ' + 'type in ["int", "float"] or None.' + ) diff --git a/test/parameter_test.py b/test/parameter_test.py index 13af40a8c7..4e1d374d19 100644 --- a/test/parameter_test.py +++ b/test/parameter_test.py @@ -21,6 +21,7 @@ from datetime import timedelta import enum import mock +import pytest import luigi import luigi.date_interval @@ -309,11 +310,6 @@ def test_enum_list_param_invalid(self): def test_enum_list_param_missing(self): self.assertRaises(ParameterException, lambda: luigi.parameter.EnumListParameter()) - def test_list_serialize_parse(self): - a = luigi.ListParameter() - b_list = [1, 2, 3] - self.assertEqual(b_list, a.parse(a.serialize(b_list))) - def test_tuple_serialize_parse(self): a = luigi.TupleParameter() b_tuple = ((1, 2), (3, 4)) @@ -853,6 +849,16 @@ def f(): self.assertRaises(ValueError, f) # ISO 8601 durations with months are not supported exc.assert_called_once_with("Invalid time delta - could not parse P6M") + @with_config({"foo": {"bar": "12.34"}}) + def testTimeDeltaFloat(self): + p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) + self.assertEqual(timedelta(seconds=12.34), _value(p)) + + @with_config({"foo": {"bar": "56789"}}) + def testTimeDeltaInt(self): + p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) + self.assertEqual(timedelta(seconds=56789), _value(p)) + def testHasDefaultNoSection(self): self.assertRaises(luigi.parameter.MissingParameterException, lambda: _value(luigi.Parameter(config_path=dict(section="foo", name="bar")))) @@ -1254,3 +1260,67 @@ class MyTask(luigi.Task): task = luigi.IntParameter() self.assertTrue(self.run_locally_split('MyTask --task 5')) + + +class TestPathParameter: + + @pytest.fixture(params=[None, "not_existing_dir"]) + def default(self, request): + return request.param + + @pytest.fixture(params=[True, False]) + def absolute(self, request): + return request.param + + @pytest.fixture(params=[True, False]) + def exists(self, request): + return request.param + + @pytest.fixture() + def path_parameter(self, tmpdir, default, absolute, exists): + class TaskPathParameter(luigi.Task): + + a = luigi.PathParameter( + default=str(tmpdir / default) if default is not None else str(tmpdir), + absolute=absolute, + exists=exists, + ) + b = luigi.OptionalPathParameter( + default=str(tmpdir / default) if default is not None else str(tmpdir), + absolute=absolute, + exists=exists, + ) + c = luigi.OptionalPathParameter(default=None) + d = luigi.OptionalPathParameter(default="not empty default") + + def run(self): + # Use the parameter as a Path object + new_file = self.a / "test.file" + new_optional_file = self.b / "test_optional.file" + if default is not None: + new_file.parent.mkdir(parents=True) + new_file.touch() + new_optional_file.touch() + assert new_file.exists() + assert new_optional_file.exists() + assert self.c is None + assert self.d is None + + def output(self): + return luigi.LocalTarget("not_existing_file") + + return { + "tmpdir": tmpdir, + "default": default, + "absolute": absolute, + "exists": exists, + "cls": TaskPathParameter, + } + + @with_config({"TaskPathParameter": {"d": ""}}) + def test_exists(self, path_parameter): + if path_parameter["default"] is not None and path_parameter["exists"]: + with pytest.raises(ValueError, match="The path .* does not exist"): + luigi.build([path_parameter["cls"]()], local_scheduler=True) + else: + assert luigi.build([path_parameter["cls"]()], local_scheduler=True) diff --git a/test/range_test.py b/test/range_test.py index 4be8760a54..39c9633ec4 100644 --- a/test/range_test.py +++ b/test/range_test.py @@ -652,7 +652,7 @@ def test_start_and_minutes_period(self): ('CommonDateMinuteTask', 13), # 20 intervals - 7 missing ], 'event.tools.range.complete.fraction': [ - ('CommonDateMinuteTask', 13. / (13 + 7)), # (exptected - missing) / expected + ('CommonDateMinuteTask', 13. / (13 + 7)), # (expected - missing) / expected ], } ) @@ -1514,7 +1514,7 @@ class RangeInstantiationTest(LuigiTestCase): def test_old_instantiation(self): """ - Verify that you can still programatically set of param as string + Verify that you can still programmatically set of param as string """ class MyTask(luigi.Task): date_param = luigi.DateParameter() diff --git a/test/rpc_test.py b/test/rpc_test.py index 1537f5c9e2..f4837e3b10 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -27,7 +27,6 @@ from server_test import ServerTestBase import socket from multiprocessing import Process, Queue -import requests class RemoteSchedulerTest(unittest.TestCase): @@ -39,6 +38,14 @@ def testUrlArgumentVariations(self): s._fetch(suffix, '{}') fetcher.fetch.assert_called_once_with('http://zorg.com/api/123', '{}', 42) + def testUrlArgumentVariationsNotRoot(self): + for url in ['http://zorg.com/subpath', 'http://zorg.com/subpath/']: + for suffix in ['api/123', '/api/123']: + s = luigi.rpc.RemoteScheduler(url, 42) + with mock.patch.object(s, '_fetcher') as fetcher: + s._fetch(suffix, '{}') + fetcher.fetch.assert_called_once_with('http://zorg.com/subpath/api/123', '{}', 42) + def get_work(self, fetcher_side_effect): scheduler = luigi.rpc.RemoteScheduler('http://zorg.com', 42) scheduler._rpc_retry_wait = 1 # shorten wait time to speed up tests @@ -147,8 +154,8 @@ def test_get_work_speed(self): class RequestsFetcherTest(ServerTestBase): def test_fork_changes_session(self): - session = requests.Session() - fetcher = luigi.rpc.RequestsFetcher(session) + fetcher = luigi.rpc.RequestsFetcher() + session = fetcher.session q = Queue() diff --git a/test/scheduler_api_test.py b/test/scheduler_api_test.py index 66ab77b8f5..9d4ae4487d 100644 --- a/test/scheduler_api_test.py +++ b/test/scheduler_api_test.py @@ -1774,6 +1774,8 @@ def test_task_list_filter_by_multiple_search_terms(self): self.add_task('ClassA', day='2016-02-01', val='5') self.search_pending('ClassA 2016-02-01 num', {expected}) + # ensure that the task search is case insensitive + self.search_pending('classa 2016-02-01 num', {expected}) def test_upstream_beyond_limit(self): sch = Scheduler(max_shown_tasks=3) diff --git a/test/scheduler_test.py b/test/scheduler_test.py index 6bfd39b4d6..4973e3219b 100644 --- a/test/scheduler_test.py +++ b/test/scheduler_test.py @@ -281,6 +281,15 @@ def test_prometheus_metrics_collector(self): collector = scheduler_state._metrics_collector self.assertTrue(isinstance(collector, PrometheusMetricsCollector)) + @with_config({'scheduler': {'metrics_collector': 'custom', 'metrics_custom_import': 'luigi.contrib.prometheus_metric.PrometheusMetricsCollector'}}) + def test_custom_metrics_collector(self): + from luigi.contrib.prometheus_metric import PrometheusMetricsCollector + + s = luigi.scheduler.Scheduler() + scheduler_state = s._state + collector = scheduler_state._metrics_collector + self.assertTrue(isinstance(collector, PrometheusMetricsCollector)) + class SchedulerWorkerTest(unittest.TestCase): def get_pending_ids(self, worker, state): diff --git a/test/setup_logging_test.py b/test/setup_logging_test.py index 3bc17462f0..18724666e6 100644 --- a/test/setup_logging_test.py +++ b/test/setup_logging_test.py @@ -31,10 +31,39 @@ def test_cli(self): self.assertFalse(result) def test_section(self): - self.cls.config = {'logging': { - 'version': 1, - 'disable_existing_loggers': False, - }} + self.cls.config = { + 'logging': { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'mockformatter': { + 'format': '{levelname}: {message}', + 'style': '{', + 'datefmt': '%Y-%m-%d %H:%M:%S', + }, + }, + 'handlers': { + 'mockhandler': { + 'class': 'logging.StreamHandler', + 'level': 'INFO', + 'formatter': 'mockformatter', + }, + }, + 'loggers': { + 'mocklogger': { + 'handlers': ('mockhandler',), + 'level': 'INFO', + 'disabled': False, + 'propagate': False, + }, + }, + }, + } + result = self.cls._section(None) + self.assertTrue(result) + + self.cls.config = LuigiTomlParser() + self.cls.config.read(['./test/testconfig/luigi_logging.toml']) result = self.cls._section(None) self.assertTrue(result) diff --git a/test/task_forwarded_attributes_test.py b/test/task_forwarded_attributes_test.py index 48ef319136..97f9a05cb3 100644 --- a/test/task_forwarded_attributes_test.py +++ b/test/task_forwarded_attributes_test.py @@ -54,7 +54,7 @@ class YieldingTask(NonYieldingTask): def run(self): # as TaskProcess._run_get_new_deps handles generators in a specific way, store names of - # forwarded attributes before and after yielding a dynamic dependency, so we can explicitely + # forwarded attributes before and after yielding a dynamic dependency, so we can explicitly # validate the attribute forwarding implementation self.attributes_before_yield = self.gather_forwarded_attributes() yield RunOnceTask() diff --git a/test/task_test.py b/test/task_test.py index 09cd8f3d4b..e2da34796c 100644 --- a/test/task_test.py +++ b/test/task_test.py @@ -78,7 +78,7 @@ def test_task_missing_necessary_param(self): def test_external_tasks_loadable(self): task = load_task("luigi", "ExternalTask", {}) - assert(isinstance(task, luigi.ExternalTask)) + self.assertTrue(isinstance(task, luigi.ExternalTask)) def test_getpaths(self): class RequiredTask(luigi.Task): @@ -169,6 +169,163 @@ class ATaskWithBadParam(luigi.Task): with self.assertRaisesRegex(ValueError, r"ATaskWithBadParam\[args=\(\), kwargs={}\]: Error when parsing the default value of 'bad_param'"): ATaskWithBadParam() + @with_config( + { + "TaskA": { + "a": "a", + "b": "b", + "c": "c", + }, + "TaskB": { + "a": "a", + "b": "b", + "c": "c", + }, + } + ) + def test_unconsumed_params(self): + class TaskA(luigi.Task): + a = luigi.Parameter(default="a") + + class TaskB(luigi.Task): + a = luigi.Parameter(default="a") + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings( + action="ignore", + category=Warning, + ) + warnings.simplefilter( + action="always", + category=luigi.parameter.UnconsumedParameterWarning, + ) + + TaskA() + TaskB() + + assert len(w) == 4 + expected = [ + ("b", "TaskA"), + ("c", "TaskA"), + ("b", "TaskB"), + ("c", "TaskB"), + ] + for i, (expected_value, task_name) in zip(w, expected): + assert issubclass(i.category, luigi.parameter.UnconsumedParameterWarning) + assert str(i.message) == ( + "The configuration contains the parameter " + f"'{expected_value}' with value '{expected_value}' that is not consumed by " + f"the task '{task_name}'." + ) + + @with_config( + { + "TaskEdgeCase": { + "camelParam": "camelCase", + "underscore_param": "underscore", + "dash-param": "dash", + }, + } + ) + def test_unconsumed_params_edge_cases(self): + class TaskEdgeCase(luigi.Task): + camelParam = luigi.Parameter() + underscore_param = luigi.Parameter() + dash_param = luigi.Parameter() + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings( + action="ignore", + category=Warning, + ) + warnings.simplefilter( + action="always", + category=luigi.parameter.UnconsumedParameterWarning, + ) + + task = TaskEdgeCase() + assert len(w) == 0 + assert task.camelParam == "camelCase" + assert task.underscore_param == "underscore" + assert task.dash_param == "dash" + + @with_config( + { + "TaskIgnoreUnconsumed": { + "a": "a", + "b": "b", + "c": "c", + }, + } + ) + def test_unconsumed_params_ignore_unconsumed(self): + class TaskIgnoreUnconsumed(luigi.Task): + ignore_unconsumed = {"b", "d"} + + a = luigi.Parameter() + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings( + action="ignore", + category=Warning, + ) + warnings.simplefilter( + action="always", + category=luigi.parameter.UnconsumedParameterWarning, + ) + + TaskIgnoreUnconsumed() + assert len(w) == 1 + + +class TaskFlattenOutputTest(unittest.TestCase): + def test_single_task(self): + expected = [luigi.LocalTarget("f1.txt"), luigi.LocalTarget("f2.txt")] + + class TestTask(luigi.ExternalTask): + def output(self): + return expected + + self.assertListEqual(luigi.task.flatten_output(TestTask()), expected) + + def test_wrapper_task(self): + expected = [luigi.LocalTarget("f1.txt"), luigi.LocalTarget("f2.txt")] + + class Test1Task(luigi.ExternalTask): + def output(self): + return expected[0] + + class Test2Task(luigi.ExternalTask): + def output(self): + return expected[1] + + @luigi.util.requires(Test1Task, Test2Task) + class TestWrapperTask(luigi.WrapperTask): + pass + + self.assertListEqual(luigi.task.flatten_output(TestWrapperTask()), expected) + + def test_wrapper_tasks_diamond(self): + expected = [luigi.LocalTarget("file.txt")] + + class TestTask(luigi.ExternalTask): + def output(self): + return expected + + @luigi.util.requires(TestTask) + class LeftWrapperTask(luigi.WrapperTask): + pass + + @luigi.util.requires(TestTask) + class RightWrapperTask(luigi.WrapperTask): + pass + + @luigi.util.requires(LeftWrapperTask, RightWrapperTask) + class MasterWrapperTask(luigi.WrapperTask): + pass + + self.assertListEqual(luigi.task.flatten_output(MasterWrapperTask()), expected) + class ExternalizeTaskTest(LuigiTestCase): @@ -409,3 +566,15 @@ class MyTask(luigi.Task): pass luigi.namespace(scope='incorrect_namespace') self.assertEqual(MyTask.get_task_namespace(), '') + + +class InitSubclassTest(LuigiTestCase): + def test_task_works_with_init_subclass(self): + class ReceivesClassKwargs(luigi.Task): + def __init_subclass__(cls, x, **kwargs): + super(ReceivesClassKwargs, cls).__init_subclass__() + cls.x = x + + class Receiver(ReceivesClassKwargs, x=1): + pass + self.assertEquals(Receiver.x, 1) diff --git a/test/testconfig/luigi_logging.toml b/test/testconfig/luigi_logging.toml new file mode 100644 index 0000000000..988170a710 --- /dev/null +++ b/test/testconfig/luigi_logging.toml @@ -0,0 +1,18 @@ +[logging] +version = 1 +disable_existing_loggers = false + +[logging.formatters.mockformatter] +format = "{levelname}: {message}" +style = "{" + +[logging.handlers.mockhandler] +class = "logging.StreamHandler" +level = "INFO" +formatter = "mockformatter" + +[logging.loggers.mocklogger] +handlers = ["mockhandler"] +level = 'INFO' +disabled = false +propagate = false diff --git a/test/worker_test.py b/test/worker_test.py index 7f09314b82..eaad63b17b 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -59,6 +59,7 @@ def run(self): class DynamicDummyTask(Task): p = luigi.Parameter() + sleep = luigi.FloatParameter(default=0.5, significant=False) def output(self): return luigi.LocalTarget(self.p) @@ -66,7 +67,7 @@ def output(self): def run(self): with self.output().open('w') as f: f.write('Done!') - time.sleep(0.5) # so we can benchmark & see if parallelization works + time.sleep(self.sleep) # so we can benchmark & see if parallelization works class DynamicDummyTaskWithNamespace(DynamicDummyTask): @@ -95,6 +96,37 @@ def run(self): print('%d: %s' % (i, line.strip()), file=f) +class DynamicRequiresWrapped(Task): + p = luigi.Parameter() + + def output(self): + return luigi.LocalTarget(os.path.join(self.p, 'parent')) + + def run(self): + reqs = [ + DynamicDummyTask(p=os.path.join(self.p, '%s.txt' % i), sleep=0.0) + for i in range(10) + ] + + # yield again as DynamicRequires + yield luigi.DynamicRequirements(reqs) + + # and again with a custom complete function that does base name comparisons + def custom_complete(complete_fn): + if not complete_fn(reqs[0]): + return False + paths = [task.output().path for task in reqs] + basenames = os.listdir(os.path.dirname(paths[0])) + self._custom_complete_called = True + self._custom_complete_result = all(os.path.basename(path) in basenames for path in paths) + return self._custom_complete_result + + yield luigi.DynamicRequirements(reqs, custom_complete) + + with self.output().open('w') as f: + f.write('Done!') + + class DynamicRequiresOtherModule(Task): p = luigi.Parameter() @@ -434,6 +466,88 @@ def requires(self): self.assertEqual(a2.complete_count, 2) self.assertEqual(b2.complete_count, 2) + def test_cache_task_completion_config(self): + class A(Task): + + i = luigi.IntParameter() + + def __init__(self, *args, **kwargs): + super(A, self).__init__(*args, **kwargs) + self.complete_count = 0 + self.has_run = False + + def complete(self): + self.complete_count += 1 + return self.has_run + + def run(self): + self.has_run = True + + class B(A): + + def run(self): + yield A(i=self.i + 0) + yield A(i=self.i + 1) + yield A(i=self.i + 2) + self.has_run = True + + # test with enabled cache_task_completion + with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=True) as w: + b0 = B(i=0) + a0 = A(i=0) + a1 = A(i=1) + a2 = A(i=2) + self.assertTrue(w.add(b0)) + # a's are required dynamically, so their counts must be 0 + self.assertEqual(b0.complete_count, 1) + self.assertEqual(a0.complete_count, 0) + self.assertEqual(a1.complete_count, 0) + self.assertEqual(a2.complete_count, 0) + w.run() + # the complete methods of a's yielded first in b's run method were called equally often + self.assertEqual(b0.complete_count, 1) + self.assertEqual(a0.complete_count, 2) + self.assertEqual(a1.complete_count, 2) + self.assertEqual(a2.complete_count, 2) + + # test with disabled cache_task_completion + with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=False) as w: + b10 = B(i=10) + a10 = A(i=10) + a11 = A(i=11) + a12 = A(i=12) + self.assertTrue(w.add(b10)) + # a's are required dynamically, so their counts must be 0 + self.assertEqual(b10.complete_count, 1) + self.assertEqual(a10.complete_count, 0) + self.assertEqual(a11.complete_count, 0) + self.assertEqual(a12.complete_count, 0) + w.run() + # the complete methods of a's yielded first in b's run method were called more often + self.assertEqual(b10.complete_count, 1) + self.assertEqual(a10.complete_count, 5) + self.assertEqual(a11.complete_count, 4) + self.assertEqual(a12.complete_count, 3) + + # test with enabled check_complete_on_run + with Worker(scheduler=self.sch, worker_id='2', check_complete_on_run=True) as w: + b20 = B(i=20) + a20 = A(i=20) + a21 = A(i=21) + a22 = A(i=22) + self.assertTrue(w.add(b20)) + # a's are required dynamically, so their counts must be 0 + self.assertEqual(b20.complete_count, 1) + self.assertEqual(a20.complete_count, 0) + self.assertEqual(a21.complete_count, 0) + self.assertEqual(a22.complete_count, 0) + w.run() + # the complete methods of a's yielded first in b's run method were called more often + self.assertEqual(b20.complete_count, 2) + self.assertEqual(a20.complete_count, 6) + self.assertEqual(a21.complete_count, 5) + self.assertEqual(a22.complete_count, 4) + def test_gets_missed_work(self): class A(Task): done = False @@ -1071,6 +1185,13 @@ def test_dynamic_dependencies_other_module(self): luigi.build([t], local_scheduler=True, workers=self.n_workers) self.assertTrue(t.complete()) + def test_wrapped_dynamic_requirements(self): + t = DynamicRequiresWrapped(p=self.p) + luigi.build([t], local_scheduler=True, workers=1) + self.assertTrue(t.complete()) + self.assertTrue(getattr(t, '_custom_complete_called', False)) + self.assertTrue(getattr(t, '_custom_complete_result', False)) + class DynamicDependenciesWithMultipleWorkersTest(DynamicDependenciesTest): n_workers = 100 @@ -1238,8 +1359,8 @@ def complete(self): worker = Worker(scheduler) a = A() - with mock.patch.object(worker._scheduler, 'announce_scheduling_failure', side_effect=Exception('Unexpected')),\ - self.assertRaises(Exception): + with mock.patch.object(worker._scheduler, 'announce_scheduling_failure', + side_effect=Exception('Unexpected')), self.assertRaises(Exception): worker.add(a) self.assertTrue(len(emails) == 2) # One for `complete` error, one for exception in announcing. self.assertTrue('Luigi: Framework error while scheduling' in emails[1]) diff --git a/tox.ini b/tox.ini index b2db3fa86f..ff75cf6f12 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{35,36,37,38,39}-{cdh,hdp,core,contrib,apache,aws,gcloud,postgres,unixsocket,azureblob,dropbox}, visualiser, docs, flake8 +envlist = py{35,36,37,38,39,310}-{cdh,hdp,core,contrib,apache,aws,gcloud,postgres,unixsocket,azureblob,dropbox}, visualiser, docs, flake8 skipsdist = True [pytest] @@ -38,7 +38,9 @@ deps = psutil<4.0 cdh,hdp: hdfs>=2.0.4,<3.0.0 postgres: psycopg2<3.0 + postgres: pg8000>=1.23.0 mysql-connector-python>=8.0.12 + py35,py36: mysql-connector-python<8.0.32 gcloud: google-api-python-client>=1.6.6,<2.0 avro-python3 gcloud: google-auth==1.4.1 @@ -56,8 +58,9 @@ deps = responses<1.0.0 azure-storage<=0.36 datadog==0.22.0 - prometheus-client==0.5.0 + prometheus-client>=0.5.0<0.15 dropbox: dropbox>=9.3.0<10.0 + jsonschema passenv = USER JAVA_HOME POSTGRES_USER DATAPROC_TEST_PROJECT_ID GCS_TEST_PROJECT_ID GCS_TEST_BUCKET GOOGLE_APPLICATION_CREDENTIALS TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT CI DROPBOX_APP_TOKEN DOCKERHUB_TOKEN GITHUB_ACTIONS OVERRIDE_SKIP_CI_TESTS setenv = @@ -131,10 +134,12 @@ basepython=python3 deps = sqlalchemy boto3 + jinja2==3.0.3 Sphinx>=1.4.4,<1.5 sphinx_rtd_theme azure-storage<=0.36 prometheus-client==0.5.0 + alabaster<0.7.13 commands = # build API docs sphinx-apidoc -o doc/api -T luigi --separate