@@ -95,6 +95,49 @@ def test_basic(tmp_path):
9595 np .testing .assert_raises_regex (lgb .basic .LightGBMError , bad_shape_error_msg , bst .predict , tname )
9696
9797
98+ def test_booster_rollback_one_iter (rng ):
99+ """Test that Booster.rollback_one_iter() correctly rolls back one boosting iteration."""
100+ X = rng .uniform (size = (100 , 5 ))
101+ y = rng .integers (0 , 2 , size = (100 ,))
102+ X_test = rng .uniform (size = (10 , 5 ))
103+
104+ train_data = lgb .Dataset (X , label = y )
105+ params = {
106+ "objective" : "binary" ,
107+ "verbose" : - 1 ,
108+ }
109+ bst = lgb .Booster (params , train_data )
110+
111+ # Train for 10 iterations
112+ num_iterations = 10
113+ for _ in range (num_iterations ):
114+ bst .update ()
115+
116+ assert bst .current_iteration () == num_iterations
117+ assert bst .num_trees () == num_iterations
118+
119+ # Get predictions before rollback
120+ pred_before = bst .predict (X_test )
121+
122+ # Rollback one iteration
123+ result = bst .rollback_one_iter ()
124+
125+ # Verify rollback decremented both iteration count and tree count
126+ assert bst .current_iteration () == num_iterations - 1
127+ assert bst .num_trees () == num_iterations - 1
128+ # Verify it returns self for method chaining
129+ assert result is bst
130+
131+ # Verify predictions actually changed (proves tree was removed, not just counter)
132+ pred_after = bst .predict (X_test )
133+ assert not np .allclose (pred_before , pred_after )
134+
135+ # Verify multiple rollbacks work
136+ bst .rollback_one_iter ()
137+ assert bst .current_iteration () == num_iterations - 2
138+ assert bst .num_trees () == num_iterations - 2
139+
140+
98141class NumpySequence (lgb .Sequence ):
99142 def __init__ (self , ndarray , batch_size ):
100143 self .ndarray = ndarray
0 commit comments