Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array API #1022

Open
wants to merge 163 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
163 commits
Select commit Hold shift + click to select a range
ed334f8
Create `array_api` namespace
neosunhan Jun 14, 2022
ec0ebdc
Enable `copy` parameter for `asarray`
neosunhan Jun 14, 2022
77813fc
Add `iinfo` and `finfo`
neosunhan Jun 14, 2022
38461c7
Add constants
neosunhan Jun 15, 2022
ae1defb
Modify `iinfo` and `finfo` to read minimum value from torch object
neosunhan Jun 15, 2022
313d22d
Add docstrings
neosunhan Jun 15, 2022
136e1b7
Fix bug where `keepdim` was not working for `all`
neosunhan Jun 15, 2022
9c6b344
Implement `all`
neosunhan Jun 16, 2022
ab30f56
Implement `Array.__eq__`
neosunhan Jun 16, 2022
198a38e
Implement `astype`
neosunhan Jun 16, 2022
2e8cfc3
Implement `Array` object
neosunhan Jun 17, 2022
fed2952
Implement `add`
neosunhan Jun 17, 2022
accbd15
Standardize import order
neosunhan Jun 17, 2022
de16eac
Document function parameters
neosunhan Jun 17, 2022
f0a169b
Fix bug causing precision loss in `abs`
neosunhan Jun 17, 2022
1e0c29b
Implement `abs`
neosunhan Jun 17, 2022
9bcbb01
Implement `add`
neosunhan Jun 17, 2022
ff53a95
Implement `bitwise_and`
neosunhan Jun 17, 2022
afe0b95
Implement `bitwise_left_shift`, `less`, `any`
neosunhan Jun 17, 2022
2ef6e3a
Implement `bitwise_invert`
neosunhan Jun 17, 2022
574f469
Implement `bitwise_or`
neosunhan Jun 17, 2022
b684c86
Fix documentation for `_normalize_two_args`
neosunhan Jun 18, 2022
6c50d7b
Fix bug in `linalg.matrix_norm`
neosunhan Jun 18, 2022
9fb9bdd
Implement `bitwise_xor`
neosunhan Jun 18, 2022
ad655ee
Implement `equal`
neosunhan Jun 18, 2022
41143f7
Implement `floor_divide`
neosunhan Jun 18, 2022
aeb97e4
Implement `greater_equal`
neosunhan Jun 18, 2022
58308db
Implement `greater`
neosunhan Jun 18, 2022
544ab0e
Implement `less_equal`
neosunhan Jun 18, 2022
7ff5119
Implement `remainder`
neosunhan Jun 18, 2022
fd62786
Implement `multiply`
neosunhan Jun 18, 2022
48b66cf
Implement `not_equal`
neosunhan Jun 18, 2022
1c586cc
Implement `negative`
neosunhan Jun 18, 2022
ba8a78c
Implement `positive`
neosunhan Jun 18, 2022
bce5b0a
Implement `pow`
neosunhan Jun 18, 2022
e6236b7
Implement `subtract`
neosunhan Jun 18, 2022
0c93d30
Implement `divide`
neosunhan Jun 18, 2022
a9132ce
Implement reflected arithmetic operators
neosunhan Jun 19, 2022
6533434
Implement `empty`
neosunhan Jun 29, 2022
cc3b03d
Fix bug causing precision loss in `arange`
neosunhan Jun 30, 2022
f99e427
Implement `empty_like`
neosunhan Jun 30, 2022
0b70ff5
Implement `full_like`
neosunhan Jun 30, 2022
1291c15
Allow `linspace` to accept `num=0`
neosunhan Jun 30, 2022
423f19a
Fix precision bug in `linspace`
neosunhan Jun 30, 2022
a0f2080
Enable bool dtype for `arange`
neosunhan Jun 30, 2022
d3318fd
Fix tests for `logspace`
neosunhan Jun 30, 2022
720880d
Fix `abs` tests
neosunhan Jun 30, 2022
f377d5d
Fix `abs` tests
neosunhan Jun 30, 2022
08d99cc
Implement `meshgrid`
neosunhan Jun 30, 2022
694a549
Implement `ones`
neosunhan Jun 30, 2022
69fd0f8
Implement `ones_like`
neosunhan Jun 30, 2022
f966061
Implement `zeros_like` and `full`
neosunhan Jun 30, 2022
778dd83
Implement `result_type`
neosunhan Jun 30, 2022
4857e61
Implement `expand_dims`
neosunhan Jun 30, 2022
fb8ebf8
Implement `flip`
neosunhan Jun 30, 2022
3043a4f
Implement `permute_dims`
neosunhan Jun 30, 2022
2d2df2e
Implement `roll`
neosunhan Jun 30, 2022
76a2ba0
Implement `squeeze`
neosunhan Jun 30, 2022
a67a9a0
Implement `can_cast`
neosunhan Jul 6, 2022
1cbe90e
Implement `acos`
neosunhan Jul 6, 2022
da0a42c
Implement `acosh`
neosunhan Jul 6, 2022
c37effd
Implement `asin` and `asinh`
neosunhan Jul 6, 2022
a3a3644
Implement `atan`, `atan2`, `atanh`
neosunhan Jul 6, 2022
058873a
Implement `ceil`
neosunhan Jul 6, 2022
50b4054
Implement `cos` and `cosh`
neosunhan Jul 6, 2022
b26ccd9
Implement `exp` and `expm1`
neosunhan Jul 6, 2022
666855a
Implement `floor`
neosunhan Jul 6, 2022
4ce40dc
Implement `log`, `log1p`, `log2`, `log10`, `logaddexp`
neosunhan Jul 6, 2022
19294dd
Implement `logical_and`, `logical_not`, `logical_or`, `logical_xor`
neosunhan Jul 6, 2022
a487a4a
Implement `round`
neosunhan Jul 6, 2022
fd767ce
Implement `sign`
neosunhan Jul 6, 2022
2c7dd1a
Implement `sin` and `sinh`
neosunhan Jul 6, 2022
05582b6
Implement `square` and `sqrt`
neosunhan Jul 6, 2022
0936cf9
Implement `tan` and `tanh`
neosunhan Jul 6, 2022
3ee7112
Implement `trunc`
neosunhan Jul 6, 2022
4febd58
Implement `argmax` and `argmin`
neosunhan Jul 6, 2022
d6949f4
Implement `vecdot`
neosunhan Jul 6, 2022
970c1fe
Implement `unique_values`
neosunhan Jul 6, 2022
cca8c2a
Implement `max` and `min`
neosunhan Jul 6, 2022
7246336
Implement `mean`
neosunhan Jul 6, 2022
7cf980c
Implement `std`
neosunhan Jul 6, 2022
438fe1b
Fix documentation
neosunhan Jul 6, 2022
f10b2cb
Add dlpack functions
neosunhan Sep 4, 2022
e216bc5
Merge branch 'main' into array-api
neosunhan Sep 4, 2022
6473b70
Extract changes made in core module
neosunhan Sep 4, 2022
d5fbfda
Restore original `abs`
neosunhan Sep 4, 2022
f81048a
Fix documentation
neosunhan Sep 4, 2022
824fa9b
Add `from_dlpack`
neosunhan Sep 5, 2022
fccadf3
Merge branch 'main' into array-api
neosunhan Sep 13, 2022
ca588a0
Add ci tests
neosunhan Sep 30, 2022
df954d9
Merge branch 'array-api' of github.com:helmholtz-analytics/heat into …
neosunhan Sep 30, 2022
31e996a
Merge branch 'main' into array-api
neosunhan Sep 30, 2022
9e83881
Merge branch 'main' into array-api
ClaudiaComito Feb 10, 2023
47867e1
Merge branch 'main' into array-api
ClaudiaComito May 17, 2023
b25be5f
[skip ci] install MPI, update API reference, introduce manual trigger
ClaudiaComito May 17, 2023
c954cab
[skip ci] trigger on PR label
ClaudiaComito May 17, 2023
74234ac
adding back original trigger
ClaudiaComito May 17, 2023
a9f80b4
back to original trigger
ClaudiaComito May 17, 2023
bfe3f12
[skip ci] fix workflow indentation
ClaudiaComito May 17, 2023
6d4706b
fix URL of test repository
ClaudiaComito May 17, 2023
afc9488
Merge branch 'main' into array-api
ClaudiaComito Aug 28, 2023
b96ad90
Merge branch 'main' into array-api
ClaudiaComito Sep 4, 2023
dcde5a3
Merge branch 'main' into array-api
ClaudiaComito Sep 5, 2023
509c19e
Test Python 3.9 and 3.10
ClaudiaComito Sep 5, 2023
2fd9ad2
Update array-api.yml
mtar Sep 11, 2023
57feb55
change skipfile path
mtar Feb 16, 2024
8d5946e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2024
e8f7771
Merge branch 'main' into array-api
mtar Feb 19, 2024
7a56c52
Update array-api.yml
mtar Feb 19, 2024
19f5a75
Update array-api.yml
mtar Feb 19, 2024
f729a4e
Update array-api.yml
mtar Feb 19, 2024
e8c9917
Update array-api.yml
mtar Feb 19, 2024
407adfb
Update array-api.yml
mtar Feb 19, 2024
c2a91fe
Update array-api.yml
mtar Feb 19, 2024
b0e15c8
Update array-api.yml
mtar Feb 20, 2024
63442ab
Update array-api.yml
mtar Feb 20, 2024
735314f
Update array-api.yml
mtar Feb 20, 2024
d12e877
Update array-api.yml
mtar Feb 20, 2024
d5a0fee
Update array-api.yml
mtar Feb 20, 2024
a96e91a
update skips file
mtar Feb 20, 2024
4b2c7df
Update array-api.yml
mtar Feb 20, 2024
83bb9a4
Update array-api.yml
mtar Feb 20, 2024
92eace9
skips
mtar Feb 20, 2024
59c64c6
update skips.txt
mtar Feb 20, 2024
3b59826
Update skips.txt
mtar Mar 7, 2024
6c62126
Update skips.txt
mtar Mar 7, 2024
7143211
Update array-api.yml
mtar Mar 8, 2024
1522115
Update array-api.yml
mtar Mar 8, 2024
463a205
Update array-api.yml
mtar Mar 8, 2024
ff8fad1
Update array-api.yml
mtar Mar 8, 2024
3910263
Update array-api.yml
mtar Mar 8, 2024
70aa22b
Update array-api.yml
mtar Mar 8, 2024
c79f4be
Update array-api.yml
mtar Mar 8, 2024
c17464d
Update __init__.py
mtar Mar 8, 2024
4fedbfe
skip failing tests
mtar Mar 11, 2024
5db33f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2024
981aba6
Update array-api.yml
mtar Mar 11, 2024
140fd4c
Update array-api.yml
mtar Mar 11, 2024
b18cf98
update skips
mtar Mar 11, 2024
acbfa05
update skips
mtar Mar 11, 2024
188303a
update skips.txt
mtar Mar 11, 2024
1c9ede0
Merge branch 'main' into array-api
mtar Mar 11, 2024
7fff72b
Update skips.txt
mtar Mar 12, 2024
3b86d54
Update array-api.yml
mtar Mar 27, 2024
cd76039
Merge branch 'main' into array-api
mtar Jun 12, 2024
6b13c3d
add complex types
mtar Jul 10, 2024
6fe8134
update skips.txt/xfails.txt
mtar Jul 11, 2024
74e2c91
update skips
mtar Jul 11, 2024
2548bdb
update skips/xfails
mtar Jul 11, 2024
d95e4f0
update skips/xfails
mtar Jul 11, 2024
ccdf750
put functions in limbo into skips
mtar Jul 12, 2024
36322fb
update skips
mtar Jul 12, 2024
80fbe82
update files
mtar Jul 15, 2024
6e0ef51
add more skips
mtar Jul 15, 2024
56655af
add more skips
mtar Jul 15, 2024
75bd670
even more skips
mtar Jul 15, 2024
64a9cc2
much more skips
mtar Jul 15, 2024
a20c376
and more
mtar Jul 15, 2024
e059414
change to manual trigger
mtar Jul 16, 2024
01b2f4e
Merge branch 'main' into array-api
mtar Jul 16, 2024
979b95c
Merge branch 'main' into array-api
ClaudiaComito Aug 26, 2024
f4e494f
add contributor
ClaudiaComito Aug 26, 2024
980915d
Merge branch 'array-api' of github.com:helmholtz-analytics/heat into …
ClaudiaComito Aug 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/array-api.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Test Array API

