Skip to content

Commit

Permalink
chore: run nightly tf in CI and fix compatibility with TF 2.9.0 (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Apr 23, 2022
1 parent 7f79748 commit 6e01458
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
push:
branches: [ master ]
pull_request:
schedule:
# run every day at midnight
- cron: "0 0 * * *"

jobs:
Linting:
Expand All @@ -26,7 +29,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]
python-version: ["3.7", "3.8", "3.9"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -58,7 +61,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: ["3.9"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -100,7 +103,7 @@ jobs:
strategy:
matrix:
tf-version: [2.7.0]
python-version: [3.7, 3.8]
python-version: ["3.7", "3.9"]
sklearn-version: [1.0.0]

steps:
Expand Down Expand Up @@ -135,7 +138,7 @@ jobs:
strategy:
matrix:
os: [MacOS, Windows] # test all OSs (except Ubuntu, which is already running other tests)
python-version: [3.7, 3.9] # test only the two extremes of supported Python versions
python-version: ["3.7", "3.9"] # test only the two extremes of supported Python versions

steps:
- uses: actions/checkout@v2
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repository = "https://github.com/adriangb/scikeras"
version = "0.6.1"

[tool.poetry.dependencies]
importlib-metadata = {version = "^3", python = "<3.8"}
importlib-metadata = {version = ">=3", python = "<3.8"}
python = ">=3.7.0,<3.10.0"
scikit-learn = ">=1.0.0"
packaging = ">=0.21,<22.0"
Expand All @@ -43,7 +43,6 @@ tensorflow-cpu = ["tensorflow-cpu"]
[tool.poetry.dev-dependencies]
tensorflow = ">=2.7.0"
coverage = {extras = ["toml"], version = ">=5.4"}
dataclasses = {version = "^0.8", python = "<3.7"}
insipid-sphinx-theme = ">=0.2.2"
ipykernel = ">=5.4.2"
jupyter = ">=1.0.0"
Expand All @@ -54,7 +53,6 @@ numpydoc = ">=1.1.0"
pre-commit = ">=2.10.1"
pytest = ">=6.2.2"
pytest-cov = ">=2.11.1"
pytest-sugar = "v0.9.4"
sphinx = ">=3.2.1"

[tool.isort]
Expand Down Expand Up @@ -84,5 +82,5 @@ source = ["scikeras/"]
show_missing = true

[build-system]
build-backend = "poetry.masonry.api"
requires = ["poetry>=1.0.10"]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core>=1.0.8"]
5 changes: 5 additions & 0 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,11 @@ def _check_array_dtype(arr, force_numeric):
f"X has {len(X_shape_)} dimensions, but this {self.__name__}"
f" is expecting {len(self.X_shape_)} dimensions in X."
)
if X_shape_[1:] != self.X_shape_[1:]:
raise ValueError(
f"X has shape {X_shape_[1:]}, but this {self.__name__}"
f" is expecting X of shape {self.X_shape_[1:]}"
)
return X, y

def _type_of_target(self, y: np.ndarray) -> str:
Expand Down

0 comments on commit 6e01458

Please sign in to comment.