Skip to content

Commit 2dc8d7c

Browse files
committed
improve serialization and memory allocation
1 parent 7daccfa commit 2dc8d7c

20 files changed

Lines changed: 385 additions & 61 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,6 @@ package-r/inst/c/
6666

6767
# Environment variables
6868
.env
69+
70+
# Temp folder
71+
tmp/

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ members = [
66

77
[package]
88
name = "perpetual"
9-
version = "2.0.0"
9+
version = "2.1.0"
1010
edition = "2024"
1111
authors = ["Mutlu Simsek <msimsek@perpetual-ml.com>", "Serkan Korkmaz <serkor1@duck.com>", "Pieter Pel <pelpieter@gmail.com>"]
1212
homepage = "https://perpetual-ml.com"

package-python/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "py-perpetual"
3-
version = "2.0.0"
3+
version = "2.1.0"
44
edition = "2024"
55
authors = ["Mutlu Simsek <msimsek@perpetual-ml.com>", "Serkan Korkmaz <serkor1@duck.com>", "Pieter Pel <pelpieter@gmail.com>"]
66
homepage = "https://perpetual-ml.com"
@@ -17,7 +17,7 @@ name = "perpetual"
1717
crate-type = ["cdylib", "rlib"]
1818

1919
[dependencies]
20-
perpetual_rs = {package="perpetual", version = "2.0.0", path = "../" }
20+
perpetual_rs = {package="perpetual", version = "2.1.0", path = "../" }
2121
pyo3 = { version = "0.28.2", features = ["extension-module"] }
2222
numpy = "0.28.0"
2323
ndarray = "0.17.2"

package-python/docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
author = "Mutlu Simsek, Serkan Korkmaz, Pieter Pel"
99

1010
# The full version, including alpha/beta/rc tags
11-
release = "2.0.0"
11+
release = "2.1.0"
1212

1313
# -- General configuration ---------------------------------------------------
1414

package-python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "maturin"
44

55
[project]
66
name = "perpetual"
7-
version = "2.0.0"
7+
version = "2.1.0"
88
description = "A self-generalizing gradient boosting machine that doesn't need hyperparameter optimization"
99
readme = "README_PYTHON.md"
1010
license = "Apache-2.0"

package-python/tests/core/test_tabarena.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pickle
23
import time
34
import urllib.request
45

@@ -17,22 +18,29 @@
1718
"tabarena_wide_y.csv": "https://github.com/user-attachments/files/25682566/y.csv",
1819
"tabarena_var_x.csv": "https://github.com/user-attachments/files/25739177/X.csv",
1920
"tabarena_var_y.csv": "https://github.com/user-attachments/files/25739178/y.csv",
21+
"tabarena_load_x.csv": "https://github.com/user-attachments/files/25764663/X.csv",
22+
"tabarena_load_y.csv": "https://github.com/user-attachments/files/25764665/y.csv",
2023
}
2124

2225

2326
def _get_csv(filename):
2427
"""Try to read a CSV from resources dir, then download, or skip the test."""
28+
print(f"Loading {filename}...")
2529
local_path = os.path.join(RESOURCES_DIR, filename)
2630
if os.path.isfile(local_path):
27-
return pd.read_csv(local_path)
31+
res = pd.read_csv(local_path)
32+
print(f"Loaded {filename} from resources.")
33+
return res
2834

2935
# Try downloading to the resources folder
3036
url = DOWNLOAD_URLS.get(filename)
3137
if url is None:
3238
pytest.skip(f"No download URL configured for {filename}")
3339
try:
40+
print(f"Downloading {filename} from {url}...")
3441
os.makedirs(RESOURCES_DIR, exist_ok=True)
3542
urllib.request.urlretrieve(url, local_path)
43+
print(f"Downloaded {filename}.")
3644
except Exception as exc:
3745
pytest.skip(f"Could not download {filename}: {exc}")
3846

