Skip to content

Commit b1c449d

Browse files
committed
Implement review suggestions
1 parent b2c4762 commit b1c449d

6 files changed

Lines changed: 238 additions & 133 deletions

File tree

h2o-algos/src/main/java/hex/glm/GLM.java

Lines changed: 64 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.text.DecimalFormat;
3838
import java.text.NumberFormat;
3939
import java.util.*;
40+
import java.util.function.Consumer;
4041
import java.util.stream.Collectors;
4142
import java.util.stream.IntStream;
4243

@@ -3398,18 +3399,39 @@ private void scoreAndUpdateModel() {
33983399
_model._useRemoveOffsetEffects = _model._parms._remove_offset_effects;
33993400
long t2 = System.currentTimeMillis();
34003401
_model.score(train, null, CFuncRef.from(_parms._custom_metric_func)).delete();
3401-
scorePostProcessingRestrictedModel(train, t2);
3402+
scorePostProcessingRestricted(train, t2,
3403+
m -> _model._output._training_metrics = m,
3404+
m -> _model._output._validation_metrics = m,
3405+
_scoringHistory, _model._useControlVariables);
3406+
_model.addScoringInfo(_parms, nclasses(), t2, _state._iter);
3407+
_model._output._scoring_history = _scoringHistory != null ? _scoringHistory.to2dTable(_parms, null, null) : null;
3408+
34023409
if (_model._parms._control_variables != null && _model._parms._remove_offset_effects) {
3410+
// CV-only
34033411
_model._useControlVariables = true;
34043412
_model._useRemoveOffsetEffects = false;
34053413
t2 = System.currentTimeMillis();
34063414
_model.score(train, null, CFuncRef.from(_parms._custom_metric_func)).delete();
3407-
scorePostProcessingRestrictedModelCVEnabled(train, t2);
3415+
scorePostProcessingRestricted(train, t2,
3416+
m -> _model._output._training_metrics_restricted_model_cv = m,
3417+
m -> _model._output._validation_metrics_restricted_model_cv = m,
3418+
_scoringHistoryControlValEnabled, true);
3419+
_model.addRestrictedModelScoringInfoCV(_parms, nclasses(), t2, _state._iter);
3420+
_model._output._scoring_history_restricted_model_cv = _scoringHistoryControlValEnabled != null
3421+
? _scoringHistoryControlValEnabled.to2dTable(_parms, null, null) : null;
3422+
3423+
// RO-only
34083424
_model._useControlVariables = false;
34093425
_model._useRemoveOffsetEffects = true;
34103426
t2 = System.currentTimeMillis();
34113427
_model.score(train, null, CFuncRef.from(_parms._custom_metric_func)).delete();
3412-
scorePostProcessingRestrictedModelROEnabled(train, t2);
3428+
scorePostProcessingRestricted(train, t2,
3429+
m -> _model._output._training_metrics_restricted_model_ro = m,
3430+
m -> _model._output._validation_metrics_restricted_model_ro = m,
3431+
_scoringHistoryRemoveOffsetEnabled, false);
3432+
_model.addRestrictedModelScoringInfoRO(_parms, nclasses(), t2, _state._iter);
3433+
_model._output._scoring_history_restricted_model_ro = _scoringHistoryRemoveOffsetEnabled != null
3434+
? _scoringHistoryRemoveOffsetEnabled.to2dTable(_parms, null, null) : null;
34133435
}
34143436
} finally {
34153437
_model._useControlVariables = false;
@@ -3423,133 +3445,46 @@ private void scoreAndUpdateModel() {
34233445
_model.generateSummary(_parms._train, _state._iter);
34243446
}
34253447

3426-
private void scorePostProcessingRestrictedModel(Frame train, long t1) {
3427-
ModelMetrics mtrain = ModelMetrics.getFromDKV(_model, train); // updated by model.scoreAndUpdateModel
3428-
long t2 = System.currentTimeMillis();
3429-
if (mtrain != null) {
3430-
_model._output._training_metrics = mtrain;
3431-
_model._output._training_time_ms = t2 - _model._output._start_time; // remember training time
3432-
ScoreKeeper trainScore = new ScoreKeeper(Double.NaN);
3433-
trainScore.fillFrom(mtrain);
3434-
Log.info(LogMsg(mtrain.toString()));
3435-
} else {
3436-
Log.info(LogMsg("ModelMetrics mtrain is null"));
3437-
}
3438-
Log.info(LogMsg("Restricted model training metrics computed in " + (t2 - t1) + "ms"));
3439-
if (_valid != null) {
3440-
Frame valid = DKV.<Frame>getGet(_parms._valid);
3441-
_model.score(_parms.valid(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
3442-
_model._output._validation_metrics = ModelMetrics.getFromDKV(_model, valid); //updated by model.scoreAndUpdateModel
3443-
ScoreKeeper validScore = new ScoreKeeper(Double.NaN);
3444-
validScore.fillFrom(_model._output._validation_metrics);
3445-
}
3446-
_model.addScoringInfo(_parms, nclasses(), t2, _state._iter); // add to scoringInfo for early stopping
3447-
3448-
if (_parms._generate_scoring_history) { // update scoring history with deviance train and valid if available
3449-
if(_model._useControlVariables) {
3450-
double[] betaContrVal = _model._output.getControlValBeta(_state.expandBeta(_state.beta()).clone());
3451-
GLMResDevTask task = new GLMResDevTask(_job._key, _dinfo, _parms, betaContrVal).doAll(_dinfo._adaptedFrame);
3452-
double objectiveControlVal = _state.objective(betaContrVal, task._likelihood);
3453-
3454-
if ((mtrain != null) && (_valid != null)) {
3455-
_scoringHistory.addIterationScore(true, true, _state._iter, task._likelihood,
3456-
objectiveControlVal, _state.deviance(task._likelihood), ((GLMMetrics) _model._output._validation_metrics).residual_deviance(),
3457-
mtrain._nobs, _model._output._validation_metrics._nobs, _state.lambda(), _state.alpha());
3458-
} else { // only doing training deviance
3459-
_scoringHistory.addIterationScore(true, false, _state._iter, task._likelihood,
3460-
objectiveControlVal, _state.deviance(task._likelihood), Double.NaN, mtrain._nobs, 1, _state.lambda(),
3461-
_state.alpha());
3462-
}
3463-
} else if (_model._useRemoveOffsetEffects) {
3464-
if ((mtrain != null) && (_valid != null)) {
3465-
_scoringHistory.addIterationScore(true, true, _state._iter, _state.likelihood(),
3466-
_state.objective(), _state.deviance(), ((GLMMetrics) _model._output._validation_metrics).residual_deviance(),
3467-
mtrain._nobs, _model._output._validation_metrics._nobs, _state.lambda(), _state.alpha());
3468-
} else { // only doing training deviance
3469-
_scoringHistory.addIterationScore(true, false, _state._iter, _state.likelihood(),
3470-
_state.objective(), _state.deviance(), Double.NaN, mtrain._nobs, 1, _state.lambda(),
3471-
_state.alpha());
3472-
}
3448+
private void scorePostProcessingRestricted(Frame train, long t1,
3449+
Consumer<ModelMetrics> storeTrain,
3450+
Consumer<ModelMetrics> storeValid,
3451+
ScoringHistory sh,
3452+
boolean useControlVarBeta) {
3453+
ModelMetrics mtrain = ModelMetrics.getFromDKV(_model, train);
3454+
long t2 = System.currentTimeMillis();
3455+
if (mtrain != null) {
3456+
storeTrain.accept(mtrain);
3457+
_model._output._training_time_ms = t2 - _model._output._start_time;
3458+
Log.info(LogMsg(mtrain.toString()));
34733459
}
3474-
}
3475-
_model._output._scoring_history = _scoringHistory != null ? _scoringHistory.to2dTable(_parms, null, null) : null;
3476-
}
3477-
3478-
private void scorePostProcessingRestrictedModelCVEnabled(Frame train, long t1) {
3479-
ModelMetrics mtrain = ModelMetrics.getFromDKV(_model, train); // updated by model.scoreAndUpdateModel
3480-
long t2 = System.currentTimeMillis();
3481-
if (mtrain != null) {
3482-
_model._output._training_metrics_restricted_model_cv = mtrain;
3483-
_model._output._training_time_ms = t2 - _model._output._start_time; // remember training time
3484-
ScoreKeeper trainScore = new ScoreKeeper(Double.NaN);
3485-
trainScore.fillFrom(mtrain);
3486-
Log.info(LogMsg(mtrain.toString()));
3487-
} else {
3488-
Log.info(LogMsg("ModelMetrics mtrain is null"));
3489-
}
3490-
Log.info(LogMsg("Restricted model where control variables feature is enabled training metrics computed in " + (t2 - t1) + "ms"));
3491-
if (_valid != null) {
3492-
Frame valid = DKV.<Frame>getGet(_parms._valid);
3493-
_model.score(_parms.valid(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
3494-
_model._output._validation_metrics_restricted_model_cv = ModelMetrics.getFromDKV(_model, valid); //updated by model.scoreAndUpdateModel
3495-
ScoreKeeper validScore = new ScoreKeeper(Double.NaN);
3496-
validScore.fillFrom(_model._output._validation_metrics_restricted_model_cv);
3497-
}
3498-
_model.addRestrictedModelScoringInfoCV(_parms, nclasses(), t2, _state._iter); // add to scoringInfo for early stopping
3499-
3500-
if (_parms._generate_scoring_history) { // update scoring history with deviance train and valid if available
3501-
double[] betaContrVal = _model._output.getControlValBeta(_state.expandBeta(_state.beta()).clone());
3502-
GLMResDevTask task = new GLMResDevTask(_job._key, _dinfo, _parms, betaContrVal).doAll(_dinfo._adaptedFrame);
3503-
double objectiveControlVal = _state.objective(betaContrVal, task._likelihood);
3504-
3505-
if ((mtrain != null) && (_valid != null)) {
3506-
_scoringHistoryControlValEnabled.addIterationScore(true, true, _state._iter, task._likelihood,
3507-
objectiveControlVal, _state.deviance(task._likelihood), ((GLMMetrics) _model._output._validation_metrics_restricted_model_cv).residual_deviance(),
3508-
mtrain._nobs, _model._output._validation_metrics_restricted_model_cv._nobs, _state.lambda(), _state.alpha());
3509-
} else { // only doing training deviance
3510-
_scoringHistoryControlValEnabled.addIterationScore(true, false, _state._iter, task._likelihood,
3511-
objectiveControlVal, _state.deviance(task._likelihood), Double.NaN, mtrain._nobs, 1, _state.lambda(),
3512-
_state.alpha());
3513-
3460+
Log.info(LogMsg("Restricted model training metrics computed in " + (t2 - t1) + "ms"));
3461+
3462+
ModelMetrics mvalid = null;
3463+
if (_valid != null) {
3464+
Frame valid = DKV.<Frame>getGet(_parms._valid);
3465+
_model.score(_parms.valid(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
3466+
mvalid = ModelMetrics.getFromDKV(_model, valid);
3467+
storeValid.accept(mvalid);
35143468
}
3515-
_model._output._scoring_history_restricted_model_cv = _scoringHistoryControlValEnabled != null ? _scoringHistoryControlValEnabled.to2dTable(_parms, null, null) : null;
3516-
}
3517-
}
3518-
3519-
private void scorePostProcessingRestrictedModelROEnabled(Frame train, long t1) {
3520-
ModelMetrics mtrain = ModelMetrics.getFromDKV(_model, train); // updated by model.scoreAndUpdateModel
3521-
long t2 = System.currentTimeMillis();
3522-
if (mtrain != null) {
3523-
_model._output._training_metrics_restricted_model_ro = mtrain;
3524-
_model._output._training_time_ms = t2 - _model._output._start_time; // remember training time
3525-
ScoreKeeper trainScore = new ScoreKeeper(Double.NaN);
3526-
trainScore.fillFrom(mtrain);
3527-
Log.info(LogMsg(mtrain.toString()));
3528-
} else {
3529-
Log.info(LogMsg("ModelMetrics mtrain is null"));
3530-
}
3531-
Log.info(LogMsg("Restricted model where remove offset feature is enabled training metrics computed in " + (t2 - t1) + "ms"));
3532-
if (_valid != null) {
3533-
Frame valid = DKV.<Frame>getGet(_parms._valid);
3534-
_model.score(_parms.valid(), null, CFuncRef.from(_parms._custom_metric_func)).delete();
3535-
_model._output._validation_metrics_restricted_model_ro = ModelMetrics.getFromDKV(_model, valid); //updated by model.scoreAndUpdateModel
3536-
ScoreKeeper validScore = new ScoreKeeper(Double.NaN);
3537-
validScore.fillFrom(_model._output._validation_metrics_restricted_model_ro);
3538-
}
3539-
_model.addRestrictedModelScoringInfoRO(_parms, nclasses(), t2, _state._iter); // add to scoringInfo for early stopping
35403469

3541-
if (_parms._generate_scoring_history) { // update scoring history with deviance train and valid if available
3542-
if ((mtrain != null) && (_valid != null)) {
3543-
_scoringHistoryRemoveOffsetEnabled.addIterationScore(true, true, _state._iter, _state.likelihood(),
3544-
_state.objective(), _state.deviance(), ((GLMMetrics) _model._output._validation_metrics_restricted_model_ro).residual_deviance(),
3545-
_model._output._training_metrics_restricted_model_ro._nobs, _model._output._validation_metrics_restricted_model_ro._nobs, _state.lambda(), _state.alpha());
3546-
} else { // only doing training deviance
3547-
_scoringHistoryRemoveOffsetEnabled.addIterationScore(true, false, _state._iter, _state.likelihood(),
3548-
_state.objective(), _state.deviance(), Double.NaN, _model._output._training_metrics_restricted_model_ro._nobs, 1, _state.lambda(),
3549-
_state.alpha());
3470+
if (sh != null && _parms._generate_scoring_history && mtrain != null) {
3471+
double likelihood, objective, deviance;
3472+
if (useControlVarBeta) {
3473+
double[] beta = _model._output.getControlValBeta(_state.expandBeta(_state.beta()).clone());
3474+
GLMResDevTask task = new GLMResDevTask(_job._key, _dinfo, _parms, beta).doAll(_dinfo._adaptedFrame);
3475+
likelihood = task._likelihood;
3476+
objective = _state.objective(beta, task._likelihood);
3477+
deviance = _state.deviance(task._likelihood);
3478+
} else {
3479+
likelihood = _state.likelihood();
3480+
objective = _state.objective();
3481+
deviance = _state.deviance();
3482+
}
3483+
double validDev = mvalid != null ? ((GLMMetrics) mvalid).residual_deviance() : Double.NaN;
3484+
long validNobs = mvalid != null ? mvalid._nobs : 1;
3485+
sh.addIterationScore(true, mvalid != null, _state._iter, likelihood, objective,
3486+
deviance, validDev, mtrain._nobs, validNobs, _state.lambda(), _state.alpha());
35503487
}
3551-
}
3552-
_model._output._scoring_history_restricted_model_ro = _scoringHistoryRemoveOffsetEnabled != null ? _scoringHistoryRemoveOffsetEnabled.to2dTable(_parms, null, null) : null;
35533488
}
35543489

35553490
private void scorePostProcessing(Frame train, long t1) {
@@ -3767,7 +3702,8 @@ public void computeImpl() {
37673702
if (_parms._keepBetaDiffVar)
37683703
keepFrameKeys(keep, _model._output._betadiff_var);
37693704
Scope.untrack(keep.toArray(new Key[keep.size()]));
3770-
}_model.update(_job._key);
3705+
}
3706+
_model.update(_job._key);
37713707
_model.generateSummary(_parms._train, _state._iter);
37723708
_model.unlock(_job);
37733709
}
@@ -4154,6 +4090,7 @@ public boolean progress(double[] beta, double likelihood) {
41544090
// hence I need to make sure that updateProgress is not run inside cv_computeAndSetOptimalParameters.
41554091
if ((!_parms._lambda_search || _parms._generate_scoring_history) && !_insideCVCheck)
41564092
updateProgress(true);
4093+
_job.update(_workPerIteration, _state.toString());
41574094
boolean converged = !_earlyStopEnabled && _state.converged();
41584095
if (converged) Log.info(LogMsg(_state.convergenceMsg));
41594096
return !stop_requested() && !converged && _state._iter < _parms._max_iterations && !_earlyStop;
@@ -4176,9 +4113,7 @@ protected void updateProgress(boolean canScore) {
41764113
} else {
41774114
_scoringHistory.addIterationScore(_state._iter, _state.likelihood(), _state.objective());
41784115
}
4179-
_job.update(_workPerIteration, _state.toString()); // glm specific scoring history is updated every iteration
41804116
}
4181-
41824117
if (canScore && (_parms._score_each_iteration || timeSinceLastScoring() > _scoringInterval ||
41834118
((_parms._score_iteration_interval > 0) && ((_state._iter % _parms._score_iteration_interval) == 0)))) {
41844119
_model.update(_state.expandBeta(_state.beta()), -1, -1, _state._iter);

0 commit comments

Comments
 (0)