Skip to content

Commit c4f132b

Browse files
authored
Switch to functools for caching (#700)
* Switch to functools implementation which introduces 'bug' The bug is that now objects of the same class can not have their cache cleared independently. From the usage, I do not believe that there would otherwise be two different instances of the same class anyway, though I may be wrong (thinking datasplit). Will need to investigate more tomorrow. In any case, the plan is to move to a more friendly implementation anyway where those classes that do need cache eviction simply add a two-liner for that. * Use cached_property instead of caching a property manually * Properly clear properties and also functions Still has the issue that cache for all instances of the class are cleared for functions (but not for properties). * Use functools for caching
1 parent e3eb201 commit c4f132b

File tree

7 files changed

+65
-116
lines changed

7 files changed

+65
-116
lines changed

amlb/benchmark.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from copy import copy
1414
from enum import Enum
15+
from functools import cached_property
1516
from importlib import import_module, invalidate_caches
1617
import logging
1718
import math
@@ -37,7 +38,6 @@
3738
file_lock,
3839
flatten,
3940
json_dump,
40-
lazy_property,
4141
profile,
4242
repr_def,
4343
run_cmd,
@@ -503,7 +503,7 @@ def _results_summary(self, scoreboard=None):
503503
)
504504
return board.as_data_frame()
505505