@@ -171,8 +179,8 @@ def test_tabarena_wide():
171179
objective="LogLoss",
172180
budget=2.0,
173181
categorical_features=categorical_features,
174-
memory_limit=1, # 30
175-
iteration_limit=3, # 10000
182+
memory_limit=1,
183+
iteration_limit=3,
176184
timeout=60 * 15,
177185
)
178186

@@ -225,3 +233,104 @@ def test_tabarena_var():
225233
categorical_features=categorical_features, iteration_limit=3, memory_limit=1
226234
)
227235
model.fit(X, y)
236+
237+
238+
def test_tabarena_save_load():
239+
print("test_tabarena_save_load started.")
240+
X = _get_csv("tabarena_load_x.csv")
241+
y = _get_csv("tabarena_load_y.csv")
242+
243+
categorical_features = [
244+
"position_-30",
245+
"position_-29",
246+
"position_-28",
247+
"position_-27",
248+
"position_-26",
249+
"position_-25",
250+
"position_-24",
251+
"position_-23",
252+
"position_-22",
253+
"position_-21",
254+
"position_-20",
255+
"position_-19",
256+
"position_-18",
257+
"position_-17",
258+
"position_-16",
259+
"position_-15",
260+
"position_-14",
261+
"position_-13",
262+
"position_-12",
263+
"position_-11",
264+
"position_-10",
265+
"position_-9",
266+
"position_-8",
267+
"position_-7",
268+
"position_-6",
269+
"position_-5",
270+
"position_-4",
271+
"position_-3",
272+
"position_-2",
273+
"position_-1",
274+
"position_1",
275+
"position_2",
276+
"position_3",
277+
"position_4",
278+
"position_5",
279+
"position_6",
280+
"position_7",
281+
"position_8",
282+
"position_9",
283+
"position_10",
284+
"position_11",
285+
"position_12",
286+
"position_13",
287+
"position_14",
288+
"position_15",
289+
"position_16",
290+
"position_17",
291+
"position_18",
292+
"position_19",
293+
"position_20",
294+
"position_21",
295+
"position_22",
296+
"position_23",
297+
"position_24",
298+
"position_25",
299+
"position_26",
300+
"position_27",
301+
"position_28",
302+
"position_29",
303+
"position_30",
304+
]
305+
306+
model = PerpetualBooster(
307+
categorical_features=categorical_features,
308+
memory_limit=3,
309+
num_threads=8,
310+
objective="LogLoss",
311+
iteration_limit=10,
312+
budget=2.0,
313+
)
314+
315+
print(f"Starting fit... memory_limit={model.memory_limit}")
316+
model.fit(X, y)
317+
print("Fit completed.")
318+
print(f"Number of trees: {model.number_of_trees}")
319+
320+
model_path = os.path.join(RESOURCES_DIR, "model.pkl")
321+
322+
print(f"Saving model to {model_path}...")
323+
with open(model_path, "wb") as f:
324+
pickle.dump(model, f)
325+
print("Model saved.")
326+
327+
print("Loading model...")
328+
with open(model_path, "rb") as f:
329+
loaded_model = pickle.load(f)
330+
print("Model loaded.")
331+
332+
del loaded_model
333+
try:
334+
os.remove(model_path)
335+
except OSError:
336+
pass

package-python/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package-r/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: perpetual
22
Type: Package
33
Title: PerpetualBooster
4-
Version: 2.0.0
4+
Version: 2.1.0
55
Authors@R: c(
66
person("Mutlu", "Simsek", email = "msimsek@perpetual-ml.com", role = c("aut", "cre")),
77
person("Serkan", "Korkmaz", email = "serkor1@duck.com", role = "aut"),

package-r/src/rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[package]
55
name = "perpetual_r"
6-
version = "2.0.0"
6+
version = "2.1.0"
77
edition = "2021"
88

99
[lib]

scripts/increment_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import subprocess
44
from pathlib import Path
55

6-
# python scripts/increment_version.py 2.0.0 --dry-run
6+
# python scripts/increment_version.py 2.1.0 --dry-run
77

88

99
def update_file(file_path, pattern, replacement, dry_run=False):

0 commit comments

Comments
 (0)