Skip to content

Commit a833ab4

Browse files
authored
Update for scikit-learn 1.2 (#267)
1 parent 8ca556a commit a833ab4

File tree

5 files changed

+16
-7
lines changed

5 files changed

+16
-7
lines changed

deeptime/base.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from collections import defaultdict
33
from inspect import signature
44
from typing import Optional, List
5-
6-
from sklearn.base import _pprint as pprint_sklearn
5+
from pprint import PrettyPrinter
76

87

98
class _BaseMethodsMixin(abc.ABC):
@@ -12,10 +11,12 @@ class _BaseMethodsMixin(abc.ABC):
1211
"""
1312

1413
def __repr__(self):
14+
pp = PrettyPrinter(indent=1, depth=2)
1515
name = '{cls}-{id}:'.format(id=id(self), cls=self.__class__.__name__)
16-
return '{name}{params}]'.format(
17-
name=name, params=pprint_sklearn(self.get_params(), offset=len(name), )
18-
)
16+
offset = "".join([' '] * len(name))
17+
params = pp.pformat(self.get_params())
18+
params = params.replace('\n', '\n' + offset)
19+
return '{name}[{params}]'.format(name=name, params=params)
1920

2021
def get_params(self, deep=False):
2122
r"""Get the parameters.

deeptime/sindy/_sindy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class STLSQ(LinearRegression):
409409

410410
def __init__(self, threshold=0.1, alpha=0.05, max_iter=20, ridge_kw=None, normalize=False, fit_intercept=False,
411411
copy_X=True):
412-
super().__init__(fit_intercept=fit_intercept, normalize=normalize, copy_X=copy_X)
412+
super().__init__(fit_intercept=fit_intercept, copy_X=copy_X)
413413
self.threshold = threshold
414414
self.alpha = alpha
415415
self.max_iter = max_iter

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ tests = [
4949
'tqdm==4.64.0',
5050
"cython>=0.29.30",
5151
"pybind11>=2.10.1",
52+
"networkx",
53+
"matplotlib",
5254
"cmake>=3.24",
5355
"ninja; platform_system!='Windows'"
5456
]

tests/base/test_base_interface.py

+7
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@ def test_mock_model():
5757

5858
with assert_raises(ValueError):
5959
m.set_params(nope=33)
60+
61+
representation = repr(m)
62+
things_that_should_be_represented = [
63+
"MockModel", "A", "a", "55", "p1", "1.0", "p2", "2.0", "p3", "3.0", "p4", "55"
64+
]
65+
for s in things_that_should_be_represented:
66+
assert_(s in representation)

tests/markov/hmm/test_integration.py

-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_observation_probabilities(hmm_scenario):
9797
assert_almost_equal(minerr, 0, decimal=2)
9898

9999

100-
@flaky(max_runs=3)
101100
def test_stationary_distribution(hmm_scenario):
102101
model = hmm_scenario.hmm
103102
minerr = 1e6

0 commit comments

Comments
 (0)