Skip to content

Commit

Permalink
Merge pull request #365 from FlorianPfaff/pytorch-backend
Browse files Browse the repository at this point in the history
Can now change backend to pytorch
  • Loading branch information
FlorianPfaff authored Oct 26, 2023
2 parents 0d23a92 + 7f8337f commit 552c9da
Show file tree
Hide file tree
Showing 203 changed files with 8,612 additions and 3,206 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/mega-linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ jobs:
key: ${{ runner.os }}-alpine-wheels-${{ hashFiles('requirements-dev.txt') }}
restore-keys: |
${{ runner.os }}-alpine-wheels-${{ hashFiles('requirements-dev.txt') }}
- name: Set up Alpine Linux
if: steps.cache-wheels.outputs.cache-hit != 'true'
uses: jirutka/setup-alpine@v1
Expand All @@ -76,7 +75,6 @@ jobs:
py3-pkgconfig
curl-dev
zlib-dev
- name: List workspace
run: ls -l .

Expand All @@ -93,6 +91,12 @@ jobs:
sed 's/==.*//' requirements-dev.txt > requirements-dev_no_version.txt
shell: alpine.sh {0}

- name: Remove torch entry (unsupported by alpine)
if: steps.cache-wheels.outputs.cache-hit != 'true'
run: |
sed -i '/^torch/d' requirements-dev_no_version.txt
shell: alpine.sh {0}

- name: Run CMake to find LAPACK
if: steps.cache-wheels.outputs.cache-hit != 'true'
run: |
Expand Down Expand Up @@ -186,7 +190,6 @@ jobs:
path: |
megalinter-reports
mega-linter.log
# Create Pull Request step
- name: Create Pull Request with applied fixes
id: cpr
Expand All @@ -206,7 +209,6 @@ jobs:
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
# Push new commit if applicable (for now works only on PR from same repository, not from forks)
- name: Prepare commit
if: steps.ml.outputs.has_updated_sources == 1 && (env.APPLY_FIXES_EVENT == 'all' || env.APPLY_FIXES_EVENT == github.event_name) && env.APPLY_FIXES_MODE == 'commit' && github.ref != 'refs/heads/main' && (github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository) && !contains(github.event.head_commit.message, 'skip fix')
Expand Down
30 changes: 23 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,41 @@ jobs:

- name: Install dependencies
run: |
export CUDA_VISIBLE_DEVICES=""
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --extras healpy_support
poetry env use python
poetry install --extras "healpy_support" --extras "pytorch_support"
poetry run python -m pip install torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: List files and check Python and package versions
run: |
ls -al
python -c 'import sys; print(sys.version_info[:])'
python -m pip freeze
poetry env use python
poetry run python -c 'import sys; print(sys.version_info[:])'
poetry run python -m pip freeze
poetry run python -c "import torch; print(torch.version.cuda)"
- name: Run tests with numpy backend
run: |
poetry env use python
export PYRECEST_BACKEND=numpy
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results_numpy.xml ./pyrecest
env:
PYTHONPATH: ${{ github.workspace }}

- name: Run tests
- name: Run tests with pytorch backend
if: always()
run: |
poetry env use python
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results.xml ./pyrecest
export PYRECEST_BACKEND=pytorch
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results_pytorch.xml ./pyrecest
env:
PYTHONPATH: ${{ github.workspace }}

- name: Publish test results
if: always()
uses: EnricoMi/publish-unit-test-result-action@v2
with:
files: junit_test_results.xml
files: |
junit_test_results_numpy.xml
junit_test_results_pytorch.xml
2 changes: 1 addition & 1 deletion .github/workflows/update-requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: python -m poetry update

- name: Update requirements.txt
run: python -m poetry export --format requirements.txt --output requirements.txt --extras healpy_support --without-hashes
run: python -m poetry export --format requirements.txt --output requirements.txt --extras healpy_support --extras pytorch_support --without-hashes

- name: Update requirements-dev.txt
run: python -m poetry export --with dev --format requirements.txt --output requirements-dev.txt --without-hashes
Expand Down
5 changes: 5 additions & 0 deletions .jscpd.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"ignore": [
"pyrecest/_backend/**"
]
}
2 changes: 2 additions & 0 deletions .mega-linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ DISABLE_LINTERS:
- JSON_JSONLINT # Disable because there is only .devcontainer.json, for which it throws an unwanted warning
- MAKEFILE_CHECKMAKE # Not using a Makefile
- SPELL_LYCHEE # Takes pretty long

FILTER_REGEX_EXCLUDE: "pyrecest/_backend/*"
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ ignore-patterns=^\.#
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
ignored-modules=pyrecest.backend

# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

*Recursive Bayesian Estimation for Python*

pyRecEst is a Python library designed for recursive Bayesian estimation. It is currently unstable and lacks a lot of features. Use with caution.
pyRecEst is a Python library designed for recursive Bayesian estimation, which supports numpy and pytorch as backends. It is currently unstable and lacks a lot of features. Use with caution.

## Usage

Expand All @@ -12,7 +12,7 @@ Please refer to the test cases for usage examples.

- Florian Pfaff (<pfaff@kit.edu>)

pyRecEst borrows its structure from libDirectional and follows its code closely for many classes. libDirectional, a project to which I contributed extensively, is [available on GitHub](https://github.com/libDirectional).
pyRecEst borrows its structure from libDirectional and follows its code closely for many classes. libDirectional, a project to which I contributed extensively, is [available on GitHub](https://github.com/libDirectional). The backend implementations are based on those of [geomstats](https://github.com/geomstats/geomstats).

## License
`pyRecEst` is licensed under the MIT License.
Loading

0 comments on commit 552c9da

Please sign in to comment.