diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index ce54bb0..67120ef 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -14,9 +14,9 @@ jobs: - name: Verify tag matches version run: | set -ex - version=$(cat trainer/VERSION) + version=$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o) tag="${GITHUB_REF/refs\/tags\/}" - if [[ "$version" != "$tag" ]]; then + if [[ "v$version" != "$tag" ]]; then exit 1 fi - uses: actions/setup-python@v5 @@ -36,7 +36,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -67,10 +67,6 @@ jobs: with: name: "sdist" path: "dist/" - - uses: actions/download-artifact@v4 - with: - name: "wheel-3.8" - path: "dist/" - uses: actions/download-artifact@v4 with: name: "wheel-3.9" diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 669b302..6755f44 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -28,13 +28,8 @@ jobs: run: | sudo apt-get update sudo apt-get install -y git make gcc - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Install Trainer - run: | - python3 -m pip install .[all] - python3 setup.py egg_info + - name: Install/upgrade dev dependencies + run: python3 -m pip install -r requirements.dev.txt - name: Lint check run: | make lint diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2f34d2d..2efedb1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, "3.10", "3.11", "3.12"] + python-version: [3.9, "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -28,12 +28,37 @@ jobs: run: | sudo apt-get update sudo apt-get install -y --no-install-recommends git make gcc - make system-deps - name: Install/upgrade Python setup deps run: python3 -m pip install --upgrade pip setuptools wheel - name: Install Trainer run: | - python3 -m pip install .[all] - python3 setup.py egg_info + python3 -m pip install .[dev,test] - name: Unit tests run: make test_all + - name: Upload coverage data + uses: actions/upload-artifact@v4 + with: + name: coverage-data-${{ matrix.python-version }} + path: .coverage.* + if-no-files-found: ignore + coverage: + if: always() + needs: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - uses: actions/download-artifact@v4 + with: + pattern: coverage-data-* + merge-multiple: true + - name: Combine coverage + run: | + python -Im pip install --upgrade coverage[toml] + + python -Im coverage combine + python -Im coverage html --skip-covered --skip-empty + + python -Im coverage report --format=markdown >> $GITHUB_STEP_SUMMARY diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..615dc54 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: "https://github.com/pre-commit/pre-commit-hooks" + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.10 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format + - repo: local + hooks: + - id: generate_requirements.py + name: generate_requirements.py + language: system + entry: python bin/generate_requirements.py + files: "pyproject.toml|requirements.*\\.txt|tools/generate_requirements.py" diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 2c8598f..0000000 --- a/.pylintrc +++ /dev/null @@ -1,647 +0,0 @@ -[MAIN] - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Clear in-memory caches upon conclusion of linting. Useful if running pylint -# in a server-like mode. -clear-cache-post-run=no - -# Load and enable all available extensions. Use --list-extensions to see a list -# all available extensions. -#enable-all-extensions= - -# In error mode, messages with a category besides ERROR or FATAL are -# suppressed, and no reports are done by default. Error mode is compatible with -# disabling specific errors. -#errors-only= - -# Always return a 0 (non-error) status code, even if lint errors are found. -# This is primarily useful in continuous integration scripts. -#exit-zero= - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-allow-list= - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. (This is an alternative name to extension-pkg-allow-list -# for backward compatibility.) -extension-pkg-whitelist= - -# Return non-zero exit code if any of these messages/categories are detected, -# even if score is above --fail-under value. Syntax same as enable. Messages -# specified are enabled, while categories only check already-enabled messages. -fail-on= - -# Specify a score threshold under which the program will exit with error. -fail-under=10 - -# Interpret the stdin as a python script, whose filename needs to be passed as -# the module_or_package argument. -#from-stdin= - -# Files or directories to be skipped. They should be base names, not paths. -ignore=CVS - -# Add files or directories matching the regular expressions patterns to the -# ignore-list. The regex matches against paths and can be in Posix or Windows -# format. Because '\\' represents the directory delimiter on Windows systems, -# it can't be used as an escape character. -ignore-paths= - -# Files or directories matching the regular expression patterns are skipped. -# The regex matches against base names, not paths. The default value ignores -# Emacs file locks -ignore-patterns=^\.# - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use, and will cap the count on Windows to -# avoid hangs. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Minimum Python version to use for version dependent checks. Will default to -# the version used to run pylint. -py-version=3.11 - -# Discover python modules and packages in the file system subtree. -recursive=no - -# Add paths to the list of the source roots. Supports globbing patterns. The -# source root is an absolute path or a path relative to the current working -# directory used to determine a package namespace for modules located under the -# source root. -source-roots= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - -# In verbose mode, extra non-checker-related info will be displayed. -#verbose= - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. If left empty, argument names will be checked with the set -# naming style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. If left empty, attribute names will be checked with the set naming -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Bad variable names regexes, separated by a comma. If names match any regex, -# they will always be refused -bad-names-rgxs= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. If left empty, class attribute names will be checked -# with the set naming style. -#class-attribute-rgx= - -# Naming style matching correct class constant names. -class-const-naming-style=UPPER_CASE - -# Regular expression matching correct class constant names. Overrides class- -# const-naming-style. If left empty, class constant names will be checked with -# the set naming style. -#class-const-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. If left empty, class names will be checked with the set naming style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. If left empty, constant names will be checked with the set naming -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. If left empty, function names will be checked with the set -# naming style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - ex, - Run, - _ - -# Good variable names regexes, separated by a comma. If names match any regex, -# they will always be accepted -good-names-rgxs= - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. If left empty, inline iteration names will be checked -# with the set naming style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. If left empty, method names will be checked with the set naming style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. If left empty, module names will be checked with the set naming style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Regular expression matching correct type alias names. If left empty, type -# alias names will be checked with the set naming style. -#typealias-rgx= - -# Regular expression matching correct type variable names. If left empty, type -# variable names will be checked with the set naming style. -#typevar-rgx= - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. If left empty, variable names will be checked with the set -# naming style. -#variable-rgx= - - -[CLASSES] - -# Warn about protected attribute access inside special methods -check-protected-access-in-special-methods=no - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - asyncSetUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[DESIGN] - -# List of regular expressions of class ancestor names to ignore when counting -# public methods (see R0903) -exclude-too-few-public-methods= - -# List of qualified class names to ignore when counting class parents (see -# R0901) -ignored-parents= - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when caught. -overgeneral-exceptions=builtins.BaseException,builtins.Exception - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow explicit reexports by alias from a package __init__. -allow-reexport-from-package=no - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules= - -# Output a graph (.gv or any supported image format) of external dependencies -# to the given file (report RP0402 must not be disabled). -ext-import-graph= - -# Output a graph (.gv or any supported image format) of all (i.e. internal and -# external) dependencies to the given file (report RP0402 must not be -# disabled). -import-graph= - -# Output a graph (.gv or any supported image format) of internal dependencies -# to the given file (report RP0402 must not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, -# UNDEFINED. -confidence=HIGH, - INFERENCE, - INFERENCE_FAILURE, - UNDEFINED - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then re-enable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - line-too-long, - missing-function-docstring, - missing-module-docstring, - missing-class-docstring, - invalid-name, - consider-using-f-string, - too-many-instance-attributes, - no-member, - too-many-locals, - too-many-branches, - too-many-arguments, - fixme, - too-many-lines, - too-many-statements, - too-many-public-methods, - duplicate-code, - - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[METHOD_ARGS] - -# List of qualified names (i.e., library.method) which require a timeout -# parameter e.g. 'requests.api.get,requests.api.post' -timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -notes-rgx= - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit,argparse.parse_error - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'fatal', 'error', 'warning', 'refactor', -# 'convention', and 'info' which contain the number of messages in each -# category, as well as 'statement' which is the total number of statements -# analyzed. This score is used by the global evaluation report (RP0004). -evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -#output-format= - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[SIMILARITIES] - -# Comments are removed from the similarity computation -ignore-comments=yes - -# Docstrings are removed from the similarity computation -ignore-docstrings=yes - -# Imports are removed from the similarity computation -ignore-imports=yes - -# Signatures are removed from the similarity computation -ignore-signatures=yes - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. No available dictionaries : You need to install -# both the python package and the system dependency for enchant to work.. -spelling-dict= - -# List of comma separated words that should be considered directives if they -# appear at the beginning of a comment and should not be checked. -spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of symbolic message names to ignore for Mixin members. -ignored-checks-for-mixins=no-member, - not-async-context-manager, - not-context-manager, - attribute-defined-outside-init - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# Regex pattern to define which classes are considered mixins. -mixin-class-rgx=.*[Mm]ixin - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of names allowed to shadow builtins -allowed-redefined-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index b80639d..9c83ebc 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -119,11 +119,11 @@ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. -Community Impact Guidelines were inspired by +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at -[https://www.contributor-covenant.org/faq][FAQ]. Translations are available +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 587b399..18390e9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -66,13 +66,13 @@ The following steps are tested on an Ubuntu system. $ make test_all # run all the tests, report all the errors ``` -9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting. +9. Format your code. We use ```ruff``` for code formatting. ```bash $ make style ``` -10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. +10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. ```bash $ make lint diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..8999e09 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 Coqui + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in index 375bf69..97413e4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,6 @@ include README.md include LICENSE.txt -include requirements.*.txt -include requirements.txt -include trainer/VERSION +include requirements.dev.txt recursive-include trainer *.json recursive-include trainer *.html recursive-include trainer *.png diff --git a/Makefile b/Makefile index c56cf33..eacf888 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,10 @@ .DEFAULT_GOAL := help -.PHONY: test system-deps dev-deps deps style lint install help docs +.PHONY: test dev-deps deps style lint install help help: @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -target_dirs := tests trainer +target_dirs := bin examples tests trainer test_all: ## run tests and don't stop on an error. coverage run -m pytest trainer tests @@ -16,26 +16,14 @@ test_failed: ## only run tests failed the last time. coverage run -m pytest --ff trainer tests style: ## update code style. - black ${target_dirs} - isort ${target_dirs} + ruff format ${target_dirs} -lint: ## run pylint linter. - pylint ${target_dirs} +lint: ## run linter. + ruff check ${target_dirs} dev-deps: ## install development deps pip install -r requirements.dev.txt -doc-deps: ## install docs dependencies - pip install -r docs/requirements.txt - -build-docs: ## build the docs - cd docs && make clean && make build - -deps: ## install 🐸 requirements. - pip install -r requirements.txt - install: ## install 🐸 Trainer for development. - pip install -e .[all] - -docs: ## build the docs - $(MAKE) -C docs clean && $(MAKE) -C docs html + pip install -e .[dev,test] + pre-commit install diff --git a/README.md b/README.md index b54cacd..26817ca 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,13 @@
# 👟 Trainer + +[![PyPI - License](https://img.shields.io/pypi/l/coqui-tts-trainer)](https://github.com/idiap/coqui-ai-Trainer/blob/main/LICENSE.txt) +![PyPI - Python Version](https://img.shields.io/pypi/pyversions/coqui-tts-trainer) +[![PyPI - Version](https://img.shields.io/pypi/v/coqui-tts-trainer)](https://pypi.org/project/coqui-tts-trainer) +![GithubActions](https://github.com/idiap/coqui-ai-Trainer/actions/workflows/tests.yml/badge.svg) +![GithubActions](https://github.com/idiap/coqui-ai-Trainer/actions/workflows/style_check.yml/badge.svg) + An opinionated general purpose model trainer on PyTorch with a simple code base. Fork of the [original, unmaintained repository](https://github.com/coqui-ai/Trainer). New PyPI package: [coqui-tts-trainer](https://pypi.org/project/coqui-tts-trainer) diff --git a/bin/collect_env_info.py b/bin/collect_env_info.py index da39c91..f6885d8 100644 --- a/bin/collect_env_info.py +++ b/bin/collect_env_info.py @@ -1,4 +1,6 @@ """Get detailed info about the working environment.""" + +import json import os import platform import sys @@ -6,11 +8,10 @@ import numpy import torch -sys.path += [os.path.abspath(".."), os.path.abspath(".")] -import json - import trainer +sys.path += [os.path.abspath(".."), os.path.abspath(".")] + def system_info(): return { diff --git a/bin/generate_requirements.py b/bin/generate_requirements.py new file mode 100644 index 0000000..bbd32ba --- /dev/null +++ b/bin/generate_requirements.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +"""Generate requirements/*.txt files from pyproject.toml. + +Adapted from: +https://github.com/numpy/numpydoc/blob/e7c6baf00f5f73a4a8f8318d0cb4e04949c9a5d1/tools/generate_requirements.py +""" + +import sys +from pathlib import Path + +try: # standard module since Python 3.11 + import tomllib as toml +except ImportError: + try: # available for older Python via pip + import tomli as toml + except ImportError: + sys.exit("Please install `tomli` first: `pip install tomli`") + +script_pth = Path(__file__) +repo_dir = script_pth.parent.parent +script_relpth = script_pth.relative_to(repo_dir) +header = [ + f"# Generated via {script_relpth.as_posix()} and pre-commit hook.", + "# Do not edit this file; modify pyproject.toml instead.", +] + + +def generate_requirement_file(name: str, req_list: list[str]) -> None: + req_fname = repo_dir / f"requirements.{name}.txt" + req_fname.write_text("\n".join(header + req_list) + "\n") + + +def main() -> None: + pyproject = toml.loads((repo_dir / "pyproject.toml").read_text()) + generate_requirement_file("dev", pyproject["project"]["optional-dependencies"]["dev"]) + + +if __name__ == "__main__": + main() diff --git a/examples/train_mnist.py b/examples/train_mnist.py index cb8d6d8..01ef01d 100644 --- a/examples/train_mnist.py +++ b/examples/train_mnist.py @@ -12,7 +12,7 @@ from torchvision import transforms from torchvision.datasets import MNIST -from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs +from trainer import Trainer, TrainerArgs, TrainerConfig, TrainerModel @dataclass @@ -65,9 +65,7 @@ def eval_step(self, batch, criterion): def get_criterion(): return torch.nn.NLLLoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:256] diff --git a/examples/train_simple_gan.py b/examples/train_simple_gan.py index ad9be17..908c53c 100644 --- a/examples/train_simple_gan.py +++ b/examples/train_simple_gan.py @@ -13,8 +13,7 @@ from torchvision import transforms from torchvision.datasets import MNIST -from trainer import Trainer, TrainerConfig, TrainerModel -from trainer.trainer import TrainerArgs +from trainer import Trainer, TrainerArgs, TrainerConfig, TrainerModel is_cuda = torch.cuda.is_available() @@ -83,8 +82,7 @@ def __init__(self): self.generator = Generator(latent_dim=100, img_shape=data_shape) self.discriminator = Discriminator(img_shape=data_shape) - def forward(self, x): - ... + def forward(self, x): ... def optimize(self, batch, trainer): imgs, _ = batch @@ -153,9 +151,7 @@ def get_optimizer(self): def get_criterion(self): return nn.BCELoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:64] @@ -165,7 +161,6 @@ def get_data_loader( if __name__ == "__main__": - config = GANModelConfig() config.batch_size = 64 config.grad_clip = None diff --git a/pyproject.toml b/pyproject.toml index 7ceeaaa..bedc6d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,32 +1,99 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" -[flake8] -max-line-length=120 +[tool.setuptools.packages.find] +include = ["trainer*"] -[tool.black] +[project] +name = "coqui-tts-trainer" +version = "0.1.2" +description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui." +readme = "README.md" +requires-python = ">=3.9, <3.13" +license = {text = "Apache-2.0"} +authors = [ + {name = "Eren Gölge", email = "egolge@coqui.ai"} +] +maintainers = [ + {name = "Enno Hermann", email = "enno.hermann@gmail.com"} +] +classifiers = [ + "Environment :: Console", + "Natural Language :: English", + # How mature is this project? Common values are + # 3 - Alpha, 4 - Beta, 5 - Production/Stable + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "coqpit", + "fsspec", + "numpy", + "psutil", + "soundfile", + "tensorboard", + "torch>=2.0", +] + +[project.optional-dependencies] +# Development dependencies +dev = [ + "coverage", + "pre-commit", + "pytest", + "ruff==0.4.10", + "tomli; python_version < '3.11'", +] +# Dependencies for running the tests +test = [ + "accelerate", + "torchvision", +] + +[project.urls] +Homepage = "https://github.com/idiap/coqui-ai-Trainer" +Repository = "https://github.com/idiap/coqui-ai-Trainer" +Issues = "https://github.com/idiap/coqui-ai-Trainer/issues" + +[tool.ruff] line-length = 120 -target-version = ['py38'] -exclude = ''' - -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - )/ - | foo.py # also separately exclude a file named foo.py in - # the root of the project -) -''' - -[tool.isort] -profile = "black" -multi_line_output = 3 \ No newline at end of file +target-version = "py39" +lint.extend-select = [ + "B", # bugbear + "I", # import sorting + "PIE", + "PLC", + "PLE", + "PLW", + "RUF", + "UP", # pyupgrade +] + +lint.ignore = [ + "F821", # TODO: enable + "PLW2901", # TODO: enable + "UP032", # TODO: enable +] + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = [ + "F401", # init files may have "unused" imports for now + "F403", # init files may have star imports for now +] + +[tool.coverage.run] +parallel = true +source = ["trainer"] diff --git a/requirements.dev.txt b/requirements.dev.txt index bd93ba2..ae2d8bf 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,6 +1,7 @@ -black==24.2.0 +# Generated via bin/generate_requirements.py and pre-commit hook. +# Do not edit this file; modify pyproject.toml instead. coverage -isort +pre-commit pytest -pylint -accelerate # for testing +ruff==0.4.10 +tomli; python_version < '3.11' diff --git a/requirements.test.txt b/requirements.test.txt deleted file mode 100644 index abf5036..0000000 --- a/requirements.test.txt +++ /dev/null @@ -1 +0,0 @@ -torchvision \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index a69286e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -torch>=1.7 -coqpit -psutil -fsspec -tensorboard -soundfile \ No newline at end of file diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1f31cb5..0000000 --- a/setup.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[build_py] -build_lib=temp_build - -[bdist_wheel] -bdist_dir=temp_build - -[install_lib] -build_dir=temp_build diff --git a/setup.py b/setup.py index 5135830..fc6ddff 100644 --- a/setup.py +++ b/setup.py @@ -20,100 +20,6 @@ # .,*++++::::::++++*,. # +from setuptools import setup -import os -import subprocess -import sys - -import setuptools.command.build_py -import setuptools.command.develop -from setuptools import find_packages, setup - -if sys.version_info < (3, 6) or sys.version_info >= (3, 13): - raise RuntimeError("Trainer requires python >= 3.6 and <3.13 " "but your Python version is {}".format(sys.version)) - - -cwd = os.path.dirname(os.path.abspath(__file__)) - -cwd = os.path.dirname(os.path.abspath(__file__)) -with open(os.path.join(cwd, "trainer", "VERSION")) as fin: - version = fin.read().strip() - - -class build_py(setuptools.command.build_py.build_py): # pylint: disable=too-many-ancestors - def run(self): - setuptools.command.build_py.build_py.run(self) - - -class develop(setuptools.command.develop.develop): - def run(self): - setuptools.command.develop.develop.run(self) - - -def pip_install(package_name): - subprocess.call([sys.executable, "-m", "pip", "install", package_name]) - - -requirements = open(os.path.join(cwd, "requirements.txt"), "r").readlines() -with open(os.path.join(cwd, "requirements.dev.txt"), "r") as f: - requirements_dev = f.readlines() -with open(os.path.join(cwd, "requirements.test.txt"), "r") as f: - requirements_test = f.readlines() -requirements_all = requirements + requirements_dev + requirements_test - -with open("README.md", "r", encoding="utf-8") as readme_file: - README = readme_file.read() - -setup( - name="coqui-tts-trainer", - version=version, - url="https://github.com/idiap/coqui-ai-Trainer", - author="Eren Gölge", - author_email="egolge@coqui.ai", - maintainer="Enno Hermann", - maintainer_email="enno.hermann@gmail.com", - description="General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui.", - long_description=README, - long_description_content_type="text/markdown", - license="Apache2", - # package - include_package_data=True, - packages=find_packages(include=["trainer"]), - package_data={ - "trainer": [ - "VERSION", - ] - }, - project_urls={ - "Documentation": "https://github.com/idiap/coqui-ai-Trainer", - "Tracker": "https://github.com/idiap/coqui-ai-Trainer/issues", - "Repository": "https://github.com/idiap/coqui-ai-Trainer", - }, - cmdclass={ - "build_py": build_py, - "develop": develop, - }, - install_requires=requirements, - extras_require={"dev": requirements_dev, "test": requirements_test, "all": requirements_all}, - python_requires=">=3.6.0, <3.13", - classifiers=[ - "Environment :: Console", - "Natural Language :: English", - # How mature is this project? Common values are - # 3 - Alpha, 4 - Beta, 5 - Production/Stable - "Development Status :: 3 - Alpha", - # Indicate who your project is intended for - "Intended Audience :: Developers", - # Pick your license as you wish - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - # Specify the Python versions you support here. In particular, ensure - # that you indicate whether you support Python 2, Python 3 or both. - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - zip_safe=False, -) +setup() diff --git a/tests/test_train_gan.py b/tests/test_train_gan.py index 37c7fc0..d02976d 100644 --- a/tests/test_train_gan.py +++ b/tests/test_train_gan.py @@ -1,6 +1,5 @@ import os from dataclasses import dataclass -from typing import Any, Dict, Tuple import numpy as np import torch @@ -9,8 +8,7 @@ from torchvision import transforms from torchvision.datasets import MNIST -from trainer import TrainerConfig, TrainerModel -from trainer.trainer import Trainer, TrainerArgs +from trainer import Trainer, TrainerArgs, TrainerConfig, TrainerModel is_cuda = torch.cuda.is_available() @@ -126,9 +124,7 @@ def get_optimizer(self): def get_criterion(self): return nn.BCELoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:64] @@ -220,9 +216,7 @@ def get_optimizer(self): def get_criterion(self): return nn.BCELoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:64] @@ -335,9 +329,7 @@ def get_optimizer(self): def get_criterion(self): return nn.BCELoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:64] @@ -451,9 +443,7 @@ def get_optimizer(self): def get_criterion(self): return nn.BCELoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:64] @@ -569,9 +559,7 @@ def get_optimizer(self): def get_criterion(self): return nn.BCELoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:64] diff --git a/tests/utils/mnist.py b/tests/utils/mnist.py index d9c8150..afb849b 100644 --- a/tests/utils/mnist.py +++ b/tests/utils/mnist.py @@ -61,9 +61,7 @@ def eval_step(self, batch, criterion): def get_criterion(): return torch.nn.NLLLoss() - def get_data_loader( - self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 - ): # pylint: disable=unused-argument + def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, rank=0): # pylint: disable=unused-argument transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) dataset.data = dataset.data[:256] diff --git a/tests/utils/train_mnist.py b/tests/utils/train_mnist.py index e4ec66f..13fda81 100644 --- a/tests/utils/train_mnist.py +++ b/tests/utils/train_mnist.py @@ -1,5 +1,3 @@ -from distutils.command.config import config - from mnist import MnistModel, MnistModelConfig from trainer import Trainer, TrainerArgs diff --git a/trainer/VERSION b/trainer/VERSION deleted file mode 100644 index 8308b63..0000000 --- a/trainer/VERSION +++ /dev/null @@ -1 +0,0 @@ -v0.1.1 diff --git a/trainer/__init__.py b/trainer/__init__.py index 45e4c32..40c72da 100644 --- a/trainer/__init__.py +++ b/trainer/__init__.py @@ -1,9 +1,7 @@ -import os +import importlib.metadata +from trainer.config import TrainerArgs, TrainerConfig from trainer.model import * from trainer.trainer import * -with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: - version = f.read().strip() - -__version__ = version +__version__ = importlib.metadata.version("coqui-tts-trainer") diff --git a/trainer/callbacks.py b/trainer/callbacks.py index 3a407d7..505fdac 100644 --- a/trainer/callbacks.py +++ b/trainer/callbacks.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict +from typing import Callable class TrainerCallback: @@ -13,7 +13,7 @@ def __init__(self) -> None: self.callbacks_on_train_step_end = [] self.callbacks_on_keyboard_interrupt = [] - def parse_callbacks_dict(self, callbacks_dict: Dict[str, Callable]) -> None: + def parse_callbacks_dict(self, callbacks_dict: dict[str, Callable]) -> None: for key, value in callbacks_dict.items(): if key == "on_init_start": self.callbacks_on_init_start.append(value) diff --git a/trainer/config.py b/trainer/config.py new file mode 100644 index 0000000..872c9a7 --- /dev/null +++ b/trainer/config.py @@ -0,0 +1,232 @@ +from dataclasses import dataclass, field +from typing import Optional, Union + +from coqpit import Coqpit + + +@dataclass +class TrainerArgs(Coqpit): + """Trainer arguments that can be accessed from the command line. + + Examples:: + >>> python train.py --restore_path /path/to/checkpoint.pth + """ + + continue_path: str = field( + default="", + metadata={ + "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." + }, + ) + restore_path: str = field( + default="", + metadata={ + "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." + }, + ) + best_path: str = field( + default="", + metadata={ + "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" + }, + ) + use_ddp: bool = field( + default=False, + metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}, + ) + use_accelerate: bool = field(default=False, metadata={"help": "Use HF Accelerate as the back end for training."}) + grad_accum_steps: int = field( + default=1, + metadata={ + "help": "Number of gradient accumulation steps. It is used to accumulate gradients over multiple batches." + }, + ) + overfit_batch: bool = field(default=False, metadata={"help": "Overfit a single batch for debugging."}) + skip_train_epoch: bool = field( + default=False, + metadata={"help": "Skip training and only run evaluation and test."}, + ) + start_with_eval: bool = field( + default=False, + metadata={"help": "Start with evaluation and test."}, + ) + small_run: Optional[int] = field( + default=None, + metadata={ + "help": "Only use a subset of the samples for debugging. Set the number of samples to use. Defaults to None. " + }, + ) + gpu: Optional[int] = field( + default=None, metadata={"help": "GPU ID to use if ```CUDA_VISIBLE_DEVICES``` is not set. Defaults to None."} + ) + # only for DDP + rank: int = field(default=0, metadata={"help": "Process rank in a distributed training. Don't set manually."}) + group_id: str = field( + default="", metadata={"help": "Process group id in a distributed training. Don't set manually."} + ) + + +@dataclass +class TrainerConfig(Coqpit): + """Config fields tweaking the Trainer for a model. + + A ````ModelConfig```, by inheriting ```TrainerConfig``` must be defined for using 👟. + Inherit this by a new model config and override the fields as needed. + All the fields can be overridden from comman-line as ```--coqpit.arg_name=value```. + + Example:: + + Run the training code by overriding the ```lr``` and ```plot_step``` fields. + + >>> python train.py --coqpit.plot_step=22 --coqpit.lr=0.001 + + Defining a model using ```TrainerConfig```. + + >>> from trainer import TrainerConfig + >>> class MyModelConfig(TrainerConfig): + ... optimizer: str = "Adam" + ... lr: float = 0.001 + ... epochs: int = 1 + ... ... + >>> class MyModel(nn.module): + ... def __init__(self, config): + ... ... + >>> model = MyModel(MyModelConfig()) + + """ + + # Fields for the run + output_path: str = field(default="output") + logger_uri: Optional[str] = field( + default=None, + metadata={ + "help": "URI to save training artifacts by the logger. If not set, logs will be saved in the output_path. Defaults to None" + }, + ) + run_name: str = field(default="run", metadata={"help": "Name of the run. Defaults to 'run'"}) + project_name: Optional[str] = field(default=None, metadata={"help": "Name of the project. Defaults to None"}) + run_description: str = field( + default="🐸Coqui trainer run.", + metadata={"help": "Notes and description about the run. Defaults to '🐸Coqui trainer run.'"}, + ) + # Fields for logging + print_step: int = field( + default=25, metadata={"help": "Print training stats on the terminal every print_step steps. Defaults to 25"} + ) + plot_step: int = field( + default=100, metadata={"help": "Plot training stats on the logger every plot_step steps. Defaults to 100"} + ) + model_param_stats: bool = field( + default=False, metadata={"help": "Log model parameters stats on the logger dashboard. Defaults to False"} + ) + wandb_entity: Optional[str] = field( + default=None, metadata={"help": "Wandb entity to log the run. Defaults to None"} + ) + dashboard_logger: str = field( + default="tensorboard", metadata={"help": "Logger to use for the tracking dashboard. Defaults to 'tensorboard'"} + ) + # Fields for checkpointing + save_on_interrupt: bool = field( + default=True, metadata={"help": "Save checkpoint on interrupt (Ctrl+C). Defaults to True"} + ) + log_model_step: Optional[int] = field( + default=None, + metadata={ + "help": "Save checkpoint to the logger every log_model_step steps. If not defined `save_step == log_model_step`." + }, + ) + save_step: int = field( + default=10000, metadata={"help": "Save local checkpoint every save_step steps. Defaults to 10000"} + ) + save_n_checkpoints: int = field(default=5, metadata={"help": "Keep n local checkpoints. Defaults to 5"}) + save_checkpoints: bool = field(default=True, metadata={"help": "Save checkpoints locally. Defaults to True"}) + save_all_best: bool = field( + default=False, metadata={"help": "Save all best checkpoints and keep the older ones. Defaults to False"} + ) + save_best_after: int = field(default=0, metadata={"help": "Wait N steps to save best checkpoints. Defaults to 0"}) + target_loss: Optional[str] = field( + default=None, metadata={"help": "Target loss name to select the best model. Defaults to None"} + ) + # Fields for eval and test run + print_eval: bool = field(default=False, metadata={"help": "Print eval steps on the terminal. Defaults to False"}) + test_delay_epochs: int = field(default=0, metadata={"help": "Wait N epochs before running the test. Defaults to 0"}) + run_eval: bool = field( + default=True, metadata={"help": "Run evalulation epoch after training epoch. Defaults to True"} + ) + run_eval_steps: Optional[int] = field( + default=None, + metadata={ + "help": "Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None" + }, + ) + # Fields for distributed training + distributed_backend: str = field( + default="nccl", metadata={"help": "Distributed backend to use. Defaults to 'nccl'"} + ) + distributed_url: str = field( + default="tcp://localhost:54321", + metadata={"help": "Distributed url to use. Defaults to 'tcp://localhost:54321'"}, + ) + # Fields for training specs + mixed_precision: bool = field(default=False, metadata={"help": "Use mixed precision training. Defaults to False"}) + precision: str = field( + default="fp16", + metadata={ + "help": "Precision to use in mixed precision training. `fp16` for float16 and `bf16` for bfloat16. Defaults to 'f16'" + }, + ) + epochs: int = field(default=1000, metadata={"help": "Number of epochs to train. Defaults to 1000"}) + batch_size: int = field(default=32, metadata={"help": "Batch size to use. Defaults to 32"}) + eval_batch_size: int = field(default=16, metadata={"help": "Batch size to use for eval. Defaults to 16"}) + grad_clip: float = field( + default=0.0, metadata={"help": "Gradient clipping value. Disabled if <= 0. Defaults to 0.0"} + ) + scheduler_after_epoch: bool = field( + default=True, + metadata={"help": "Step the scheduler after each epoch else step after each iteration. Defaults to True"}, + ) + # Fields for optimzation + lr: Union[float, list[float]] = field( + default=0.001, metadata={"help": "Learning rate for each optimizer. Defaults to 0.001"} + ) + optimizer: Optional[Union[str, list[str]]] = field( + default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"} + ) + optimizer_params: Union[dict, list[dict]] = field( + default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"} + ) + lr_scheduler: Optional[Union[str, list[str]]] = field( + default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"} + ) + lr_scheduler_params: dict = field( + default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"} + ) + use_grad_scaler: bool = field( + default=False, + metadata={ + "help": "Enable/disable gradient scaler explicitly. It is enabled by default with AMP training. Defaults to False" + }, + ) + allow_tf32: bool = field( + default=False, + metadata={ + "help": "A bool that controls whether TensorFloat-32 tensor cores may be used in matrix multiplications on Ampere or newer GPUs. Default to False." + }, + ) + cudnn_enable: bool = field(default=True, metadata={"help": "Enable/disable cudnn explicitly. Defaults to True"}) + cudnn_deterministic: bool = field( + default=False, + metadata={ + "help": "Enable/disable deterministic cudnn operations. Set this True for reproducibility but it slows down training significantly. Defaults to False." + }, + ) + cudnn_benchmark: bool = field( + default=False, + metadata={ + "help": "Enable/disable cudnn benchmark explicitly. Set this False if your input size change constantly. Defaults to False" + }, + ) + training_seed: int = field( + default=54321, + metadata={"help": "Global seed for torch, random and numpy random number generator. Defaults to 54321"}, + ) diff --git a/trainer/distribute.py b/trainer/distribute.py index b02730a..f1505f5 100644 --- a/trainer/distribute.py +++ b/trainer/distribute.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- import os import pathlib @@ -51,7 +50,7 @@ def distribute(): command[-1] = f"--rank={rank}" # prevent stdout for processes with rank != 0 stdout = None - p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with + p = subprocess.Popen(["python3", *command], stdout=stdout, env=my_env) # pylint: disable=consider-using-with processes.append(p) logger.info(command) diff --git a/trainer/generic_utils.py b/trainer/generic_utils.py index 38e6fa3..b9ff256 100644 --- a/trainer/generic_utils.py +++ b/trainer/generic_utils.py @@ -1,7 +1,7 @@ -# -*- coding: utf-8 -*- import datetime import os import subprocess +from typing import Any, Union import fsspec import torch @@ -9,14 +9,14 @@ from trainer.logger import logger -def isimplemented(obj, method_name): +def isimplemented(obj, method_name) -> bool: """Check if a method is implemented in a class.""" if method_name in dir(obj) and callable(getattr(obj, method_name)): try: obj.__getattribute__(method_name)() # pylint: disable=bad-option-value, unnecessary-dunder-call except NotImplementedError: return False - except: # pylint: disable=bare-except + except Exception: return True return True return False @@ -38,19 +38,19 @@ def get_cuda(): return use_cuda, device -def get_git_branch(): +def get_git_branch() -> str: try: out = subprocess.check_output(["git", "branch"]).decode("utf8") current = next(line for line in out.split("\n") if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" - except FileNotFoundError: + except (FileNotFoundError, StopIteration): current = "unknown" return current -def get_commit_hash(): +def get_commit_hash() -> str: """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" try: commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() @@ -60,7 +60,7 @@ def get_commit_hash(): return commit -def get_experiment_folder_path(root_path, model_name): +def get_experiment_folder_path(root_path: Union[str, os.PathLike[Any]], model_name: str) -> str: """Get an experiment folder path with the current date and time""" date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") commit_hash = get_commit_hash() @@ -68,8 +68,9 @@ def get_experiment_folder_path(root_path, model_name): return output_folder -def remove_experiment_folder(experiment_path): +def remove_experiment_folder(experiment_path: Union[str, os.PathLike[Any]]) -> None: """Check folder if there is a checkpoint, otherwise remove the folder""" + experiment_path = str(experiment_path) fs = fsspec.get_mapper(experiment_path).fs checkpoint_files = fs.glob(experiment_path + "/*.pth") if not checkpoint_files: @@ -80,14 +81,14 @@ def remove_experiment_folder(experiment_path): logger.info(" ! Run is kept in %s", experiment_path) -def count_parameters(model): +def count_parameters(model: torch.nn.Module) -> int: r"""Count number of trainable parameters in a network""" return sum(p.numel() for p in model.parameters() if p.requires_grad) def set_partial_state_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint_state.items(): + for k in checkpoint_state: if k not in model_dict: logger.info(" | > Layer missing in the model definition: %s", k) for k in model_dict: diff --git a/trainer/io.py b/trainer/io.py index 62381f1..165f6ea 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -4,7 +4,7 @@ import re import sys from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Optional, Union from urllib.parse import urlparse import fsspec @@ -14,9 +14,15 @@ from trainer.logger import logger -def get_user_data_dir(appname): - if sys.platform == "win32": - import winreg # pylint: disable=import-outside-toplevel, import-error +def get_user_data_dir(appname: str) -> Path: + TTS_HOME = os.environ.get("TTS_HOME") + XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME") + if TTS_HOME is not None: + ans = Path(TTS_HOME).expanduser().resolve(strict=False) + elif XDG_DATA_HOME is not None: + ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel key = winreg.OpenKey( winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" @@ -30,9 +36,8 @@ def get_user_data_dir(appname): return ans.joinpath(appname) -def copy_model_files(config: Coqpit, out_path, new_fields): - """Copy config.json and other model files to training folder and add - new fields. +def copy_model_files(config: Coqpit, out_path: Union[str, os.PathLike[Any]], new_fields: dict) -> None: + """Copy config.json and other model files to training folder and add new fields. Args: config (Coqpit): Coqpit config defining the training run. @@ -49,17 +54,19 @@ def copy_model_files(config: Coqpit, out_path, new_fields): def load_fsspec( - path: str, - map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + path: Union[str, os.PathLike[Any]], + map_location: Union[str, Callable, torch.device, dict[Union[str, torch.device], Union[str, torch.device]]] = None, cache: bool = True, **kwargs, ) -> Any: """Like torch.load but can load from other locations (e.g. s3:// , gs://). + Args: path: Any path or url supported by fsspec. map_location: torch.device or str. cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/trainer_cache`. Defaults to True. **kwargs: Keyword arguments forwarded to torch.load. + Returns: Object stored in path. """ @@ -72,12 +79,18 @@ def load_fsspec( ) as f: return torch.load(f, map_location=map_location, **kwargs) else: - with fsspec.open(path, "rb") as f: + with fsspec.open(str(path), "rb") as f: return torch.load(f, map_location=map_location, **kwargs) -def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) +def load_checkpoint( + model: torch.nn.Module, + checkpoint_path: Union[str, os.PathLike[Any]], + use_cuda: bool = False, + eval: bool = False, + cache: bool = False, +) -> tuple[torch.nn.Module, Any]: # pylint: disable=redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) model.load_state_dict(state["model"]) if use_cuda: model.cuda() @@ -86,7 +99,7 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli return model, state -def save_fsspec(state: Any, path: str, **kwargs): +def save_fsspec(state: Any, path: Union[str, os.PathLike[Any]], **kwargs) -> None: """Like torch.save but can save to other locations (e.g. s3:// , gs://). Args: @@ -94,11 +107,21 @@ def save_fsspec(state: Any, path: str, **kwargs): path: Any path or url supported by fsspec. **kwargs: Keyword arguments forwarded to torch.save. """ - with fsspec.open(path, "wb") as f: + with fsspec.open(str(path), "wb") as f: torch.save(state, f, **kwargs) -def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, save_func, **kwargs): +def save_model( + config: Union[dict, Coqpit], + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scaler, + current_step: int, + epoch: int, + output_path: Union[str, os.PathLike[Any]], + save_func: Optional[Callable] = None, + **kwargs, +) -> None: if hasattr(model, "module"): model_state = model.module.state_dict() else: @@ -128,22 +151,22 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) - if save_func: + if save_func is not None: save_func(state, output_path) else: save_fsspec(state, output_path) def save_checkpoint( - config, - model, - optimizer, + config: Union[dict, Coqpit], + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, scaler, - current_step, - epoch, - output_folder, - save_n_checkpoints=None, - save_func=None, + current_step: int, + epoch: int, + output_folder: Union[str, os.PathLike[Any]], + save_n_checkpoints: Optional[int] = None, + save_func: Optional[Callable] = None, **kwargs, ): file_name = f"checkpoint_{current_step}.pth" @@ -166,20 +189,20 @@ def save_checkpoint( def save_best_model( - current_loss, - best_loss, - config, - model, - optimizer, + current_loss: Union[dict, float], + best_loss: Union[dict, float], + config: Union[dict, Coqpit], + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, scaler, - current_step, - epoch, - out_path, - keep_all_best=False, - keep_after=0, - save_func=None, + current_step: int, + epoch: int, + out_path: Union[str, os.PathLike[Any]], + keep_all_best: bool = False, + keep_after: int = 0, + save_func: Optional[Callable] = None, **kwargs, -): +) -> Union[dict, float]: if isinstance(current_loss, dict): use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( @@ -188,9 +211,7 @@ def save_best_model( else: is_save_model = current_loss < best_loss - if isinstance(keep_after, (int, float)): - keep_after = int(keep_after) - is_save_model = is_save_model and current_step > keep_after + is_save_model = is_save_model and current_step > keep_after if is_save_model: best_model_name = f"best_model_{current_step}.pth" @@ -208,7 +229,7 @@ def save_best_model( save_func=save_func, **kwargs, ) - fs = fsspec.get_mapper(out_path).fs + fs = fsspec.get_mapper(str(out_path)).fs # only delete previous if current is saved successfully if not keep_all_best or (current_step < keep_after): model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) @@ -223,7 +244,7 @@ def save_best_model( return best_loss -def get_last_checkpoint(path: str) -> Tuple[str, str]: +def get_last_checkpoint(path: Union[str, os.PathLike]) -> tuple[str, str]: """Get latest checkpoint or/and best model in path. It is based on globbing for `*.pth` and the RegEx @@ -239,6 +260,7 @@ def get_last_checkpoint(path: str) -> Tuple[str, str]: Path to the last checkpoint Path to best checkpoint """ + path = str(path) fs = fsspec.get_mapper(path).fs file_names = fs.glob(os.path.join(path, "*.pth")) scheme = urlparse(path).scheme @@ -288,31 +310,39 @@ def get_last_checkpoint(path: str) -> Tuple[str, str]: return last_models["checkpoint"], last_models["best_model"] -def keep_n_checkpoints(path: str, n: int) -> None: +def keep_n_checkpoints(path: Union[str, os.PathLike[Any]], n: int) -> None: """Keep only the last n checkpoints in path. Args: path: Path to files to be compared. n: Number of checkpoints to keep. """ - fs = fsspec.get_mapper(path).fs + fs = fsspec.get_mapper(str(path)).fs file_names = sort_checkpoints(path, "checkpoint") if len(file_names) > n: for file_name in file_names[:-n]: fs.rm(file_name) -def sort_checkpoints(output_path: str, checkpoint_prefix: str, use_mtime: bool = False) -> List[str]: +def sort_checkpoints( + output_path: Union[str, os.PathLike[Any]], checkpoint_prefix: str, use_mtime: bool = False +) -> list[str]: """Sort checkpoint paths based on the checkpoint step number. Args: - output_path (str): Path to directory containing checkpoints. + output_path: Path to directory containing checkpoints. checkpoint_prefix (str): Prefix of the checkpoint files. use_mtime (bool): If True, use modification dates to determine checkpoint order. """ ordering_and_checkpoint_path = [] - glob_checkpoints = [str(x) for x in Path(output_path).glob(f"{checkpoint_prefix}_*")] + output_path = str(output_path) + fs = fsspec.get_mapper(output_path).fs + glob_checkpoints = fs.glob(os.path.join(output_path, f"{checkpoint_prefix}_*")) + scheme = urlparse(output_path).scheme + if scheme and output_path.startswith(scheme + "://"): + # scheme is not preserved in fs.glob, add it back if it exists on the path + glob_checkpoints = [scheme + "://" + file_name for file_name in glob_checkpoints] for path in glob_checkpoints: if use_mtime: diff --git a/trainer/logging/aim_logger.py b/trainer/logging/aim_logger.py index 6a59895..f4cc270 100644 --- a/trainer/logging/aim_logger.py +++ b/trainer/logging/aim_logger.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from trainer.logging.base_dash_logger import BaseDashboardLogger @@ -15,7 +17,7 @@ def __init__( self, repo: str, model_name: str, - tags: str = None, + tags: Optional[str] = None, ): self._context = None self.model_name = model_name diff --git a/trainer/logging/base_dash_logger.py b/trainer/logging/base_dash_logger.py index 9ae891d..5e20e45 100644 --- a/trainer/logging/base_dash_logger.py +++ b/trainer/logging/base_dash_logger.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Union +from typing import Union from trainer.io import save_fsspec from trainer.utils.distributed import rank_zero_only @@ -37,15 +37,15 @@ def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases= pass @abstractmethod - def add_scalars(self, scope_name: str, scalars: Dict, step: int): + def add_scalars(self, scope_name: str, scalars: dict, step: int): pass @abstractmethod - def add_figures(self, scope_name: str, figures: Dict, step: int): + def add_figures(self, scope_name: str, figures: dict, step: int): pass @abstractmethod - def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int): + def add_audios(self, scope_name: str, audios: dict, step: int, sample_rate: int): pass @abstractmethod @@ -58,7 +58,7 @@ def finish(self): @staticmethod @rank_zero_only - def save_model(state: Dict, path: str): + def save_model(state: dict, path: str): save_fsspec(state, path) def train_step_stats(self, step, stats): diff --git a/trainer/logging/clearml_logger.py b/trainer/logging/clearml_logger.py index f99a02b..17f2479 100644 --- a/trainer/logging/clearml_logger.py +++ b/trainer/logging/clearml_logger.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Any, Optional import torch @@ -35,7 +35,7 @@ def __init__( local_path: str, project_name: str, task_name: str, - tags: str = None, + tags: Optional[str] = None, ): self._context = None self.local_path = local_path diff --git a/trainer/logging/dummy_logger.py b/trainer/logging/dummy_logger.py index ec7b37b..beea20f 100644 --- a/trainer/logging/dummy_logger.py +++ b/trainer/logging/dummy_logger.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Union from trainer.logging.base_dash_logger import BaseDashboardLogger @@ -29,13 +29,13 @@ def add_text(self, title: str, text: str, step: int) -> None: def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases=None): pass - def add_scalars(self, scope_name: str, scalars: Dict, step: int): + def add_scalars(self, scope_name: str, scalars: dict, step: int): pass - def add_figures(self, scope_name: str, figures: Dict, step: int): + def add_figures(self, scope_name: str, figures: dict, step: int): pass - def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int): + def add_audios(self, scope_name: str, audios: dict, step: int, sample_rate: int): pass def flush(self): diff --git a/trainer/logging/mlflow_logger.py b/trainer/logging/mlflow_logger.py index 51379f7..f33b376 100644 --- a/trainer/logging/mlflow_logger.py +++ b/trainer/logging/mlflow_logger.py @@ -2,6 +2,7 @@ import shutil import tempfile import traceback +from typing import Optional import soundfile as sf import torch @@ -23,7 +24,7 @@ def __init__( self, log_uri: str, model_name: str, - tags: str = None, + tags: Optional[str] = None, ): self.model_name = model_name self.client = MlflowClient(tracking_uri=os.path.join(log_uri)) diff --git a/trainer/model.py b/trainer/model.py index d865337..9dfd642 100644 --- a/trainer/model.py +++ b/trainer/model.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Optional import torch -from coqpit import Coqpit from torch import nn +from trainer.trainer import Trainer from trainer.trainer_utils import is_apex_available if is_apex_available(): @@ -18,7 +18,7 @@ class TrainerModel(ABC, nn.Module): """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" @abstractmethod - def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: + def forward(self, input: torch.Tensor, *args, aux_input: Optional[dict[str, Any]] = None, **kwargs) -> dict: """Forward ... for the model mainly used in training. You can be flexible here and use different number of arguments and argument names since it is intended to be @@ -31,11 +31,13 @@ def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: Returns: Dict: Model outputs. Main model output must be named as "model_outputs". """ + if aux_input is None: + aux_input = {} outputs_dict = {"model_outputs": None} ... return outputs_dict - def format_batch(self, batch: Dict) -> Dict: + def format_batch(self, batch: dict) -> dict: """Format batch returned by the data loader before sending it to the model. If not implemented, model uses the batch as is. @@ -43,7 +45,7 @@ def format_batch(self, batch: Dict) -> Dict: """ return batch - def format_batch_on_device(self, batch: Dict) -> Dict: + def format_batch_on_device(self, batch: dict) -> dict: """Format batch on device before sending it to the model. If not implemented, model uses the batch as is. @@ -51,7 +53,7 @@ def format_batch_on_device(self, batch: Dict) -> Dict: """ return batch - def train_step(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict]: + def train_step(self, *args: Any, **kwargs: Any) -> tuple[dict, dict]: """Perform a single training step. Run the model forward ... and compute losses. Args: @@ -62,7 +64,6 @@ def train_step(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. """ - ... raise NotImplementedError(" [!] `train_step()` is not implemented.") def train_log(self, *args: Any, **kwargs: Any) -> None: @@ -81,7 +82,6 @@ def train_log(self, *args: Any, **kwargs: Any) -> None: Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - ... raise NotImplementedError(" [!] `train_log()` is not implemented.") @torch.no_grad() @@ -101,7 +101,6 @@ def eval_step(self, *args: Any, **kwargs: Any): def eval_log(self, *args: Any, **kwargs: Any) -> None: """The same as `train_log()`""" - ... raise NotImplementedError(" [!] `eval_log()` is not implemented.") @abstractmethod @@ -126,9 +125,8 @@ def get_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader: def init_for_training(self) -> None: """Initialize model for training.""" - ... - def optimize(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict, float]: + def optimize(self, *args: Any, **kwargs: Any) -> tuple[dict, dict, float]: """Model specific optimization step that must perform the following steps: 1. Forward pass 2. Compute loss @@ -144,12 +142,11 @@ def optimize(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict, float]: Returns: Tuple[Dict, Dict, float]: Model outputs, loss dictionary and grad_norm value. """ - ... raise NotImplementedError(" [!] `optimize()` is not implemented.") def scaled_backward( - self, loss: torch.Tensor, trainer: "Trainer", optimizer: "Optimizer", *args: Any, **kwargs: Any - ) -> Tuple[float, bool]: + self, loss: torch.Tensor, trainer: Trainer, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any + ) -> tuple[float, bool]: """Backward pass with gradient scaling for custom `optimize` calls. Args: diff --git a/trainer/torch.py b/trainer/torch.py index 9184499..17f3489 100644 --- a/trainer/torch.py +++ b/trainer/torch.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import torch from torch.utils.data.distributed import DistributedSampler @@ -29,8 +31,8 @@ class DistributedSamplerWrapper(DistributedSampler): def __init__( self, sampler, - num_replicas: int = None, - rank: int = None, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, ): @@ -103,7 +105,7 @@ def get_lr(self): boolean_indeces = np.less_equal(step_thresholds, step) try: - last_true = np.where(boolean_indeces == True)[0][-1] # pylint: disable=singleton-comparison + last_true = np.where(boolean_indeces)[0][-1] # pylint: disable=singleton-comparison except IndexError: # For the steps larger than the last step in the list pass diff --git a/trainer/trainer.py b/trainer/trainer.py index 229c4c9..bc6cd58 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -8,9 +8,8 @@ import time import traceback from contextlib import nullcontext -from dataclasses import dataclass, field from inspect import signature -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -20,6 +19,7 @@ from torch.utils.data import DataLoader from trainer.callbacks import TrainerCallback +from trainer.config import TrainerArgs from trainer.generic_utils import ( KeepAverage, count_parameters, @@ -38,6 +38,7 @@ save_checkpoint, ) from trainer.logging import ConsoleLogger, DummyLogger, logger_factory +from trainer.logging.base_dash_logger import BaseDashboardLogger from trainer.trainer_utils import ( get_optimizer, get_scheduler, @@ -59,230 +60,6 @@ from apex import amp # pylint: disable=import-error -@dataclass -class TrainerConfig(Coqpit): - """Config fields tweaking the Trainer for a model. - - A ````ModelConfig```, by inheriting ```TrainerConfig``` must be defined for using 👟. - Inherit this by a new model config and override the fields as needed. - All the fields can be overridden from comman-line as ```--coqpit.arg_name=value```. - - Example:: - - Run the training code by overriding the ```lr``` and ```plot_step``` fields. - - >>> python train.py --coqpit.plot_step=22 --coqpit.lr=0.001 - - Defining a model using ```TrainerConfig```. - - >>> from trainer import TrainerConfig - >>> class MyModelConfig(TrainerConfig): - ... optimizer: str = "Adam" - ... lr: float = 0.001 - ... epochs: int = 1 - ... ... - >>> class MyModel(nn.module): - ... def __init__(self, config): - ... ... - >>> model = MyModel(MyModelConfig()) - - """ - - # Fields for the run - output_path: str = field(default="output") - logger_uri: str = field( - default=None, - metadata={ - "help": "URI to save training artifacts by the logger. If not set, logs will be saved in the output_path. Defaults to None" - }, - ) - run_name: str = field(default="run", metadata={"help": "Name of the run. Defaults to 'run'"}) - project_name: str = field(default=None, metadata={"help": "Name of the project. Defaults to None"}) - run_description: str = field( - default="🐸Coqui trainer run.", - metadata={"help": "Notes and description about the run. Defaults to '🐸Coqui trainer run.'"}, - ) - # Fields for logging - print_step: int = field( - default=25, metadata={"help": "Print training stats on the terminal every print_step steps. Defaults to 25"} - ) - plot_step: int = field( - default=100, metadata={"help": "Plot training stats on the logger every plot_step steps. Defaults to 100"} - ) - model_param_stats: bool = field( - default=False, metadata={"help": "Log model parameters stats on the logger dashboard. Defaults to False"} - ) - wandb_entity: str = field(default=None, metadata={"help": "Wandb entity to log the run. Defaults to None"}) - dashboard_logger: str = field( - default="tensorboard", metadata={"help": "Logger to use for the tracking dashboard. Defaults to 'tensorboard'"} - ) - # Fields for checkpointing - save_on_interrupt: bool = field( - default=True, metadata={"help": "Save checkpoint on interrupt (Ctrl+C). Defaults to True"} - ) - log_model_step: int = field( - default=None, - metadata={ - "help": "Save checkpoint to the logger every log_model_step steps. If not defined `save_step == log_model_step`." - }, - ) - save_step: int = field( - default=10000, metadata={"help": "Save local checkpoint every save_step steps. Defaults to 10000"} - ) - save_n_checkpoints: int = field(default=5, metadata={"help": "Keep n local checkpoints. Defaults to 5"}) - save_checkpoints: bool = field(default=True, metadata={"help": "Save checkpoints locally. Defaults to True"}) - save_all_best: bool = field( - default=False, metadata={"help": "Save all best checkpoints and keep the older ones. Defaults to False"} - ) - save_best_after: int = field(default=0, metadata={"help": "Wait N steps to save best checkpoints. Defaults to 0"}) - target_loss: str = field( - default=None, metadata={"help": "Target loss name to select the best model. Defaults to None"} - ) - # Fields for eval and test run - print_eval: bool = field(default=False, metadata={"help": "Print eval steps on the terminal. Defaults to False"}) - test_delay_epochs: int = field(default=0, metadata={"help": "Wait N epochs before running the test. Defaults to 0"}) - run_eval: bool = field( - default=True, metadata={"help": "Run evalulation epoch after training epoch. Defaults to True"} - ) - run_eval_steps: int = field( - default=None, - metadata={ - "help": "Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None" - }, - ) - # Fields for distributed training - distributed_backend: str = field( - default="nccl", metadata={"help": "Distributed backend to use. Defaults to 'nccl'"} - ) - distributed_url: str = field( - default="tcp://localhost:54321", - metadata={"help": "Distributed url to use. Defaults to 'tcp://localhost:54321'"}, - ) - # Fields for training specs - mixed_precision: bool = field(default=False, metadata={"help": "Use mixed precision training. Defaults to False"}) - precision: str = field( - default="fp16", - metadata={ - "help": "Precision to use in mixed precision training. `fp16` for float16 and `bf16` for bfloat16. Defaults to 'f16'" - }, - ) - epochs: int = field(default=1000, metadata={"help": "Number of epochs to train. Defaults to 1000"}) - batch_size: int = field(default=32, metadata={"help": "Batch size to use. Defaults to 32"}) - eval_batch_size: int = field(default=16, metadata={"help": "Batch size to use for eval. Defaults to 16"}) - grad_clip: float = field( - default=0.0, metadata={"help": "Gradient clipping value. Disabled if <= 0. Defaults to 0.0"} - ) - scheduler_after_epoch: bool = field( - default=True, - metadata={"help": "Step the scheduler after each epoch else step after each iteration. Defaults to True"}, - ) - # Fields for optimzation - lr: Union[float, List[float]] = field( - default=0.001, metadata={"help": "Learning rate for each optimizer. Defaults to 0.001"} - ) - optimizer: Union[str, List[str]] = field(default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"}) - optimizer_params: Union[Dict, List[Dict]] = field( - default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"} - ) - lr_scheduler: Union[str, List[str]] = field( - default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"} - ) - lr_scheduler_params: Dict = field( - default_factory=dict, metadata={"help": "Learning rate scheduler(s) arguments. Defaults to {}"} - ) - use_grad_scaler: bool = field( - default=False, - metadata={ - "help": "Enable/disable gradient scaler explicitly. It is enabled by default with AMP training. Defaults to False" - }, - ) - allow_tf32: bool = field( - default=False, - metadata={ - "help": "A bool that controls whether TensorFloat-32 tensor cores may be used in matrix multiplications on Ampere or newer GPUs. Default to False." - }, - ) - cudnn_enable: bool = field(default=True, metadata={"help": "Enable/disable cudnn explicitly. Defaults to True"}) - cudnn_deterministic: bool = field( - default=False, - metadata={ - "help": "Enable/disable deterministic cudnn operations. Set this True for reproducibility but it slows down training significantly. Defaults to False." - }, - ) - cudnn_benchmark: bool = field( - default=False, - metadata={ - "help": "Enable/disable cudnn benchmark explicitly. Set this False if your input size change constantly. Defaults to False" - }, - ) - training_seed: int = field( - default=54321, - metadata={"help": "Global seed for torch, random and numpy random number generator. Defaults to 54321"}, - ) - - -@dataclass -class TrainerArgs(Coqpit): - """Trainer arguments that can be accessed from the command line. - - Examples:: - >>> python train.py --restore_path /path/to/checkpoint.pth - """ - - continue_path: str = field( - default="", - metadata={ - "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." - }, - ) - restore_path: str = field( - default="", - metadata={ - "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." - }, - ) - best_path: str = field( - default="", - metadata={ - "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" - }, - ) - use_ddp: bool = field( - default=False, - metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}, - ) - use_accelerate: bool = field(default=False, metadata={"help": "Use HF Accelerate as the back end for training."}) - grad_accum_steps: int = field( - default=1, - metadata={ - "help": "Number of gradient accumulation steps. It is used to accumulate gradients over multiple batches." - }, - ) - overfit_batch: bool = field(default=False, metadata={"help": "Overfit a single batch for debugging."}) - skip_train_epoch: bool = field( - default=False, - metadata={"help": "Skip training and only run evaluation and test."}, - ) - start_with_eval: bool = field( - default=False, - metadata={"help": "Start with evaluation and test."}, - ) - small_run: int = field( - default=None, - metadata={ - "help": "Only use a subset of the samples for debugging. Set the number of samples to use. Defaults to None. " - }, - ) - gpu: int = field( - default=None, metadata={"help": "GPU ID to use if ```CUDA_VISIBLE_DEVICES``` is not set. Defaults to None."} - ) - # only for DDP - rank: int = field(default=0, metadata={"help": "Process rank in a distributed training. Don't set manually."}) - group_id: str = field( - default="", metadata={"help": "Process group id in a distributed training. Don't set manually."} - ) - - class Trainer: def __init__( # pylint: disable=dangerous-default-value self, @@ -290,19 +67,19 @@ def __init__( # pylint: disable=dangerous-default-value config: Coqpit, output_path: str, c_logger: ConsoleLogger = None, - dashboard_logger: "Logger" = None, + dashboard_logger: BaseDashboardLogger = None, model: nn.Module = None, - get_model: Callable = None, - get_data_samples: Callable = None, - train_samples: List = None, - eval_samples: List = None, - test_samples: List = None, + get_model: Optional[Callable] = None, + get_data_samples: Optional[Callable] = None, + train_samples: Optional[list] = None, + eval_samples: Optional[list] = None, + test_samples: Optional[list] = None, train_loader: DataLoader = None, eval_loader: DataLoader = None, - training_assets: Dict = {}, + training_assets: Optional[dict] = None, parse_command_line_args: bool = True, - callbacks: Dict[str, Callable] = {}, - gpu: int = None, + callbacks: Optional[dict[str, Callable]] = None, + gpu: Optional[int] = None, ) -> None: """Simple yet powerful 🐸💬 TTS trainer for PyTorch. @@ -390,6 +167,11 @@ def __init__( # pylint: disable=dangerous-default-value - Overfitting to a batch. - TPU training """ + if training_assets is None: + training_assets = {} + if callbacks is None: + callbacks = {} + if parse_command_line_args: # parse command-line arguments to override TrainerArgs() args, coqpit_overrides = self.parse_argv(args) @@ -533,7 +315,7 @@ def __init__( # pylint: disable=dangerous-default-value and not isimplemented(self.model, "optimize") ): raise ValueError( - " [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called ´optimize` that need to deal with dangling gradients in multiple-optimizer setup!" + " [!] Coqui Trainer does not support grad_accum_steps for multiple-optimizer setup, please set grad_accum_steps to 1 or implement in your model a custom method called `optimize` that need to deal with dangling gradients in multiple-optimizer setup!" ) # CALLBACK @@ -661,12 +443,12 @@ def save_training_script(self) -> None: if os.path.isfile(file_path): file_name = os.path.basename(file_path) self.dashboard_logger.add_artifact(file_or_dir=file_path, name=file_name, artifact_type="file") - with open(file_path, "r", encoding="utf8") as f: + with open(file_path, encoding="utf8") as f: self.dashboard_logger.add_text("training-script", f"{f.read()}", 0) shutil.copyfile(file_path, os.path.join(self.output_path, file_name)) @staticmethod - def parse_argv(args: Union[Coqpit, List]): + def parse_argv(args: Union[Coqpit, list]): """Parse command line arguments to init or override `TrainerArgs()`.""" if isinstance(args, Coqpit): parser = args.init_argparse(arg_prefix="") @@ -713,8 +495,8 @@ def setup_small_run(self, small_run: Optional[int] = None) -> None: @staticmethod def init_training( - args: TrainerArgs, coqpit_overrides: Dict, config: Coqpit = None - ) -> Tuple[Coqpit, Dict[str, str]]: + args: TrainerArgs, coqpit_overrides: dict, config: Coqpit = None + ) -> tuple[Coqpit, dict[str, str]]: """Initialize training and update model configs from command line arguments. Args: @@ -752,7 +534,7 @@ def init_training( return config, new_fields @staticmethod - def setup_training_environment(args, config, gpu) -> Tuple[bool, int]: + def setup_training_environment(args, config, gpu) -> tuple[bool, int]: if platform.system() != "Windows": # https://github.com/pytorch/pytorch/issues/973 import resource # pylint: disable=import-outside-toplevel @@ -804,11 +586,11 @@ def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Modul def restore_model( self, config: Coqpit, - restore_path: str, + restore_path: Union[str, os.PathLike[Any]], model: nn.Module, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler = None, - ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: + ) -> tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: """Restore training from an old run. It restores model, optimizer, AMP scaler and training stats. Args: @@ -886,9 +668,9 @@ def _get_loader( self, model: nn.Module, config: Coqpit, - assets: Dict, + assets: dict, is_eval: bool, - samples: List, + samples: list, verbose: bool, num_gpus: int, ) -> DataLoader: @@ -913,7 +695,7 @@ def _get_loader( ), " ❗ len(DataLoader) returns 0. Make sure your dataset is not empty or len(dataset) > 0. " return loader - def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: + def get_train_dataloader(self, training_assets: dict, samples: list, verbose: bool) -> DataLoader: """Initialize and return a training data loader. Call ```model.get_train_data_loader``` if it is implemented, else call ```model.get_data_loader``` @@ -950,7 +732,7 @@ def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bo self.num_gpus, ) - def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: + def get_eval_dataloader(self, training_assets: dict, samples: list, verbose: bool) -> DataLoader: """Initialize and return a evaluation data loader. Call ```model.get_eval_data_loader``` if it is implemented, else call ```model.get_data_loader``` @@ -987,7 +769,7 @@ def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: boo self.num_gpus, ) - def get_test_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: + def get_test_dataloader(self, training_assets: dict, samples: list, verbose: bool) -> DataLoader: """Initialize and return a evaluation data loader. Call ```model.get_test_data_loader``` if it is implemented, else call ```model.get_data_loader``` @@ -1024,7 +806,7 @@ def get_test_dataloader(self, training_assets: Dict, samples: List, verbose: boo self.num_gpus, ) - def format_batch(self, batch: List) -> Dict: + def format_batch(self, batch: list) -> dict: """Format the dataloader output and return a batch. 1. Call ```model.format_batch```. @@ -1078,8 +860,8 @@ def master_params(optimizer: torch.optim.Optimizer): @staticmethod def _model_train_step( - batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None - ) -> Tuple[Dict, Dict]: + batch: dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None + ) -> tuple[dict, dict]: """Perform a trainig forward step. Compute model outputs and losses. Args: @@ -1118,7 +900,7 @@ def _get_autocast_args(self, mixed_precision: bool, precision: str): def detach_loss_dict( self, - loss_dict: Dict, + loss_dict: dict, step_optimizer: bool, optimizer_idx: Optional[int] = None, grad_norm: Optional[float] = None, @@ -1136,7 +918,7 @@ def detach_loss_dict( loss_dict_detached["grad_norm"] = grad_norm return loss_dict_detached - def _compute_loss(self, batch: Dict, model: nn.Module, criterion: nn.Module, config: Coqpit, optimizer_idx: int): + def _compute_loss(self, batch: dict, model: nn.Module, criterion: nn.Module, config: Coqpit, optimizer_idx: int): device, dtype = self._get_autocast_args(config.mixed_precision, config.precision) with torch.autocast(device_type=device, dtype=dtype, enabled=config.mixed_precision): if optimizer_idx is not None: @@ -1162,7 +944,7 @@ def _set_grad_clip_per_optimizer(config: Coqpit, optimizer_idx: int): def _compute_grad_norm(self, optimizer: torch.optim.Optimizer): return torch.norm(torch.cat([param.grad.view(-1) for param in self.master_params(optimizer)], dim=0), p=2) - def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: "AMPScaler"): + def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler): """Perform gradient clipping""" if grad_clip is not None and grad_clip > 0: if scaler: @@ -1175,17 +957,17 @@ def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, sca def optimize( self, - batch: Dict, + batch: dict, model: nn.Module, optimizer: torch.optim.Optimizer, - scaler: "AMPScaler", + scaler: torch.amp.GradScaler, criterion: nn.Module, - scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List, Dict], # pylint: disable=protected-access + scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], # pylint: disable=protected-access config: Coqpit, optimizer_idx: Optional[int] = None, step_optimizer: bool = True, num_optimizers: int = 1, - ) -> Tuple[Dict, Dict, int]: + ) -> tuple[dict, dict, int]: """Perform a forward - backward pass and run the optimizer. Args: @@ -1297,7 +1079,7 @@ def optimize( loss_dict_detached = self.detach_loss_dict(loss_dict, step_optimizer, optimizer_idx, grad_norm) return outputs, loss_dict_detached, step_time - def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: + def train_step(self, batch: dict, batch_n_steps: int, step: int, loader_start_time: float) -> tuple[dict, dict]: """Perform a training step on a batch of inputs and log the process. Args: @@ -1417,11 +1199,11 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti lrs = {} if isinstance(self.optimizer, list): for idx, optimizer in enumerate(self.optimizer): - current_lr = self.optimizer[idx].param_groups[0]["lr"] + current_lr = optimizer.param_groups[0]["lr"] lrs.update({f"current_lr_{idx}": current_lr}) elif isinstance(self.optimizer, dict): for key, optimizer in self.optimizer.items(): - current_lr = self.optimizer[key].param_groups[0]["lr"] + current_lr = optimizer.param_groups[0]["lr"] lrs.update({f"current_lr_{key}": current_lr}) else: current_lr = self.optimizer.param_groups[0]["lr"] @@ -1533,8 +1315,8 @@ def train_epoch(self) -> None: ####################### def _model_eval_step( - self, batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None - ) -> Tuple[Dict, Dict]: + self, batch: dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None + ) -> tuple[dict, dict]: """Perform a evaluation forward pass. Compute model outputs and losses with no gradients. Args: @@ -1560,7 +1342,7 @@ def _model_eval_step( return model.eval_step(*input_args) - def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: + def eval_step(self, batch: dict, step: int) -> tuple[dict, dict]: """Perform a evaluation step on a batch of inputs and log the process. Args: @@ -1762,7 +1544,7 @@ def _fit(self) -> None: self.total_steps_done = self.restore_step - for epoch in range(0, self.config.epochs): + for epoch in range(self.config.epochs): if self.num_gpus > 1: # let all processes sync up before starting with a new epoch of training dist.barrier() @@ -1969,7 +1751,7 @@ def update_training_dashboard_logger(self, batch=None, outputs=None) -> None: ##################### @staticmethod - def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: + def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, list]: """Receive the optimizer from the model if model implements `get_optimizer()` else check the optimizer parameters in the config and try initiating the optimizer. @@ -1993,7 +1775,7 @@ def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimiz return optimizer @staticmethod - def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]: + def get_lr(model: nn.Module, config: Coqpit) -> Union[float, list[float]]: """Set the initial learning rate by the model if model implements `get_lr()` else try setting the learning rate fromthe config. @@ -2016,8 +1798,8 @@ def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]: @staticmethod def get_scheduler( - model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List, Dict] - ) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access + model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, list, dict] + ) -> Union[torch.optim.lr_scheduler._LRScheduler, list]: # pylint: disable=protected-access """Receive the scheduler from the model if model implements `get_scheduler()` else check the config and try initiating the scheduler. @@ -2046,8 +1828,12 @@ def get_scheduler( @staticmethod def restore_scheduler( - scheduler: Union["Scheduler", List, Dict], args: Coqpit, config: Coqpit, restore_epoch: int, restore_step: int - ) -> Union["Scheduler", List]: + scheduler: Union[torch.optim.lr_scheduler._LRScheduler, list, dict], + args: Coqpit, + config: Coqpit, + restore_epoch: int, + restore_step: int, + ) -> Union[torch.optim.lr_scheduler._LRScheduler, list]: """Restore scheduler wrt restored model.""" if scheduler is not None and args.continue_path: if isinstance(scheduler, list): @@ -2087,7 +1873,7 @@ def get_criterion(model: nn.Module) -> nn.Module: #################### @staticmethod - def _detach_loss_dict(loss_dict: Dict) -> Dict: + def _detach_loss_dict(loss_dict: dict) -> dict: """Detach loss values from autograp. Args: @@ -2104,7 +1890,7 @@ def _detach_loss_dict(loss_dict: Dict) -> Dict: loss_dict_detached[key] = value.detach().cpu().item() return loss_dict_detached - def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: + def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> dict: """Pick the target loss to compare models""" # if the keep_avg_target is None or empty return None diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py index b369d57..2a11c19 100644 --- a/trainer/trainer_utils.py +++ b/trainer/trainer_utils.py @@ -1,11 +1,12 @@ import importlib import os import random -from typing import Dict, List, Tuple +from typing import Optional import numpy as np import torch +from trainer.config import TrainerArgs from trainer.logger import logger from trainer.torch import NoamLR, StepwiseGradualLR from trainer.utils.distributed import rank_zero_logger_info @@ -61,7 +62,7 @@ def print_training_env(args, config): def setup_torch_training_env( - args: "TrainerArgs", + args: TrainerArgs, cudnn_enable: bool, cudnn_benchmark: bool, cudnn_deterministic: bool, @@ -69,7 +70,7 @@ def setup_torch_training_env( training_seed=54321, allow_tf32: bool = False, gpu=None, -) -> Tuple[bool, int]: +) -> tuple[bool, int]: """Setup PyTorch environment for training. Args: @@ -119,7 +120,7 @@ def setup_torch_training_env( def get_scheduler( - lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer + lr_scheduler: str, lr_scheduler_params: dict, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access """Find, initialize and return a Torch scheduler. @@ -147,7 +148,7 @@ def get_optimizer( optimizer_params: dict, lr: float, model: torch.nn.Module = None, - parameters: List = None, + parameters: Optional[list] = None, ) -> torch.optim.Optimizer: """Find, initialize and return a Torch optimizer. @@ -162,7 +163,7 @@ def get_optimizer( """ if optimizer_name.lower() == "radam": module = importlib.import_module("TTS.utils.radam") - optimizer = getattr(module, "RAdam") + optimizer = module.RAdam else: optimizer = getattr(torch.optim, optimizer_name) if model is not None: diff --git a/trainer/utils/cuda_memory.py b/trainer/utils/cuda_memory.py index 3eba6d5..5e9c310 100644 --- a/trainer/utils/cuda_memory.py +++ b/trainer/utils/cuda_memory.py @@ -5,8 +5,6 @@ because of OOM conditions. """ -from __future__ import print_function - import gc import torch