diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index b0ce007..7801023 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -7,6 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask + package-version: '>= 2024.9.0' module-name: dask.array extra-requires: numpy pytest-extra-args: --disable-deadline --max-examples=5 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 254e4e6..e0d5d84 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -40,7 +40,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + # min version of dask we needs drops support for python 3.9 + python-version: ${{ inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']') || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }} steps: - name: Checkout array-api-compat diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 03e0cd7..ce0e609 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -6,3 +6,4 @@ __array_api_version__ = '2022.12' __import__(__package__ + '.linalg') +__import__(__package__ + '.fft') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index cf57c82..a24694f 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -9,13 +9,9 @@ import numpy as np from numpy import ( - # Constants - e, - inf, - nan, - pi, - newaxis, # Dtypes + iinfo, + finfo, bool_ as bool, float32, float64, @@ -29,8 +25,6 @@ uint64, complex64, complex128, - iinfo, - finfo, can_cast, result_type, ) @@ -206,19 +200,18 @@ def _isscalar(a): return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) -# exclude these from all since +# exclude these from all since dask.array has no sorting functions _da_unsupported = ['sort', 'argsort'] -common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] +_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] -__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool', - 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'e', - 'inf', 'nan', 'pi', 'newaxis', 'float32', - 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', - 'can_cast', 'result_type'] +__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', + 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', + 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', + 'can_cast', 'result_type'] -_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] +_all_ignore = ["get_xp", "da", "np"] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py new file mode 100644 index 0000000..aebd86f --- /dev/null +++ b/array_api_compat/dask/array/fft.py @@ -0,0 +1,24 @@ +from dask.array.fft import * # noqa: F403 +# dask.array.fft doesn't have __all__. If it is added, replace this with +# +# from dask.array.fft import __all__ as linalg_all +_n = {} +exec('from dask.array.fft import *', _n) +del _n['__builtins__'] +fft_all = list(_n) +del _n + +from ...common import _fft +from ..._internal import get_xp + +import dask.array as da + +fftfreq = get_xp(da)(_fft.fftfreq) +rfftfreq = get_xp(da)(_fft.rfftfreq) + +__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] + +del get_xp +del da +del fft_all +del _fft diff --git a/dask-skips.txt b/dask-skips.txt index 2a67d75..63a09e4 100644 --- a/dask-skips.txt +++ b/dask-skips.txt @@ -1,17 +1,2 @@ -# FFT isn't conformant -array_api_tests/test_fft.py -array_api_tests/test_signatures.py::test_extension_func_signature[fft.fft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfftn] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.hfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.ihfft] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftfreq] -array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftfreq] - # slow and not implemented in dask array_api_tests/test_linalg.py::test_matrix_power