Skip to content

Commit e86422f

Browse files
authored
Merge branch 'master' into use-auditwheel
2 parents f2638a3 + 9612485 commit e86422f

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/python_package_test/test_basic.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,49 @@ def test_basic(tmp_path):
9595
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg, bst.predict, tname)
9696

9797

98+
def test_booster_rollback_one_iter(rng):
99+
"""Test that Booster.rollback_one_iter() correctly rolls back one boosting iteration."""
100+
X = rng.uniform(size=(100, 5))
101+
y = rng.integers(0, 2, size=(100,))
102+
X_test = rng.uniform(size=(10, 5))
103+
104+
train_data = lgb.Dataset(X, label=y)
105+
params = {
106+
"objective": "binary",
107+
"verbose": -1,
108+
}
109+
bst = lgb.Booster(params, train_data)
110+
111+
# Train for 10 iterations
112+
num_iterations = 10
113+
for _ in range(num_iterations):
114+
bst.update()
115+
116+
assert bst.current_iteration() == num_iterations
117+
assert bst.num_trees() == num_iterations
118+
119+
# Get predictions before rollback
120+
pred_before = bst.predict(X_test)
121+
122+
# Rollback one iteration
123+
result = bst.rollback_one_iter()
124+
125+
# Verify rollback decremented both iteration count and tree count
126+
assert bst.current_iteration() == num_iterations - 1
127+
assert bst.num_trees() == num_iterations - 1
128+
# Verify it returns self for method chaining
129+
assert result is bst
130+
131+
# Verify predictions actually changed (proves tree was removed, not just counter)
132+
pred_after = bst.predict(X_test)
133+
assert not np.allclose(pred_before, pred_after)
134+
135+
# Verify multiple rollbacks work
136+
bst.rollback_one_iter()
137+
assert bst.current_iteration() == num_iterations - 2
138+
assert bst.num_trees() == num_iterations - 2
139+
140+
98141
class NumpySequence(lgb.Sequence):
99142
def __init__(self, ndarray, batch_size):
100143
self.ndarray = ndarray

0 commit comments

Comments
 (0)