Skip to content

Commit d6772ab

Browse files
authored
Merge pull request #201 from atong01/torch-2.0
Torch 2.0 Compatibility
2 parents 92dd9ba + 2b0de91 commit d6772ab

File tree

5 files changed

+37
-8
lines changed

5 files changed

+37
-8
lines changed

.github/workflows/os-coverage.yml

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,30 @@ jobs:
99
max-parallel: 15
1010
matrix:
1111
os: [ubuntu-latest, macos-latest, windows-latest]
12-
python-version: ['3.8', '3.9', '3.10', '3.11']
12+
python-version: ["3.8", "3.9", "3.10", "3.11"]
13+
torch-version: ["1.8.1", "1.9.1", "1.10.0", "1.11.0", "1.12.0", "1.13.1", "2.0.0"]
14+
exclude:
15+
# python >= 3.10 does not support pytorch < 1.11.0
16+
- torch-version: "1.8.1"
17+
python-version: "3.10"
18+
- torch-version: "1.9.1"
19+
python-version: "3.10"
20+
- torch-version: "1.10.0"
21+
python-version: "3.10"
22+
# python >= 3.11 does not support pytorch < 1.13.0
23+
- torch-version: "1.8.1"
24+
python-version: "3.11"
25+
- torch-version: "1.9.1"
26+
python-version: "3.11"
27+
- torch-version: "1.10.0"
28+
python-version: "3.11"
29+
- torch-version: "1.11.0"
30+
python-version: "3.11"
31+
- torch-version: "1.12.0"
32+
python-version: "3.11"
33+
- torch-version: "1.13.1"
34+
python-version: "3.11"
35+
1336
defaults:
1437
run:
1538
shell: bash
@@ -33,14 +56,18 @@ jobs:
3356
uses: actions/cache@v3
3457
with:
3558
path: ~/.cache
36-
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
59+
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.torch-version }}-${{ hashFiles('**/poetry.lock') }}
3760

3861
- name: Install dependencies # hack for 🐛: don't let poetry try installing Torch https://github.com/pytorch/pytorch/issues/88049
3962
run: |
4063
pip install pytest pytest-cov papermill poethepoet>=0.10.0
41-
pip install torch>=1.8.1 torchvision pytorch-lightning scikit-learn torchsde torchcde>=0.2.3 scipy matplotlib ipykernel ipywidgets
64+
pip install torch==${{ matrix.torch-version }} pytorch-lightning scikit-learn torchsde torchcde>=0.2.3 scipy matplotlib ipykernel ipywidgets
4265
poetry install --only-root
4366
poetry run pip install setuptools
67+
68+
- name: List dependencies
69+
run: |
70+
pip list
4471
4572
- name: Run pytest checks
4673
run: |

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ packages = [
1010

1111
[tool.poetry.dependencies]
1212
python = "^3.8"
13-
torch = "^1.8.1"
13+
torch = ">=1.8.1"
1414
torchsde="*"
1515
torchcde="^0.2.3"
1616
scikit-learn = "*"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
description="PyTorch package for all things neural differential equations.",
2121
url="https://github.com/DiffEqML/torchdyn",
2222
install_requires=[
23-
"torch>=1.6.0",
23+
"torch>=1.8.1",
2424
"pytorch-lightning>=0.8.4",
2525
"matplotlib",
2626
"scikit-learn",

test/models/test_ode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
from packaging.version import parse
1314
import pytest
1415
import torch
1516
import torch.nn as nn
@@ -261,6 +262,8 @@ def forward(self, t, x, u, v, z, args={}):
261262
grad(sol2.sum(), x0)
262263

263264

265+
@pytest.mark.skipif(parse(torch.__version__) < parse("1.11.0"),
266+
reason="adjoint support added in torch 1.11.0")
264267
def test_complex_ode():
265268
"""Test odeint for complex numbers with a simple complex-valued ODE, corresponding
266269
to Rabi oscillations of quantum two-level system."""
@@ -312,4 +315,4 @@ def test_odeint(solver):
312315
t_span = torch.linspace(0., 2., 10)
313316
sys = Lorenz()
314317

315-
odeint(sys, x0, t_span, solver=solver)
318+
odeint(sys, x0, t_span, solver=solver)

torchdyn/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13-
__version__ = '1.0'
13+
__version__ = '1.0.5'
1414
__author__ = 'Michael Poli, Stefano Massaroli et al.'
1515

1616
from torch import Tensor
1717
from typing import Tuple
1818

1919
TTuple = Tuple[Tensor, Tensor]
20-

0 commit comments

Comments
 (0)