on:
workflow_dispatch:

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
mpi: [ 'openmpi' ]

steps:
- name: Checkout
uses: actions/checkout@v3
with:
path: heat
- name: Setup MPI
uses: mpi4py/setup-mpi@v1
with:
mpi: ${{ matrix.mpi }}
- name: Use Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
architecture: x64
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install ${GITHUB_WORKSPACE}/heat/
- name: Checkout array-api-tests
uses: actions/checkout@v3
with:
repository: data-apis/array-api-tests
path: array-api-tests
submodules: 'true'
- name: Install dependencies
run: |
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: heat.array_api
run: |
export PYTHONPATH="${GITHUB_WORKSPACE}/heat"
# Skip testing functions with known issues
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest array_api_tests/ -v -rxXfE --ci --xfails-file ${GITHUB_WORKSPACE}/heat/heat/array_api/test/xfails.txt --skips-file ${GITHUB_WORKSPACE}/heat/heat/array_api/test/skips.txt --disable-extension linalg
2 changes: 2 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ authors:
- family-names: Tarnawa
given-names: Michael
# release contributors - add as needed
- family-names: Neo
given-names: Sun Han
repository-code: 'https://github.com/helmholtz-analytics/heat'
url: 'https://helmholtz-analytics.github.io/heat/'
repository: 'https://heat.readthedocs.io/en/stable/'
Expand Down
275 changes: 275 additions & 0 deletions heat/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""
A Heat sub-namespace that conforms to the Python array API standard.
"""

import warnings

warnings.warn("The heat.array_api submodule is not fully implemented.", stacklevel=2)

__array_api_version__ = "2023.12"

__all__ = ["__array_api_version__"]

from ._constants import e, inf, nan, newaxis, pi

__all__ += ["e", "inf", "nan", "newaxis", "pi"]

from ._creation_functions import (
arange,
asarray,
empty,
empty_like,
eye,
from_dlpack,
full,
full_like,
linspace,
meshgrid,
ones,
ones_like,
tril,
triu,
zeros,
zeros_like,
)

__all__ += [
"arange",
"asarray",
"empty",
"empty_like",
"eye",
"from_dlpack",
"full",
"full_like",
"linspace",
"meshgrid",
"ones",
"ones_like",
"tril",
"triu",
"zeros",
"zeros_like",
]

from ._data_type_functions import (
astype,
broadcast_arrays,
broadcast_to,
can_cast,
finfo,
iinfo,
result_type,
)

__all__ += [
"astype",
"broadcast_arrays",
"broadcast_to",
"can_cast",
"finfo",
"iinfo",
"result_type",
]

from heat.core.devices import cpu

__all__ += ["cpu"]

import heat.core.devices

if hasattr(heat.core.devices, "gpu"):
from heat.core.devices import gpu

__all__ += ["gpu"]

from ._dtypes import (
bool,
int8,
int16,
int32,
int64,
uint8,
# uint16,
# uint32,
# uint64,
float32,
float64,
complex64,
complex128,
)

__all__ += [
"bool",
"int8",
"int16",
"int32",
"int64",
"uint8",
# "uint16",
# "uint32",
# "uint64",
"float32",
"float64",
"complex64",
"complex128",
]

from ._elementwise_functions import (
abs,
acos,
acosh,
add,
asin,
asinh,
atan,
atan2,
atanh,
bitwise_and,
bitwise_left_shift,
bitwise_invert,
bitwise_or,
bitwise_right_shift,
bitwise_xor,
ceil,
cos,
cosh,
divide,
equal,
exp,
expm1,
floor,
floor_divide,
greater,
greater_equal,
isfinite,
isinf,
isnan,
less,
less_equal,
log,
log1p,
log2,
log10,
logaddexp,
logical_and,
logical_not,
logical_or,
logical_xor,
multiply,
negative,
not_equal,
positive,
pow,
remainder,
round,
sign,
sin,
sinh,
square,
sqrt,
subtract,
tan,
tanh,
trunc,
)

__all__ += [
"abs",
"acos",
"acosh",
"add",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"bitwise_and",
"bitwise_left_shift",
"bitwise_invert",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"cos",
"cosh",
"divide",
"equal",
"exp",
"expm1",
"floor",
"floor_divide",
"greater",
"greater_equal",
"isfinite",
"isinf",
"isnan",
"less",
"less_equal",
"log",
"log1p",
"log2",
"log10",
"logaddexp",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"multiply",
"negative",
"not_equal",
"positive",
"pow",
"remainder",
"round",
"sign",
"sin",
"sinh",
"square",
"sqrt",
"subtract",
"tan",
"tanh",
"trunc",
]

from . import linalg

__all__ += ["linalg"]

from .linalg import matmul, matrix_transpose, tensordot, vecdot

__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"]

from ._manipulation_functions import (
concat,
expand_dims,
flip,
permute_dims,
reshape,
roll,
squeeze,
stack,
)

__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]

from ._searching_functions import argmax, argmin, nonzero, where

__all__ += ["argmax", "argmin", "nonzero", "where"]

from ._set_functions import unique_inverse, unique_values

__all__ += ["unique_inverse", "unique_values"]

from ._sorting_functions import sort

__all__ += ["sort"]

from ._statistical_functions import max, mean, min, prod, std, sum, var

__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]

from ._utility_functions import all, any

__all__ += ["all", "any"]
Loading