506-
@lazy_property
506+
@cached_property
507507
def output_dirs(self):
508508
return routput_dirs(
509509
rconfig().output_dir,

amlb/data.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from abc import ABC, abstractmethod
1717
from enum import Enum
18+
from functools import cached_property
1819
import logging
1920
from typing import List, Union, Iterable
2021

@@ -24,7 +25,7 @@
2425
from typing_extensions import TypeAlias
2526

2627
from .datautils import Encoder
27-
from .utils import clear_cache, lazy_property, profile, repr_def
28+
from .utils import clear_cache, profile, repr_def
2829

2930
log = logging.getLogger(__name__)
3031

@@ -66,7 +67,7 @@ def is_categorical(self, strict: bool = True) -> bool:
6667
def is_numerical(self) -> bool:
6768
return self.data_type in ["int", "float", "number"]
6869

69-
@lazy_property
70+
@cached_property
7071
def label_encoder(self) -> Encoder:
7172
return Encoder(
7273
"label" if self.values is not None else "no-op",
@@ -77,7 +78,7 @@ def label_encoder(self) -> Encoder:
7778
normalize_fn=Feature.normalize,
7879
).fit(self.values)
7980

80-
@lazy_property
81+
@cached_property
8182
def one_hot_encoder(self) -> Encoder:
8283
return Encoder(
8384
"one-hot" if self.values is not None else "no-op",
@@ -127,15 +128,15 @@ def data_path(self, format: str) -> str:
127128
"""
128129
pass
129130

130-
@property
131+
@cached_property
131132
@abstractmethod
132133
def data(self) -> DF:
133134
"""
134135
:return: all the columns (predictors + target) as a pandas DataFrame.
135136
"""
136137
pass
137138

138-
@lazy_property
139+
@cached_property
139140
@profile(logger=log)
140141
def X(self) -> DF:
141142
"""
@@ -144,15 +145,15 @@ def X(self) -> DF:
144145
predictors_ind = [p.index for p in self.dataset.predictors]
145146
return self.data.iloc[:, predictors_ind]
146147

147-
@lazy_property
148+
@cached_property
148149
@profile(logger=log)
149150
def y(self) -> DF:
150151
"""
151152
:return:the target column as a pandas DataFrame: if you need a Series, just call `y.squeeze()`.
152153
"""
153154
return self.data.iloc[:, [self.dataset.target.index]] # type: ignore
154155

155-
@lazy_property
156+
@cached_property
156157
@profile(logger=log)
157158
def data_enc(self) -> AM:
158159
encoded_cols = [
@@ -162,15 +163,15 @@ def data_enc(self) -> AM:
162163
# optimize mem usage : frameworks use either raw data or encoded ones,
163164
# so we can clear the cached raw data once they've been encoded
164165
self.release(["data", "X", "y"])
165-
return np.hstack(tuple(col.reshape(-1, 1) for col in encoded_cols))
166+
return np.hstack(tuple(col.reshape(-1, 1) for col in encoded_cols)) # type: ignore[union-attr]
166167

167-
@lazy_property
168+
@cached_property
168169
@profile(logger=log)
169170
def X_enc(self) -> AM:
170171
predictors_ind = [p.index for p in self.dataset.predictors]
171172
return self.data_enc[:, predictors_ind]
172173

173-
@lazy_property
174+
@cached_property
174175
@profile(logger=log)
175176
def y_enc(self) -> AM:
176177
# return self.dataset.target.label_encoder.transform(self.y)

amlb/datasets/file.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import re
77
import tempfile
8+
from functools import cache, cached_property
89
from typing import List
910

1011
import arff
@@ -18,9 +19,7 @@
1819
from ..utils import (
1920
Namespace as ns,
2021
as_list,
21-
lazy_property,
2222
list_all_files,
23-
memoize,
2423
path_from_split,
2524
profile,
2625
repr_def,
@@ -257,7 +256,7 @@ def features(self) -> List[Feature]:
257256
def target(self) -> Feature:
258257
return self._get_metadata("target")
259258

260-
@memoize
259+
@cache
261260
def _get_metadata(self, prop):
262261
meta = self._train.load_metadata()
263262
return meta[prop]
@@ -281,7 +280,7 @@ def data_path(self, format):
281280
)
282281
return self._get_data(format)
283282

284-
@lazy_property
283+
@cached_property
285284
def data(self):
286285
# use codecs for unicode support: path = codecs.load(self._path, 'rb', 'utf-8')
287286
log.debug("Loading datasplit %s.", self.path)

amlb/datasets/openml.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from abc import abstractmethod
1010
import copy
1111
import functools
12+
from functools import cached_property
1213
import logging
1314
import os
1415
import re
@@ -27,7 +28,6 @@
2728
from ..resources import config as rconfig, get as rget
2829
from ..utils import (
2930
as_list,
30-
lazy_property,
3131
path_from_split,
3232
profile,
3333
split_path,
@@ -107,7 +107,7 @@ def nrows(self) -> int:
107107
self._nrows = len(self._load_full_data(fmt="dataframe"))
108108
return self._nrows
109109

110-
@lazy_property
110+
@cached_property
111111
def type(self):
112112
def get_type(card):
113113
if card > 2:
@@ -262,7 +262,7 @@ def get_non_empty_columns(data: DF) -> list[Hashable]:
262262

263263
return subsample_path
264264

265-
@lazy_property
265+
@cached_property
266266
@profile(logger=log)
267267
def features(self):
268268
def has_missing_values(f) -> bool:
@@ -298,7 +298,7 @@ def to_feature_type(dt):
298298
)
299299
]
300300

301-
@lazy_property
301+
@cached_property
302302
def target(self):
303303
return next(f for f in self.features if f.is_target)
304304

@@ -347,12 +347,12 @@ def data_path(self, format):
347347
)
348348
return self._get_data(format)
349349

350-
@lazy_property
350+
@cached_property
351351
@profile(logger=log)
352352
def data(self) -> DF:
353353
return self._get_data("dataframe")
354354

355-
@lazy_property
355+
@cached_property
356356
@profile(logger=log)
357357
def data_enc(self) -> AM:
358358
return self._get_data("array")

amlb/resources.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
import random
1313
import re
1414
import sys
15+
from functools import cache, cached_property
1516

1617
from amlb.benchmarks.parser import benchmark_load
1718
from amlb.frameworks import default_tag, load_framework_definitions
1819
from .frameworks.definitions import TaskConstraint
1920
from .utils import (
2021
Namespace,
21-
lazy_property,
22-
memoize,
2322
normalize_path,
2423
run_cmd,
2524
str_sanitize,
@@ -66,15 +65,15 @@ def __init__(self, config: Namespace):
6665
sys.path.append(common_dirs["user"])
6766
log.debug("Extended Python sys.path to user directory: %s.", sys.path)
6867

69-
@lazy_property
68+
@cached_property
7069
def project_info(self):
7170
split_url = self.config.project_repository.split("#", 1)
7271
repo = split_url[0]
7372
tag = None if len(split_url) == 1 else split_url[1]
7473
branch = tag or "master"
7574
return Namespace(repo=repo, tag=tag, branch=branch)
7675

77-
@lazy_property
76+
@cached_property
7877
def git_info(self):
7978
def git(cmd, defval=None):
8079
try:
@@ -99,7 +98,7 @@ def git(cmd, defval=None):
9998
repo=repo, branch=branch, commit=commit, tags=tags, status=status
10099
)
101100

102-
@lazy_property
101+
@cached_property
103102
def app_version(self):
104103
v = __version__
105104
if v != dev:
@@ -118,7 +117,7 @@ def seed(self, fold=None):
118117
else:
119118
return self._seed
120119

121-
@lazy_property
120+
@cached_property
122121
def _seed(self):
123122
if str(self.config.seed).lower() in ["none", ""]:
124123
return None
@@ -167,12 +166,12 @@ def framework_definition(self, name, tag=None):
167166
)
168167
return framework, framework.name
169168

