Skip to content

Commit 1ada47a

Browse files
committed
Fix DeepLearning params modification issue
1 parent d0b0868 commit 1ada47a

2 files changed

Lines changed: 41 additions & 2 deletions

File tree

h2o-algos/src/test/java/hex/deeplearning/DeepLearningTest.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,45 @@ public void testCheckpointBackwards() {
13181318
}
13191319
}
13201320

1321+
// GH-16717: Checkpoint resumption fails when distribution=AUTO gets resolved during training.
1322+
// Training mutates the caller's parms object (AUTO -> multinomial), so cloning after training
1323+
// picks up the resolved value. Checkpoint validation then sees AUTO (from _input_parms) vs
1324+
// multinomial (from the cloned parms) and incorrectly throws.
1325+
@Test
1326+
public void testCheckpointAutoDistribution() {
1327+
Frame tfr = null;
1328+
DeepLearningModel dl = null;
1329+
DeepLearningModel dl2 = null;
1330+
1331+
try {
1332+
tfr = parseTestFile("./smalldata/iris/iris.csv");
1333+
DeepLearningParameters parms = new DeepLearningParameters();
1334+
parms._train = tfr._key;
1335+
parms._epochs = 1;
1336+
parms._response_column = "C5";
1337+
parms._reproducible = true;
1338+
parms._hidden = new int[]{2, 2};
1339+
parms._seed = 0xdecaf;
1340+
1341+
// distribution defaults to AUTO; training resolves it to multinomial and mutates parms
1342+
dl = new DeepLearning(parms).trainModel().get();
1343+
1344+
// Clone AFTER training — parms._distribution is now multinomial (mutated by training)
1345+
DeepLearningParameters parms2 = (DeepLearningParameters) parms.clone();
1346+
parms2._epochs = 2;
1347+
parms2._checkpoint = dl._key;
1348+
1349+
// This should succeed but fails with:
1350+
// "Cannot change parameter: '_distribution': AUTO -> multinomial"
1351+
dl2 = new DeepLearning(parms2).trainModel().get();
1352+
Assert.assertTrue(dl2.epoch_counter > dl.epoch_counter);
1353+
} finally {
1354+
if (tfr != null) tfr.delete();
1355+
if (dl != null) dl.delete();
1356+
if (dl2 != null) dl2.delete();
1357+
}
1358+
}
1359+
13211360
@Test
13221361
public void testConvergenceLogloss() {
13231362
Frame tfr = null;

h2o-core/src/main/java/hex/ModelBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ protected ModelBuilder(P parms) {
7878
/** Unique new job and named result key */
7979
protected ModelBuilder(P parms, Key<M> key) {
8080
_job = new Job<>(_result = key, parms.javaName(), parms.algoName());
81-
_parms = parms;
8281
_input_parms = (P) parms.clone();
82+
_parms = (P) parms.clone();
8383
}
8484

8585
/** Shared pre-existing Job and unique new result key */
8686
protected ModelBuilder(P parms, Job<M> job) {
8787
_job = job;
8888
_result = defaultKey(parms.algoName());
89-
_parms = parms;
9089
_input_parms = (P) parms.clone();
90+
_parms = (P) parms.clone();
9191
}
9292

9393
/** List of known ModelBuilders with all default args; endlessly cloned by

0 commit comments

Comments
 (0)