Skip to content

Commit

Permalink
use jax.experimental.array_api (#9)
Browse files Browse the repository at this point in the history
* use jax.experimental.array_api
* disable beartype checking (until we can get stuff to pass)

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
nstarman and pre-commit-ci[bot] authored Dec 31, 2023
1 parent 18c163d commit 440fdf6
Show file tree
Hide file tree
Showing 21 changed files with 3,299 additions and 336 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ repos:
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: tests/myarray.py
- id: requirements-txt-fixer
- id: trailing-whitespace

Expand Down Expand Up @@ -55,7 +56,7 @@ repos:
rev: "v1.8.0"
hooks:
- id: mypy
files: src|tests
files: src
args: []
additional_dependencies:
- numpy
Expand Down
22 changes: 13 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = [
"error",
# jaxtyping
"ignore:ast\\.Str is deprecated and will be removed in Python 3.14:DeprecationWarning",
]
log_cli_level = "INFO"
testpaths = [
Expand All @@ -90,7 +92,7 @@ port.exclude_lines = [
]

[tool.mypy]
files = ["src", "tests"]
files = ["src"]
python_version = "3.10"
warn_unused_configs = true
strict = true
Expand All @@ -111,6 +113,7 @@ plugins = [
[[tool.mypy.overrides]]
module = [
"jax.*",
"jaxtyping.*",
"plum.*",
"quax.*",
]
Expand All @@ -124,23 +127,24 @@ src = ["src"]
[tool.ruff.lint]
extend-select = ["ALL"]
ignore = [
"A001", # Variable is shadowing a Python builtin
"A002", # Argument is shadowing a Python builtin
"A001", # Variable is shadowing a Python builtin
"A002", # Argument is shadowing a Python builtin
"ANN101", # Missing type annotation for self in method
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed # TODO
"ARG001", # Unused function argument
"D103", # Missing docstring in public function # TODO
"D203", # one-blank-line-before-class
"D213", # Multi-line docstring summary should start at the second line
"D103", # Missing docstring in public function # TODO
"D203", # one-blank-line-before-class
"D213", # Multi-line docstring summary should start at the second line
"ERA001", # Found commented-out code
"FIX002", # Line contains TODO, consider resolving the issue
"PD011", # Pandas
"PYI041", # Use `float` instead of `int | float`
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): .
"TD003", # Missing issue link on the line following this TODO
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): .
"TD003", # Missing issue link on the line following this TODO
]

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["INP001", "S101", "T20"]
"tests/**" = ["ANN", "INP001", "PLR0913", "S101", "T20"]
"__init__.py" = ["F403"]
"noxfile.py" = ["T20"]
"docs/conf.py" = ["INP001"]
Expand Down
64 changes: 34 additions & 30 deletions src/array_api_jax_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,41 @@

from typing import Any

from . import (
_constants,
_creation_functions,
_data_type_functions,
_elementwise_functions,
_indexing_functions,
_linear_algebra_functions,
_manipulation_functions,
_searching_functions,
_set_functions,
_sorting_functions,
_statistical_functions,
_utility_functions,
fft,
linalg,
)
from ._constants import *
from ._creation_functions import *
from ._data_type_functions import *
from ._elementwise_functions import *
from ._indexing_functions import *
from ._linear_algebra_functions import *
from ._manipulation_functions import *
from ._searching_functions import *
from ._set_functions import *
from ._sorting_functions import *
from ._statistical_functions import *
from ._utility_functions import *
from ._version import version as __version__
from jax.experimental.array_api import __array_api_version__
from jaxtyping import install_import_hook

__all__ = ["__version__", "fft", "linalg"]
with install_import_hook("array_api_jax_compat", None):
from . import (
_constants,
_creation_functions,
_data_type_functions,
_elementwise_functions,
_indexing_functions,
_linear_algebra_functions,
_manipulation_functions,
_searching_functions,
_set_functions,
_sorting_functions,
_statistical_functions,
_utility_functions,
fft,
linalg,
)
from ._constants import *
from ._creation_functions import *
from ._data_type_functions import *
from ._elementwise_functions import *
from ._indexing_functions import *
from ._linear_algebra_functions import *
from ._manipulation_functions import *
from ._searching_functions import *
from ._set_functions import *
from ._sorting_functions import *
from ._statistical_functions import *
from ._utility_functions import *
from ._version import version as __version__

__all__ = ["__version__", "__array_api_version__", "fft", "linalg"]
__all__ += _constants.__all__
__all__ += _creation_functions.__all__
__all__ += _data_type_functions.__all__
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_jax_compat/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["e", "inf", "nan", "newaxis", "pi"]

from jax.numpy import e, inf, nan, newaxis, pi
from jax.experimental.array_api import e, inf, nan, newaxis, pi
Loading

0 comments on commit 440fdf6

Please sign in to comment.