170-
@lazy_property
169+
@cached_property
171170
def _frameworks(self):
172171
frameworks_file = self.config.frameworks.definition_file
173172
return load_framework_definitions(frameworks_file, self.config)
174173

175-
@memoize
174+
@cache
176175
def constraint_definition(self, name: str) -> TaskConstraint:
177176
"""
178177
:param name: name of the benchmark constraint definition as defined in the constraints file
@@ -187,7 +186,7 @@ def constraint_definition(self, name: str) -> TaskConstraint:
187186
)
188187
return TaskConstraint(**Namespace.dict(constraint))
189188

190-
@lazy_property
189+
@cached_property
191190
def _constraints(self):
192191
constraints_file = self.config.benchmarks.constraints_file
193192
log.info("Loading benchmark constraint definitions from %s.", constraints_file)

amlb/results.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import collections
99
import io
1010
import logging
11+
from functools import cache
12+
1113
import math
1214
import os
1315
import re
@@ -43,11 +45,9 @@
4345
from .utils import (
4446
Namespace,
4547
backup_file,
46-
cached,
4748
datetime_iso,
4849
get_metadata,
4950
json_load,
50-
memoize,
5151
profile,
5252
set_metadata,
5353
)
@@ -185,7 +185,7 @@ def __init__(
185185
else None
186186
)
187187

188-
@cached
188+
@cache
189189
def as_data_frame(self):
190190
# index = ['task', 'framework', 'fold']
191191
index = []
@@ -236,7 +236,7 @@ def as_data_frame(self):
236236
log.debug("Scores columns: %s.", df.columns)
237237
return df
238238

239-
@memoize
239+
@cache
240240
def as_printable_data_frame(self, verbosity=3):
241241
def none_like_as_empty(val: Any) -> str:
242242
return (
@@ -450,7 +450,7 @@ def save_predictions(
450450
] # reorder columns alphabetically: necessary to match label encoding
451451
if any(prob_cols != df.columns.values):
452452
encoding_map = {
453-
prob_cols.index(col): i
453+
prob_cols.index(col): i # type: ignore[union-attr]
454454
for i, col in enumerate(df.columns.values)
455455
}
456456
remap = np.vectorize(lambda v: encoding_map[v])
@@ -606,11 +606,11 @@ def __init__(
606606
)
607607
self._metadata = metadata
608608

609-
@cached
609+
@cache
610610
def get_result(self):
611611
return self.load_predictions(self._predictions_file)
612612

613-
@cached
613+
@cache
614614
def get_result_metadata(self):
615615
return self._metadata or self.load_metadata(self._metadata_file)
616616

0 commit comments

Comments
 (0)