diff --git a/h2o-algos/src/main/java/hex/schemas/DTV3.java b/h2o-algos/src/main/java/hex/schemas/DTV3.java index 8849387adc35..98dc49570476 100644 --- a/h2o-algos/src/main/java/hex/schemas/DTV3.java +++ b/h2o-algos/src/main/java/hex/schemas/DTV3.java @@ -19,6 +19,7 @@ public static final class DTParametersV3 extends ModelParametersSchemaV3 (" - + decisionValue + ", probabilities: " + probability + ", " + (1 - probability) + ")"); + double[] probabilities = ((CompressedLeaf) _nodes[actualNodeIndex]).getProbabilities(); + return new DTPrediction((int) decisionValue, probabilities, + ruleExplanation + " -> " + _nodes[actualNodeIndex].toString()); } if (!ruleExplanation.isEmpty()) { ruleExplanation += " and "; } AbstractSplittingRule splittingRule = ((CompressedNode) _nodes[actualNodeIndex]).getSplittingRule(); - // splitting rule is true - left, false - right + // splitting rule is: true - left, false - right if(splittingRule.routeSample(rowValues)) { return predictRowStartingFromNode(rowValues, 2 * actualNodeIndex + 1, ruleExplanation + splittingRule.toString()); @@ -65,7 +65,7 @@ public int extractRulesStartingWithNode(int nodeIndex, String actualRule, int ne if (_nodes[nodeIndex] instanceof CompressedLeaf) { // if node is a leaf, add the rule to the list of rules at index given by the nextFreeSpot parameter _listOfRules[nextFreeSpot] = actualRule + " -> (" + ((CompressedLeaf) _nodes[nodeIndex]).getDecisionValue() - + ", " + ((CompressedLeaf) _nodes[nodeIndex]).getProbabilities() + ")"; + + ", " + Arrays.toString(((CompressedLeaf) _nodes[nodeIndex]).getProbabilities()) + ")"; // move nextFreeSpot to the next index and return it to be used for other branches nextFreeSpot++; return nextFreeSpot; diff --git a/h2o-algos/src/main/java/hex/tree/dt/CompressedLeaf.java b/h2o-algos/src/main/java/hex/tree/dt/CompressedLeaf.java index 7d77f5ae20ec..c8abb04d84b4 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/CompressedLeaf.java +++ b/h2o-algos/src/main/java/hex/tree/dt/CompressedLeaf.java @@ -1,27 +1,32 @@ package hex.tree.dt; +import java.util.Arrays; +import java.util.stream.Collectors; + public class CompressedLeaf extends AbstractCompressedNode { private final double _decisionValue; - private final double _probability; + private final double[] _probabilities; - public CompressedLeaf(double decisionValue, double probabilities) { + public CompressedLeaf(double decisionValue, double[] probabilities) { super(); _decisionValue = decisionValue; - _probability = probabilities; + _probabilities = probabilities; } public double getDecisionValue() { return _decisionValue; } - public double getProbabilities() { - return _probability; + public double[] getProbabilities() { + return _probabilities; } @Override public String toString() { - return "(leaf: " + _decisionValue + ", " + _probability + ", " + (1- _probability) + ")"; + return "(leaf: " + _decisionValue + "; " + + Arrays.stream(_probabilities).mapToObj(Double::toString) + .collect(Collectors.joining(", ")) + ")"; } } diff --git a/h2o-algos/src/main/java/hex/tree/dt/DT.java b/h2o-algos/src/main/java/hex/tree/dt/DT.java index 99d747584625..16cebbd514be 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/DT.java +++ b/h2o-algos/src/main/java/hex/tree/dt/DT.java @@ -8,7 +8,6 @@ import hex.tree.dt.binning.Histogram; import hex.tree.dt.mrtasks.GetClassCountsMRTask; import hex.tree.dt.mrtasks.ScoreDTTask; -import org.apache.commons.math3.util.Precision; import org.apache.log4j.Logger; import water.DKV; import water.exceptions.H2OModelBuilderIllegalArgumentException; @@ -19,7 +18,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import static hex.tree.dt.binning.SplitStatistics.entropyBinarySplit; +import static hex.tree.dt.binning.SplitStatistics.entropyMulticlass; /** * Decision Tree @@ -49,8 +48,6 @@ public class DT extends ModelBuilder ((binStatistics._leftCount >= _min_rows) && (binStatistics._rightCount >= _min_rows))) .peek(binStatistics -> Log.debug("split: " + binStatistics._splittingRule + ", counts: " @@ -128,7 +124,7 @@ private AbstractSplittingRule findBestSplitForFeature(Histogram histogram, int f private static double calculateCriterionOfSplit(SplitStatistics binStatistics) { - return binStatistics.binaryEntropy(); + return binStatistics.splitEntropy(); } /** @@ -139,7 +135,7 @@ private static double calculateCriterionOfSplit(SplitStatistics binStatistics) { */ private int selectDecisionValue(int[] countsByClass) { if (_nclass == 1) { - return countsByClass[0]; + return 0; } int currentMaxClass = 0; int currentMax = countsByClass[currentMaxClass]; @@ -155,10 +151,10 @@ private int selectDecisionValue(int[] countsByClass) { /** * Calculates probabilities of each class for a leaf. * - * @param countsByClass counts of 0 and 1 in a leaf - * @return probabilities of 0 or 1 + * @param countsByClass counts of each class in a leaf + * @return probabilities of each class */ - private double[] calculateProbability(int[] countsByClass) { + private double[] calculateProbabilities(int[] countsByClass) { int samplesCount = Arrays.stream(countsByClass).sum(); return Arrays.stream(countsByClass).asDoubleStream().map(n -> n / samplesCount).toArray(); } @@ -171,7 +167,7 @@ private double[] calculateProbability(int[] countsByClass) { * @param nodeIndex node index */ public void makeLeafFromNode(int[] countsByClass, int nodeIndex) { - _tree[nodeIndex] = new CompressedLeaf(selectDecisionValue(countsByClass), calculateProbability(countsByClass)[0]); + _tree[nodeIndex] = new CompressedLeaf(selectDecisionValue(countsByClass), calculateProbabilities(countsByClass)); _leavesCount++; // nothing to return, node is modified inplace } @@ -200,16 +196,19 @@ public void buildNextNode(Queue limitsQueue, int nodeIndex) // [count0, count1, ...] int[] countsByClass = countClasses(actualLimits); if (nodeIndex == 0) { - Log.info("Classes counts in dataset: 0 - " + countsByClass[0] + ", 1 - " + countsByClass[1]); + Log.info(IntStream.range(0, countsByClass.length) + .mapToObj(i -> i + " - " + countsByClass[i]) + .collect(Collectors.joining(", ", "Classes counts in dataset: ", ""))); } // compute node depth int nodeDepth = (int) Math.floor(MathUtils.log2(nodeIndex + 1)); - // stop building from this node, the node will be a leaf - if ((nodeDepth >= _parms._max_depth) - || (countsByClass[0] <= _min_rows) - || (countsByClass[1] <= _min_rows) -// || zeroRatio > 0.999 || zeroRatio < 0.001 - ) { + // stop building from this node, the node will be a leaf if: + // - max depth is reached + // - there is only one non-zero count in the countsByClass + // - there are not enough data points in the node + if ((nodeDepth >= _parms._max_depth) + || Arrays.stream(countsByClass).filter(c -> c > 0).count() < 2 + || Arrays.stream(countsByClass).sum() < _min_rows) { // add imaginary left and right children to imitate valid tree structure // left child limitsQueue.add(null); @@ -219,10 +218,10 @@ public void buildNextNode(Queue limitsQueue, int nodeIndex) return; } - Histogram histogram = new Histogram(_train, actualLimits, BinningStrategy.EQUAL_WIDTH/*, minNumSamplesInBin - todo consider*/); + Histogram histogram = new Histogram(_train, actualLimits, BinningStrategy.EQUAL_WIDTH, _nclass); AbstractSplittingRule bestSplittingRule = findBestSplit(histogram); - double criterionForTheParentNode = entropyBinarySplit(1.0 * countsByClass[0] / (countsByClass[0] + countsByClass[1])); + double criterionForTheParentNode = entropyMulticlass(countsByClass, Arrays.stream(countsByClass).sum()); // if no split could be found, make a list from current node // if the information gain is low, make a leaf from current node if (bestSplittingRule == null @@ -291,9 +290,6 @@ private void dtChecks() { if (!_response.isCategorical()) { error("_response", "Only categorical response is supported"); } - if (!_response.isBinary()) { - error("_response", "Only binary response is supported"); - } } @Override @@ -365,7 +361,7 @@ public BuilderVisibility builderVisibility() { public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Binomial, -// ModelCategory.Multinomial, + ModelCategory.Multinomial, // ModelCategory.Ordinal, // ModelCategory.Regression }; diff --git a/h2o-algos/src/main/java/hex/tree/dt/DTModel.java b/h2o-algos/src/main/java/hex/tree/dt/DTModel.java index 58c173e31d67..82277e84d1d2 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/DTModel.java +++ b/h2o-algos/src/main/java/hex/tree/dt/DTModel.java @@ -4,7 +4,6 @@ import org.apache.log4j.Logger; import water.*; -import java.util.Arrays; public class DTModel extends Model { @@ -36,10 +35,10 @@ protected double[] score0(double[] data, double[] preds) { // compute score for given point CompressedDT tree = DKV.getGet(_output._treeKey); DTPrediction prediction = tree.predictRowStartingFromNode(data, 0, ""); - // for now, only pred. for class 0 is stored, will be improved later preds[0] = prediction.classPrediction; - preds[1] = prediction.probability; - preds[2] = 1 - prediction.probability; + for (int i = 0; i < prediction.probabilities.length; i++) { + preds[i + 1] = prediction.probabilities[i]; + } return preds; } diff --git a/h2o-algos/src/main/java/hex/tree/dt/DTPrediction.java b/h2o-algos/src/main/java/hex/tree/dt/DTPrediction.java index d2d343ee707c..8f6eb135534e 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/DTPrediction.java +++ b/h2o-algos/src/main/java/hex/tree/dt/DTPrediction.java @@ -2,12 +2,12 @@ public class DTPrediction { public int classPrediction; - public double probability; + public double[] probabilities; public String ruleExplanation; - public DTPrediction(int classPrediction, double probability, String ruleExplanation) { + public DTPrediction(int classPrediction, double[] probabilities, String ruleExplanation) { this.classPrediction = classPrediction; - this.probability = probability; + this.probabilities = probabilities; this.ruleExplanation = ruleExplanation; } } diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/AbstractBin.java b/h2o-algos/src/main/java/hex/tree/dt/binning/AbstractBin.java index 3e981c31a300..eeb2ec4cf274 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/AbstractBin.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/AbstractBin.java @@ -5,11 +5,11 @@ * Single bin holding limits (min excluded), count of samples and count of class 0. */ public abstract class AbstractBin { - public int _count0; + public int[] _classesDistribution; public int _count; - public int getCount0() { - return _count0; + public int getClassCount(int i) { + return _classesDistribution[i]; } public abstract AbstractBin clone(); diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/BinningStrategy.java b/h2o-algos/src/main/java/hex/tree/dt/binning/BinningStrategy.java index aa4b65be5fd2..559b3642fc38 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/BinningStrategy.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/BinningStrategy.java @@ -9,13 +9,12 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; - -import static hex.tree.dt.mrtasks.CountBinsSamplesCountsMRTask.COUNT; -import static hex.tree.dt.mrtasks.CountBinsSamplesCountsMRTask.COUNT_0; +import java.util.stream.DoubleStream; /** - * Strategy for binning. Creates bins for single feature. + * Strategy for binning. Create bins for single feature. */ public enum BinningStrategy { @@ -39,13 +38,13 @@ public enum BinningStrategy { return roundToNDecimalPoints(number, DECIMALS_TO_CONSIDER); } - private List createEmptyBinsFromBinningValues(List binningValues, double realMin, double realMax) { + private List createEmptyBinsFromBinningValues(List binningValues, double realMin, double realMax, int nclass) { List emptyBins = new ArrayList<>(); // create bins between nearest binning values, don't create bin starting with the last value (on index size - 1) for (int i = 0; i < binningValues.size() - 1; i++) { emptyBins.add( new NumericBin(roundToNDecimalPoints(binningValues.get(i)), - roundToNDecimalPoints(binningValues.get(i + 1)))); + roundToNDecimalPoints(binningValues.get(i + 1)), nclass)); } // set the firs min to some lower value (relative to step) so the actual value equal to min is not lost ((NumericBin) emptyBins.get(0)).setMin(realMin - MIN_REL_COEFF * (binningValues.get(1) - binningValues.get(0))); @@ -55,7 +54,7 @@ private List createEmptyBinsFromBinningValues(List binningV } @Override - List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature) { + List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature, int nclass) { if (originData.vec(feature).isNumeric()) { NumericFeatureLimits featureLimits = (NumericFeatureLimits) featuresLimits.getFeatureLimits(feature); double step = (featureLimits._max - featureLimits._min) / NUM_BINS; @@ -69,7 +68,7 @@ List createFeatureBins(Frame originData, DataFeaturesLimits feature binningValues.add(value); } List emptyBins = createEmptyBinsFromBinningValues( - binningValues, featureLimits._min, featureLimits._max); + binningValues, featureLimits._min, featureLimits._max, nclass); return calculateNumericBinSamplesCount(originData, emptyBins, featuresLimits.toDoubles(), feature); } else { @@ -78,7 +77,7 @@ List createFeatureBins(Frame originData, DataFeaturesLimits feature for (int category = 0; category < featureLimits._mask.length; category++) { // if the category is present in feature values, add new bin for this category if (featureLimits._mask[category]) { - emptyBins.add(new CategoricalBin(category)); + emptyBins.add(new CategoricalBin(category, nclass)); } } @@ -90,23 +89,23 @@ List createFeatureBins(Frame originData, DataFeaturesLimits feature }, /** - * Equal height: bins have approximately the same size - todo + * Equal height: bins have approximately the same size (not implemented yet) * - probably too costly to do it with MR task, better leave equal-width */ EQUAL_HEIGHT { @Override - List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature) { + List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature, int nclass) { return null; } }, /** - * Custom bins: works with provided bins limits - todo + * Custom bins: works with provided bins limits (not implemented yet) */ CUSTOM_BINS { @Override - List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature) { + List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature, int nclass) { return null; } }; @@ -121,7 +120,7 @@ List createFeatureBins(Frame originData, DataFeaturesLimits feature * @param feature selected feature index * @return list of created bins */ - abstract List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature); + abstract List createFeatureBins(Frame originData, DataFeaturesLimits featuresLimits, int feature, int nclass); /** * Calculates samples count for given bins for categorical feature. @@ -136,11 +135,14 @@ private static List calculateCategoricalBinSamplesCount(Frame data, double[][] featuresLimits, int feature) { // run MR task to compute accumulated statistic for bins - one task for one feature, calculates all bins at once double[][] binsArray = bins.stream().map(AbstractBin::toDoubles).toArray(double[][]::new); - CountBinsSamplesCountsMRTask task = new CountBinsSamplesCountsMRTask(feature, featuresLimits, binsArray); + int countsOffset = CountBinsSamplesCountsMRTask.CAT_COUNT_OFFSET; + CountBinsSamplesCountsMRTask task = new CountBinsSamplesCountsMRTask(feature, featuresLimits, binsArray, countsOffset); task.doAll(data); - for(int i = 0; i < binsArray.length; i ++) { - bins.get(i)._count = (int) task._bins[i][COUNT]; - bins.get(i)._count0 = (int) task._bins[i][COUNT_0]; + for (int i = 0; i < binsArray.length; i++) { + bins.get(i)._count = (int) task._bins[i][countsOffset]; + bins.get(i)._classesDistribution = + DoubleStream.of(Arrays.copyOfRange(task._bins[i], countsOffset + 1, task._bins[i].length)) + .mapToInt(c -> (int) c).toArray(); } return bins; } @@ -158,11 +160,15 @@ private static List calculateNumericBinSamplesCount(Frame data, Lis double[][] featuresLimits, int feature) { // run MR task to compute accumulated statistic for bins - one task for one feature, calculates all bins at once double[][] binsArray = bins.stream().map(AbstractBin::toDoubles).toArray(double[][]::new); - CountBinsSamplesCountsMRTask task = new CountBinsSamplesCountsMRTask(feature, featuresLimits, binsArray); + int countsOffset = CountBinsSamplesCountsMRTask.NUM_COUNT_OFFSET; + CountBinsSamplesCountsMRTask task = new CountBinsSamplesCountsMRTask(feature, featuresLimits, binsArray, countsOffset); task.doAll(data); - for(int i = 0; i < binsArray.length; i ++) { - bins.get(i)._count = (int) task._bins[i][COUNT]; - bins.get(i)._count0 = (int) task._bins[i][COUNT_0]; + + for (int i = 0; i < binsArray.length; i++) { + bins.get(i)._count = (int) task._bins[i][countsOffset]; + bins.get(i)._classesDistribution = + DoubleStream.of(Arrays.copyOfRange(task._bins[i], countsOffset + 1, task._bins[i].length)) + .mapToInt(c -> (int) c).toArray(); } return bins; } diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/CategoricalBin.java b/h2o-algos/src/main/java/hex/tree/dt/binning/CategoricalBin.java index 6008836ed655..3213a3a62875 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/CategoricalBin.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/CategoricalBin.java @@ -1,21 +1,25 @@ package hex.tree.dt.binning; +import org.apache.commons.lang.ArrayUtils; + +import java.util.Arrays; + /** * For categorical features values are already binned to categories - each bin corresponds to one value (category) */ public class CategoricalBin extends AbstractBin { public int _category; - public CategoricalBin(int category, int count, int count0) { + public CategoricalBin(int category, int[] classesDistribution, int count) { _category = category; + _classesDistribution = classesDistribution; _count = count; - _count0 = count0; } - public CategoricalBin(int category) { + public CategoricalBin(int category, int nclass) { _category = category; + _classesDistribution = new int[nclass]; _count = 0; - _count0 = 0; } public int getCategory() { @@ -23,11 +27,13 @@ public int getCategory() { } public CategoricalBin clone() { - return new CategoricalBin(_category, _count, _count0); + return new CategoricalBin(_category, _classesDistribution, _count); } public double[] toDoubles() { - return new double[]{_category, _count, _count0}; + // category|count|class0|class1|... + return ArrayUtils.addAll(new double[]{_category, _count}, + Arrays.stream(_classesDistribution).asDoubleStream().toArray()); } } diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/FeatureBins.java b/h2o-algos/src/main/java/hex/tree/dt/binning/FeatureBins.java index 4e7fffc96fbc..911109e63ee0 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/FeatureBins.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/FeatureBins.java @@ -9,7 +9,7 @@ public class FeatureBins { private List _bins; - private final boolean _isConstant; // todo - test this + private final boolean _isConstant; private int _numOfCategories; public FeatureBins(List bins) { @@ -32,18 +32,18 @@ public FeatureBins(List bins, int numOfCategories) { * * @return list of accumulated statistics, matches original bins */ - public List calculateSplitStatisticsForNumericFeature() { + public List calculateSplitStatisticsForNumericFeature(int nclass) { // init list with empty instances List statistics = _bins.stream() - .map(b -> new SplitStatistics()).collect(Collectors.toList()); + .map(b -> new SplitStatistics(nclass)).collect(Collectors.toList()); // calculate accumulative statistics for each split: // left split - bins to the left + current; // right split - bins to the right. - SplitStatistics tmpAccumulatorLeft = new SplitStatistics(); - SplitStatistics tmpAccumulatorRight = new SplitStatistics(); + SplitStatistics tmpAccumulatorLeft = new SplitStatistics(nclass); + SplitStatistics tmpAccumulatorRight = new SplitStatistics(nclass); int rightIndex; for (int leftIndex = 0; leftIndex < statistics.size(); leftIndex++) { - tmpAccumulatorLeft.accumulateLeftStatistics(_bins.get(leftIndex)._count, _bins.get(leftIndex)._count0); + tmpAccumulatorLeft.accumulateLeftStatistics(_bins.get(leftIndex)._count, _bins.get(leftIndex)._classesDistribution); statistics.get(leftIndex).copyLeftValues(tmpAccumulatorLeft); statistics.get(leftIndex)._splittingRule = new NumericSplittingRule(((NumericBin) _bins.get(leftIndex))._max); // accumulate from the right (from the end of bins array) @@ -51,7 +51,7 @@ public List calculateSplitStatisticsForNumericFeature() { // firstly accumulate with old values, then add the actual bin for the future statistics // as the values of the actual bin are not included in its right statistics statistics.get(rightIndex).copyRightValues(tmpAccumulatorRight); - tmpAccumulatorRight.accumulateRightStatistics(_bins.get(rightIndex)._count, _bins.get(rightIndex)._count0); + tmpAccumulatorRight.accumulateRightStatistics(_bins.get(rightIndex)._count, _bins.get(rightIndex)._classesDistribution); } return statistics; } @@ -64,15 +64,17 @@ List getFeatureBins() { return _bins.stream().map(AbstractBin::clone).collect(Collectors.toList()); } - public List calculateSplitStatisticsForCategoricalFeature() { - // for binomial classification sort bins by the frequency of one class and split similarly to the sequential feature - return calculateStatisticsForCategoricalFeatureBinomialClassification(); - - // full approach for binomial/multinomial/regression, works fpr up to 10 categories -// return calculateStatisticsForCategoricalFeatureFullApproach(); + public List calculateSplitStatisticsForCategoricalFeature(int nclass) { + if(nclass == 2) { + // for binomial classification sort bins by the frequency of one class and split similarly to the sequential feature + return calculateStatisticsForCategoricalFeatureBinomialClassification(nclass); + } else { + // full approach for binomial/multinomial/regression, works for up to 10 categories + return calculateStatisticsForCategoricalFeatureFullApproach(nclass); + } } - private List calculateStatisticsForCategoricalFeatureFullApproach() { + private List calculateStatisticsForCategoricalFeatureFullApproach(int nclass) { // calculate accumulative statistics for each subset of categories: // left split - categories included in the subset; // right split - categories not included in subset. @@ -80,17 +82,17 @@ private List calculateStatisticsForCategoricalFeatureFullApproa // as now the max supported category is 9 and for the bigger number the faster sequential approach should be used // init list with empty instances String categories = _bins.stream().map(b -> String.valueOf(((CategoricalBin) b)._category)) - .collect(Collectors.joining("")); // is it always 0 to _bins.size()? + .collect(Collectors.joining("")); Set splits = findAllCategoricalSplits(categories); List statistics = new ArrayList<>(); for (boolean[] splitMask : splits) { - SplitStatistics splitStatistics = new SplitStatistics(); + SplitStatistics splitStatistics = new SplitStatistics(nclass); for (AbstractBin bin : _bins) { // if bin category is in the mask, it belongs to the left split, otherwise it belongs to the right split if (splitMask[((CategoricalBin) bin)._category]) { - splitStatistics.accumulateLeftStatistics(bin._count, bin._count0); + splitStatistics.accumulateLeftStatistics(bin._count, bin._classesDistribution); } else { - splitStatistics.accumulateRightStatistics(bin._count, bin._count0); + splitStatistics.accumulateRightStatistics(bin._count, bin._classesDistribution); } } splitStatistics._splittingRule = new CategoricalSplittingRule(splitMask); @@ -122,7 +124,7 @@ else for (String s : categories.split("")) { } private void rec(Set masks, String current, String categories, int stepsToGo) { - if (stepsToGo == 0) { + if (stepsToGo == 0 || categories.isEmpty()) { masks.add(createMaskFromString(current)); return; } @@ -147,23 +149,23 @@ private boolean[] createMaskFromBins(List bins) { return mask; } - public List calculateStatisticsForCategoricalFeatureBinomialClassification() { + public List calculateStatisticsForCategoricalFeatureBinomialClassification(int nclass) { List sortedBins = _bins.stream() .map(b -> (CategoricalBin) b) - .sorted(Comparator.comparingInt(CategoricalBin::getCount0)) + .sorted(Comparator.comparing(c -> c.getClassCount(0) / c._count)) .collect(Collectors.toList()); // init list with empty instances List statistics = sortedBins.stream() - .map(b -> new SplitStatistics()).collect(Collectors.toList()); + .map(b -> new SplitStatistics(nclass)).collect(Collectors.toList()); // calculate accumulative statistics for each split: // left split - bins to the left + current; // right split - bins to the right. - SplitStatistics tmpAccumulatorLeft = new SplitStatistics(); - SplitStatistics tmpAccumulatorRight = new SplitStatistics(); + SplitStatistics tmpAccumulatorLeft = new SplitStatistics(nclass); + SplitStatistics tmpAccumulatorRight = new SplitStatistics(nclass); int rightIndex; for (int leftIndex = 0; leftIndex < statistics.size(); leftIndex++) { - tmpAccumulatorLeft.accumulateLeftStatistics(sortedBins.get(leftIndex)._count, sortedBins.get(leftIndex)._count0); + tmpAccumulatorLeft.accumulateLeftStatistics(sortedBins.get(leftIndex)._count, sortedBins.get(leftIndex)._classesDistribution); statistics.get(leftIndex).copyLeftValues(tmpAccumulatorLeft); statistics.get(leftIndex)._splittingRule = new CategoricalSplittingRule( createMaskFromBins(sortedBins.subList(0, leftIndex + 1))); // subList takes toIndex exclusive, so +1 @@ -172,7 +174,7 @@ public List calculateStatisticsForCategoricalFeatureBinomialCla // firstly accumulate with old values, then add the actual bin for the future statistics // as the values of the actual bin are not included in its right statistics statistics.get(rightIndex).copyRightValues(tmpAccumulatorRight); - tmpAccumulatorRight.accumulateRightStatistics(sortedBins.get(rightIndex)._count, sortedBins.get(rightIndex)._count0); + tmpAccumulatorRight.accumulateRightStatistics(sortedBins.get(rightIndex)._count, sortedBins.get(rightIndex)._classesDistribution); } return statistics; } diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/Histogram.java b/h2o-algos/src/main/java/hex/tree/dt/binning/Histogram.java index 57b215c38e6e..7622e4f1ff92 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/Histogram.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/Histogram.java @@ -15,7 +15,7 @@ public class Histogram { private final List _featuresBins; private final BinningStrategy _binningStrategy; - public Histogram(Frame originData, DataFeaturesLimits conditionLimits, BinningStrategy binningStrategy) { + public Histogram(Frame originData, DataFeaturesLimits conditionLimits, BinningStrategy binningStrategy, int nclass) { _binningStrategy = binningStrategy; // get real features limits where the conditions are fulfilled DataFeaturesLimits featuresLimitsForConditions = getFeaturesLimitsForConditions(originData, conditionLimits); @@ -23,7 +23,7 @@ public Histogram(Frame originData, DataFeaturesLimits conditionLimits, BinningSt _featuresBins = IntStream .range(0, originData.numCols() - 1/*exclude the last prediction column*/) .mapToObj(i -> new FeatureBins( - _binningStrategy.createFeatureBins(originData, featuresLimitsForConditions, i), + _binningStrategy.createFeatureBins(originData, featuresLimitsForConditions, i, nclass), originData.vec(i).cardinality())) .collect(Collectors.toList()); } @@ -57,12 +57,12 @@ public static DataFeaturesLimits getFeaturesLimitsForConditions(Frame originData return new DataFeaturesLimits(task._realFeatureLimits); } - public List calculateSplitStatisticsForNumericFeature(int feature) { - return _featuresBins.get(feature).calculateSplitStatisticsForNumericFeature(); + public List calculateSplitStatisticsForNumericFeature(int feature, int nclass) { + return _featuresBins.get(feature).calculateSplitStatisticsForNumericFeature(nclass); } - public List calculateSplitStatisticsForCategoricalFeature(int feature) { - return _featuresBins.get(feature).calculateSplitStatisticsForCategoricalFeature(); + public List calculateSplitStatisticsForCategoricalFeature(int feature, int nclass) { + return _featuresBins.get(feature).calculateSplitStatisticsForCategoricalFeature(nclass); } public boolean isConstant(int featureIndex) { diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/NumericBin.java b/h2o-algos/src/main/java/hex/tree/dt/binning/NumericBin.java index 513b9a866679..a37a616d8138 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/NumericBin.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/NumericBin.java @@ -1,7 +1,10 @@ package hex.tree.dt.binning; +import org.apache.commons.lang.ArrayUtils; import water.util.Pair; +import java.util.Arrays; + /** * Single bin holding limits (min excluded), count of samples and count of class 0. */ @@ -9,38 +12,39 @@ public class NumericBin extends AbstractBin { public double _min; public double _max; - public static final int MIN_INDEX = 3; - public static final int MAX_INDEX = 4; + public static final int MIN_INDEX = 1; + public static final int MAX_INDEX = 2; - public NumericBin(double min, double max, int count, int count0) { + public NumericBin(double min, double max, int[] classesDistribution, int count) { _min = min; _max = max; - + _classesDistribution = classesDistribution; _count = count; - _count0 = count0; } - public NumericBin(double min, double max) { + public NumericBin(double min, double max, int nclass) { _min = min; _max = max; + _classesDistribution = new int[nclass]; _count = 0; - _count0 = 0; } - public NumericBin(Pair binLimits) { + public NumericBin(Pair binLimits, int nclass) { _min = binLimits._1(); _max = binLimits._2(); + _classesDistribution = new int[nclass]; _count = 0; - _count0 = 0; } public NumericBin clone() { - return new NumericBin(_min, _max, _count, _count0); + return new NumericBin(_min, _max, _classesDistribution, _count); } public double[] toDoubles() { // Place numeric flag -1.0 on the index 0 to mark that the feature is numeric - return new double[]{-1.0, _count, _count0, _min, _max}; + // -1|min|max|count|class0|class1|... + return ArrayUtils.addAll(new double[]{-1.0, _min, _max, _count}, + Arrays.stream(_classesDistribution).asDoubleStream().toArray()); } public void setMin(double min) { diff --git a/h2o-algos/src/main/java/hex/tree/dt/binning/SplitStatistics.java b/h2o-algos/src/main/java/hex/tree/dt/binning/SplitStatistics.java index cf19a653ff4c..4ddd60a54440 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/binning/SplitStatistics.java +++ b/h2o-algos/src/main/java/hex/tree/dt/binning/SplitStatistics.java @@ -1,7 +1,8 @@ package hex.tree.dt.binning; import hex.tree.dt.AbstractSplittingRule; -import org.apache.commons.math3.util.Precision; + +import java.util.Arrays; /** * Potential split including splitting rule and statistics on count of samples and distribution of target variable. @@ -11,36 +12,45 @@ public class SplitStatistics { public AbstractSplittingRule _splittingRule; public int _leftCount; - public int _leftCount0; + public int[] _leftClassDistribution; public int _rightCount; - public int _rightCount0; + public int[] _rightClassDistribution; - public SplitStatistics() { + public SplitStatistics(int numClasses) { _leftCount = 0; - _leftCount0 = 0; + _leftClassDistribution = new int[numClasses]; _rightCount = 0; - _rightCount0 = 0; + _rightClassDistribution = new int[numClasses]; } - public void accumulateLeftStatistics(int leftCount, int leftCount0) { + public SplitStatistics() { + // call constructor with default value + this(1); + } + + public void accumulateLeftStatistics(int leftCount, int[] leftClassDistribution) { _leftCount += leftCount; - _leftCount0 += leftCount0; + for (int i = 0; i < _leftClassDistribution.length; i++) { + _leftClassDistribution[i] += leftClassDistribution[i]; + } } - public void accumulateRightStatistics(int rightCount, int rightCount0) { + public void accumulateRightStatistics(int rightCount, int[] rightClassDistribution) { _rightCount += rightCount; - _rightCount0 += rightCount0; + for (int i = 0; i < _rightClassDistribution.length; i++) { + _rightClassDistribution[i] += rightClassDistribution[i]; + } } public void copyLeftValues(SplitStatistics toCopy) { _leftCount = toCopy._leftCount; - _leftCount0 = toCopy._leftCount0; + _leftClassDistribution = Arrays.copyOf(toCopy._leftClassDistribution, toCopy._leftClassDistribution.length); } public void copyRightValues(SplitStatistics toCopy) { _rightCount = toCopy._rightCount; - _rightCount0 = toCopy._rightCount0; + _rightClassDistribution = Arrays.copyOf(toCopy._rightClassDistribution, toCopy._rightClassDistribution.length); } public SplitStatistics setCriterionValue(double criterionOfSplit) { @@ -53,16 +63,17 @@ public SplitStatistics setFeatureIndex(int featureIndex) { return this; } - public static double entropyBinarySplit(final double oneClassFrequency) { - return -1 * ((oneClassFrequency < Precision.EPSILON ? 0 : (oneClassFrequency * Math.log(oneClassFrequency))) - + ((1 - oneClassFrequency) < Precision.EPSILON ? 0 : ((1 - oneClassFrequency) * Math.log(1 - oneClassFrequency)))); + + public static double entropyMulticlass(final int[] classCountsDistribution, final int totalCount) { + return -1 * Arrays.stream(classCountsDistribution) + .mapToDouble(count -> (count == 0) ? 0 : (count * 1.0 / totalCount) * Math.log((count * 1.0 / totalCount))) + .sum(); } - - public Double binaryEntropy() { - double a1 = (entropyBinarySplit(_leftCount0 * 1.0 / _leftCount) - * _leftCount / (_leftCount + _rightCount)); - double a2 = (entropyBinarySplit(_rightCount0 * 1.0 / _rightCount) - * _rightCount / (_leftCount + _rightCount)); - return a1 + a2; + + public Double splitEntropy() { + int totalCount = _leftCount + _rightCount; + double res = entropyMulticlass(_leftClassDistribution, _leftCount) * _leftCount / totalCount + + entropyMulticlass(_rightClassDistribution, _rightCount) * _rightCount / totalCount; + return res; } } diff --git a/h2o-algos/src/main/java/hex/tree/dt/mrtasks/CountBinsSamplesCountsMRTask.java b/h2o-algos/src/main/java/hex/tree/dt/mrtasks/CountBinsSamplesCountsMRTask.java index a20eb8a6149d..d8e44726fbdf 100644 --- a/h2o-algos/src/main/java/hex/tree/dt/mrtasks/CountBinsSamplesCountsMRTask.java +++ b/h2o-algos/src/main/java/hex/tree/dt/mrtasks/CountBinsSamplesCountsMRTask.java @@ -1,6 +1,5 @@ package hex.tree.dt.mrtasks; -import org.apache.commons.math3.util.Precision; import water.MRTask; import water.fvec.Chunk; @@ -16,25 +15,30 @@ */ public class CountBinsSamplesCountsMRTask extends MRTask { public final int _featureSplit; + // numCol x 2 - min and max for each feature final double[][] _featuresLimits; - // binsCount x bin_encoding_len (5 or 3), depending on feature type: - // for numeric feature bin_encoding_len = 5: {numeric flag (-1.0), count, count0, min, max} - // for categorical feature bin_encoding_len = 3: {category, count, count0} + + // binsCount x bin_encoding_len, depending on feature type: + // for numeric feature: {numeric flag (-1.0), min, max, count, count0, count1, count2, ...} + // for categorical feature: {category, count, count0, count1, count2, ...} + // for accessing specific class count use {offset+1+class} index - e.g. for numeric count1: _bins[NUM_COUNT_OFFSET+1+1] public double[][] _bins; // indices for the serialized array public static final int NUMERICAL_FLAG = 0; + public static final int NUM_COUNT_OFFSET = 3; // follows numeric flag, min and max + public static final int CAT_COUNT_OFFSET = 1; // follows category + + private final int _countsOffset; + - // for both numeric and categorical features indices of count and count0 are the same - public static final int COUNT = 1; - public static final int COUNT_0 = 2; - - public CountBinsSamplesCountsMRTask(int featureSplit, double[][] featuresLimits, double[][] bins) { + public CountBinsSamplesCountsMRTask(int featureSplit, double[][] featuresLimits, double[][] bins, int countsOffset) { _featureSplit = featureSplit; _featuresLimits = featuresLimits; _bins = bins; + _countsOffset = countsOffset; } @Override @@ -44,6 +48,9 @@ public void map(Chunk[] cs) { double[][] tmpBins = new double[_bins.length][]; for (int b = 0; b < _bins.length; b++) { tmpBins[b] = Arrays.copyOf(_bins[b], _bins[b].length); + for(int c = _countsOffset; c < _bins[b].length; c++) { + tmpBins[b][c] = 0; // set all the counts to 0 - throw away existing counts if any + } } _bins = tmpBins; } @@ -64,20 +71,18 @@ public void map(Chunk[] cs) { for (int i = 0; i < _bins.length; i++) { // find bin by category if (_bins[i][0] == cs[_featureSplit].atd(row)) { - _bins[i][COUNT]++; - if (Precision.equals(cs[classFeature].atd(row), 0, Precision.EPSILON)) { - _bins[i][COUNT_0]++; - } + _bins[i][_countsOffset]++; + // calc index as {offset+1+class} + _bins[i][_countsOffset + 1 + (int)cs[classFeature].atd(row)]++; } } } else { for (int i = 0; i < _bins.length; i++) { // count feature values in the current bin if (checkBinBelonging(cs[_featureSplit].atd(row), i)) { - _bins[i][COUNT]++; - if (Precision.equals(cs[classFeature].atd(row), 0, Precision.EPSILON)) { - _bins[i][COUNT_0]++; - } + _bins[i][_countsOffset]++; + // calc index as {offset+1+class} + _bins[i][_countsOffset + 1 + (int)cs[classFeature].atd(row)]++; } } } @@ -107,8 +112,9 @@ private boolean checkBinBelonging(double featureValue, int bin) { @Override public void reduce(CountBinsSamplesCountsMRTask mrt) { for (int i = 0; i < _bins.length; i++) { - _bins[i][COUNT] += mrt._bins[i][COUNT]; - _bins[i][COUNT_0] += mrt._bins[i][COUNT_0]; + for(int c = _countsOffset; c < _bins[i].length; c++) { + _bins[i][c] += mrt._bins[i][c]; + } } } } diff --git a/h2o-algos/src/test/java/hex/tree/dt/BinningTest.java b/h2o-algos/src/test/java/hex/tree/dt/BinningTest.java index b58757df3ea3..d8ee98f15342 100644 --- a/h2o-algos/src/test/java/hex/tree/dt/BinningTest.java +++ b/h2o-algos/src/test/java/hex/tree/dt/BinningTest.java @@ -55,7 +55,7 @@ public void testBinningBasicData() { DataFeaturesLimits wholeDataLimits = getInitialFeaturesLimits(basicData); - Histogram histogram = new Histogram(basicData, wholeDataLimits, BinningStrategy.EQUAL_WIDTH); + Histogram histogram = new Histogram(basicData, wholeDataLimits, BinningStrategy.EQUAL_WIDTH, 2); // count of features assertEquals(basicData.numCols() - 1, histogram.featuresCount()); int numRows = (int) basicData.numRows(); @@ -67,13 +67,13 @@ public void testBinningBasicData() { histogram.getFeatureBins(0).stream().map(b -> b._count).collect(Collectors.toList())); // feature 0, count 0 assertEquals(Arrays.asList(0, 0, 1, 0, 1, 0, 1, 0, 0, 0), - histogram.getFeatureBins(0).stream().map(b -> b._count0).collect(Collectors.toList())); + histogram.getFeatureBins(0).stream().map(b -> b._classesDistribution[0]).collect(Collectors.toList())); // feature 1, count all assertEquals(Arrays.asList(4, 3, 3), histogram.getFeatureBins(1).stream().map(b -> b._count).collect(Collectors.toList())); // feature 1, count 0 assertEquals(Arrays.asList(1, 1, 1), - histogram.getFeatureBins(1).stream().map(b -> b._count0).collect(Collectors.toList())); + histogram.getFeatureBins(1).stream().map(b -> b._classesDistribution[0]).collect(Collectors.toList())); } finally { Scope.exit(); @@ -93,58 +93,58 @@ public void testBinSamplesCountBasicData() { .build(); DataFeaturesLimits dataLimits = getInitialFeaturesLimits(basicData); - Histogram histogram = new Histogram(basicData, dataLimits, BinningStrategy.EQUAL_WIDTH); + Histogram histogram = new Histogram(basicData, dataLimits, BinningStrategy.EQUAL_WIDTH, 2); - // extracting bins from the histogram and throwing away calculated values to test the calculation separately + // extracting bins from the histogram double[][] binsArray = histogram.getFeatureBins(0).stream() - .map(bin -> new double[]{-1.0, 0, 0, ((NumericBin) bin)._min, ((NumericBin) bin)._max}).toArray(double[][]::new); + .map(AbstractBin::toDoubles).toArray(double[][]::new); CountBinsSamplesCountsMRTask task = new CountBinsSamplesCountsMRTask( - 0, dataLimits.toDoubles(), binsArray); + 0, dataLimits.toDoubles(), binsArray, NUM_COUNT_OFFSET); task.doAll(basicData); assertEquals(10, task._bins.length); assert(task._bins[0][NUMERICAL_FLAG] == -1); assert(task._bins[0][MIN_INDEX] < basicData.vec(0).min()); assert(task._bins[0][MAX_INDEX] < 1 && task._bins[0][MAX_INDEX] > 0.8); - assert(task._bins[0][COUNT] == 1.0); - assert(task._bins[0][COUNT_0] == 0.0); + assert(task._bins[0][NUM_COUNT_OFFSET] == 1.0); + assert(task._bins[0][NUM_COUNT_OFFSET + 1] == 0.0); assert(task._bins[1][NUMERICAL_FLAG] == -1); assert(task._bins[1][MIN_INDEX] == task._bins[0][MAX_INDEX]); assert(task._bins[1][MAX_INDEX] < 2); - assert(task._bins[1][COUNT] == 1.0); - assert(task._bins[1][COUNT_0] == 0.0); + assert(task._bins[1][NUM_COUNT_OFFSET] == 1.0); + assert(task._bins[1][NUM_COUNT_OFFSET + 1] == 0.0); assert(task._bins[2][NUMERICAL_FLAG] == -1); assert(task._bins[2][MIN_INDEX] == task._bins[1][MAX_INDEX]); assert(task._bins[2][MAX_INDEX] < 3); - assert(task._bins[2][COUNT] == 1.0); - assert(task._bins[2][COUNT_0] == 1.0); + assert(task._bins[2][NUM_COUNT_OFFSET] == 1.0); + assert(task._bins[2][NUM_COUNT_OFFSET + 1] == 1.0); // extracting bins from the histogram and throwing away calculated values to test the calculation separately binsArray = histogram.getFeatureBins(1).stream() - .map(bin -> new double[]{((CategoricalBin) bin)._category, 0, 0}).toArray(double[][]::new); + .map(bin -> new double[]{((CategoricalBin) bin)._category, 0, 0, 0}).toArray(double[][]::new); - task = new CountBinsSamplesCountsMRTask(1, dataLimits.toDoubles(), binsArray).doAll(basicData); + task = new CountBinsSamplesCountsMRTask(1, dataLimits.toDoubles(), binsArray, CAT_COUNT_OFFSET).doAll(basicData); assertEquals(3, task._bins.length); // category assert(task._bins[0][0] == 0); - assert(task._bins[0][COUNT] == 4); - assert(task._bins[0][COUNT_0] == 1); + assert(task._bins[0][CAT_COUNT_OFFSET] == 4); + assert(task._bins[0][CAT_COUNT_OFFSET + 1] == 1); // category assert(task._bins[1][0] == 1); - assert(task._bins[1][COUNT] == 3); - assert(task._bins[1][COUNT_0] == 1); + assert(task._bins[1][CAT_COUNT_OFFSET] == 3); + assert(task._bins[1][CAT_COUNT_OFFSET + 1] == 1); // category assert(task._bins[2][0] == 2); - assert(task._bins[2][COUNT] == 3); - assert(task._bins[2][COUNT_0] == 1); + assert(task._bins[2][CAT_COUNT_OFFSET] == 3); + assert(task._bins[2][CAT_COUNT_OFFSET + 1] == 1); } finally { Scope.exit(); @@ -179,7 +179,7 @@ public void testBinningProstateData() { } - Histogram histogram = new Histogram(prostateData, wholeDataLimits, BinningStrategy.EQUAL_WIDTH); + Histogram histogram = new Histogram(prostateData, wholeDataLimits, BinningStrategy.EQUAL_WIDTH, 2); // count of features assertEquals(prostateData.numCols() - 1, histogram.featuresCount()); int numRows = (int) prostateData.numRows(); @@ -218,7 +218,7 @@ public void testBinningAirlinesData() { assertEquals(data.vec(i).cardinality(), wholeDataLimits.getFeatureLimits(i).toDoubles().length); } - Histogram histogram = new Histogram(data, wholeDataLimits, BinningStrategy.EQUAL_WIDTH); + Histogram histogram = new Histogram(data, wholeDataLimits, BinningStrategy.EQUAL_WIDTH, 2); // count of features assertEquals(data.numCols() - 1, histogram.featuresCount()); int numRows = (int) data.numRows(); diff --git a/h2o-algos/src/test/java/hex/tree/dt/DTTest.java b/h2o-algos/src/test/java/hex/tree/dt/DTTest.java index b98efd1d9c6a..6a65964569be 100644 --- a/h2o-algos/src/test/java/hex/tree/dt/DTTest.java +++ b/h2o-algos/src/test/java/hex/tree/dt/DTTest.java @@ -24,7 +24,7 @@ public class DTTest extends TestUtil { @Test - public void testBasicData() { + public void testBasicDataBinomial() { try { Scope.enter(); Frame train = new TestFrameBuilder() @@ -43,7 +43,7 @@ public void testBasicData() { p._train = train._key; p._seed = 0xDECAF; p._max_depth = 5; - p._min_rows = 2; + p._min_rows = 1; p._response_column = "Prediction"; testDataset(train, p); @@ -83,13 +83,84 @@ public void testBasicData() { assertEquals(1, prediction.vec(0).at(6), 0.1); assertEquals(1, prediction.vec(0).at(7), 0.1); assertEquals(1, prediction.vec(0).at(8), 0.1); - assertEquals(1, prediction.vec(0).at(9), 0.1); // the only one false positive + assertEquals(0, prediction.vec(0).at(9), 0.1); } finally { Scope.exit(); } } + + @Test + public void testBasicDataMultinomial() { + try { + Scope.enter(); + Frame train = new TestFrameBuilder() + .withVecTypes(Vec.T_NUM, Vec.T_CAT, Vec.T_CAT) + .withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0)) + .withDataForCol(1, ar("two", "one", "three", "one", "three", "two", "two", "one", "one", "one", "three", "three")) + .withDataForCol(2, ar("0", "2", "2", "2", "2", "0", "0", "1", "1", "1", "1", "0")) + .withColNames("First", "Second", "Prediction") + .build(); + + Scope.track(train); + + + DTModel.DTParameters p = + new DTModel.DTParameters(); + p._train = train._key; + p._seed = 0xDECAF; + p._max_depth = 5; + p._min_rows = 1; + p._response_column = "Prediction"; + + testDataset(train, p); + + DT dt = new DT(p); + DTModel model = dt.trainModel().get(); + assert model != null; + Scope.track_generic(model); + Frame out = model.score(train); + Scope.track(out); + System.out.println(Arrays.toString(out.names())); + assertEquals(train.numRows(), out.numRows()); + + +// System.out.println(DKV.getGet(model._output._treeKey)); + + Frame test = new TestFrameBuilder() + .withVecTypes(Vec.T_NUM, Vec.T_CAT) + .withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0)) + .withDataForCol(1, ar("two", "one", "three", "one", "three", "two", "two", "one", "one", "one", "three", "three")) + .withColNames("First", "Second") + .build(); + Scope.track(test); + + System.out.println(Arrays.deepToString(((CompressedDT) DKV.getGet(model._output._treeKey)).getNodes())); + System.out.println(String.join("\n", ((CompressedDT) DKV.getGet(model._output._treeKey)).getListOfRules())); + + Frame prediction = model.score(test); + Scope.track(prediction); + System.out.println(Arrays.toString(FrameUtils.asInts(prediction.vec(0)))); + assertEquals(0, prediction.vec(0).at(0), 0.1); + assertEquals(2, prediction.vec(0).at(1), 0.1); + assertEquals(2, prediction.vec(0).at(2), 0.1); + assertEquals(2, prediction.vec(0).at(3), 0.1); + assertEquals(2, prediction.vec(0).at(4), 0.1); + assertEquals(0, prediction.vec(0).at(5), 0.1); + assertEquals(0, prediction.vec(0).at(6), 0.1); + assertEquals(1, prediction.vec(0).at(7), 0.1); + assertEquals(1, prediction.vec(0).at(8), 0.1); + assertEquals(1, prediction.vec(0).at(9), 0.1); + assertEquals(1, prediction.vec(0).at(10), 0.1); + assertEquals(0, prediction.vec(0).at(11), 0.1); + + } finally { + Scope.exit(); + } + } + + @Test public void testNaNsChecks() { try { @@ -131,7 +202,7 @@ public void testPredictionColumnChecks() { .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM) .withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)) .withDataForCol(1, ard(1.88, 1.5, 0.88, 1.5, 0.88, 1.5, 0.88, 1.5, 8.0, 9.0)) - .withDataForCol(2, ard(1, 2, 2, 1, 0, 1, 0, 1, 1, 1)) + .withDataForCol(2, ard(1, 3, 2, 1, 0, 1, 0, 1, 1, 1)) .withColNames("First", "Second", "Prediction") .build(); @@ -146,7 +217,6 @@ public void testPredictionColumnChecks() { fail("should have thrown validation error"); } catch (H2OModelBuilderIllegalArgumentException e) { assertTrue(e.getMessage().contains("Only categorical response is supported")); - assertTrue(e.getMessage().contains("Only binary response is supported")); } finally { Scope.exit(); } @@ -203,6 +273,48 @@ public void testAirlinesSmallData() { Scope.exit(); } + @Ignore // uses local dataset + @Test + public void testMultinomialSmallDataLocal() { + Scope.enter(); + Frame train = Scope.track(parseTestFile("smalldata/sdt/sdt_3EnumCols_10kRows_multinomial.csv")) + .toCategoricalCol(3); + Frame test = Scope.track(parseTestFile("smalldata/sdt/sdt_3EnumCols_10kRows_multinomial.csv")) + .toCategoricalCol(3); + + DTModel.DTParameters p = + new DTModel.DTParameters(); + p._train = train._key; + p._valid = train._key; + p._seed = 0xDECAF; + p._max_depth = 5; + p._response_column = "response"; + + testDataset(test, p); + Scope.exit(); + } + + @Test + public void testMultinomialSmallDataGAMData() { + Scope.enter(); + Frame train = Scope.track(parseTestFile("smalldata/predictMultinomialGAM3.csv")) + .toCategoricalCol(0); + Frame test = Scope.track(parseTestFile("smalldata/predictMultinomialGAM3.csv")) + .toCategoricalCol(0); + + + DTModel.DTParameters p = + new DTModel.DTParameters(); + p._train = train._key; + p._valid = train._key; + p._seed = 0xDECAF; + p._max_depth = 5; + p._response_column = "C1"; + + testDataset(test, p); + Scope.exit(); + } + @Test public void testBigDataCreditCard() { @@ -257,9 +369,39 @@ public void testDataset(Frame test, DTModel.DTParameters p) { test.vec(p._response_column).toCategoricalVec(), out.vec(0).toCategoricalVec()); System.out.println("Max depth: " + p._max_depth); - System.out.println("DT:"); System.out.println("Accuracy: " + cm.accuracy()); - System.out.println("F1: " + cm.f1()); +// System.out.println("F1: " + cm.f1()); + // Calculate precision, recall, and F1 score for each class manually as it's not available for multiclass data + int nClasses = cm.nclasses(); + double[] precision = new double[nClasses]; + double[] recall = new double[nClasses]; + double[] f1 = new double[nClasses]; + + for (int i = 0; i < nClasses; i++) { + double truePositive = cm._cm[i][i]; + double falsePositive = 0; + double falseNegative = 0; + + for (int j = 0; j < nClasses; j++) { + if (i != j) { + falsePositive += cm._cm[j][i]; + falseNegative += cm._cm[i][j]; + } + } + + precision[i] = truePositive / (truePositive + falsePositive); + recall[i] = truePositive / (truePositive + falseNegative); + f1[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]); + } + + // Print class-wise precision, recall, and F1 score + System.out.println("Class-wise Precision, Recall, and F1 Score:"); + for (int i = 0; i < nClasses; i++) { + System.out.printf("Class %d - Precision: %.4f, Recall: %.4f, F1 Score: %.4f%n", + i, precision[i], recall[i], f1[i]); + } + + // // check for model metrics // assertNotNull(model._output._training_metrics); diff --git a/h2o-algos/src/test/java/hex/tree/dt/SplittingTest.java b/h2o-algos/src/test/java/hex/tree/dt/SplittingTest.java index 092678d78fbb..2dac002a12cf 100644 --- a/h2o-algos/src/test/java/hex/tree/dt/SplittingTest.java +++ b/h2o-algos/src/test/java/hex/tree/dt/SplittingTest.java @@ -37,7 +37,7 @@ public void testNumericSplitting() { DataFeaturesLimits wholeDataLimits = getInitialFeaturesLimits(basicData); - Histogram histogram = new Histogram(basicData, wholeDataLimits, BinningStrategy.EQUAL_WIDTH); + Histogram histogram = new Histogram(basicData, wholeDataLimits, BinningStrategy.EQUAL_WIDTH, 2); // count of features assertEquals(basicData.numCols() - 1, histogram.featuresCount()); int numRows = (int) basicData.numRows(); @@ -49,13 +49,13 @@ public void testNumericSplitting() { histogram.getFeatureBins(0).stream().map(b -> b._count).collect(Collectors.toList())); // feature 0, count 0 assertEquals(Arrays.asList(0, 0, 1, 0, 1, 0, 1, 0, 0, 0), - histogram.getFeatureBins(0).stream().map(b -> b._count0).collect(Collectors.toList())); + histogram.getFeatureBins(0).stream().map(b -> b._classesDistribution[0]).collect(Collectors.toList())); // feature 1, count all assertEquals(Arrays.asList(4, 3, 3), histogram.getFeatureBins(1).stream().map(b -> b._count).collect(Collectors.toList())); // feature 1, count 0 assertEquals(Arrays.asList(1, 1, 1), - histogram.getFeatureBins(1).stream().map(b -> b._count0).collect(Collectors.toList())); + histogram.getFeatureBins(1).stream().map(b -> b._classesDistribution[0]).collect(Collectors.toList())); } finally { Scope.exit(); diff --git a/h2o-py/h2o/estimators/decision_tree.py b/h2o-py/h2o/estimators/decision_tree.py index e598396b2a82..ee044612ef2a 100644 --- a/h2o-py/h2o/estimators/decision_tree.py +++ b/h2o-py/h2o/estimators/decision_tree.py @@ -29,6 +29,7 @@ def __init__(self, categorical_encoding="auto", # type: Literal["auto", "enum", "one_hot_internal", "one_hot_explicit", "binary", "eigen", "label_encoder", "sort_by_response", "enum_limited"] response_column=None, # type: Optional[str] seed=-1, # type: int + distribution="auto", # type: Literal["auto", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", "quantile", "huber"] max_depth=20, # type: int min_rows=10, # type: int ): @@ -55,6 +56,10 @@ def __init__(self, :param seed: Seed for random numbers (affects sampling) Defaults to ``-1``. :type seed: int + :param distribution: Distribution function + Defaults to ``"auto"``. + :type distribution: Literal["auto", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", + "quantile", "huber"] :param max_depth: Max depth of tree. Defaults to ``20``. :type max_depth: int @@ -71,6 +76,7 @@ def __init__(self, self.categorical_encoding = categorical_encoding self.response_column = response_column self.seed = seed + self.distribution = distribution self.max_depth = max_depth self.min_rows = min_rows @@ -158,6 +164,21 @@ def seed(self, seed): assert_is_type(seed, None, int) self._parms["seed"] = seed + @property + def distribution(self): + """ + Distribution function + + Type: ``Literal["auto", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", + "quantile", "huber"]``, defaults to ``"auto"``. + """ + return self._parms.get("distribution") + + @distribution.setter + def distribution(self, distribution): + assert_is_type(distribution, None, Enum("auto", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", "quantile", "huber")) + self._parms["distribution"] = distribution + @property def max_depth(self): """ diff --git a/h2o-py/tests/testdir_algos/dt/pyunit_dt_classification.py b/h2o-py/tests/testdir_algos/dt/pyunit_dt_classification.py index 50b577a704d0..80901f1646df 100644 --- a/h2o-py/tests/testdir_algos/dt/pyunit_dt_classification.py +++ b/h2o-py/tests/testdir_algos/dt/pyunit_dt_classification.py @@ -1,3 +1,6 @@ +import sys +sys.path.insert(1, "../../../") + import h2o from h2o.estimators import H2ODecisionTreeEstimator from tests import pyunit_utils diff --git a/h2o-py/tests/testdir_algos/dt/pyunit_dt_multinomial.py b/h2o-py/tests/testdir_algos/dt/pyunit_dt_multinomial.py new file mode 100644 index 000000000000..c164c26a2321 --- /dev/null +++ b/h2o-py/tests/testdir_algos/dt/pyunit_dt_multinomial.py @@ -0,0 +1,37 @@ +import sys + +from sklearn.metrics import accuracy_score + +sys.path.insert(1, "../../../") + +import h2o +from tests import pyunit_utils +from h2o.estimators import H2ODecisionTreeEstimator + + +def test_dt_multinomial(): + data = h2o.import_file(pyunit_utils.locate("smalldata/sdt/sdt_3EnumCols_10kRows_multinomial.csv")) + response_col = "response" + data[response_col] = data[response_col].asfactor() + + predictors = ["C1", "C2", "C3"] + + # train model + dt = H2ODecisionTreeEstimator(max_depth=3) + dt.train(x=predictors, y=response_col, training_frame=data) + + dt.show() + + pred_train = dt.predict(data).as_data_frame(use_pandas=True)['predict'] + + train_accuracy = accuracy_score(data[response_col].as_data_frame(use_pandas=True), pred_train) + + print(train_accuracy) + + assert train_accuracy >= 0.8555 + + +if __name__ == "__main__": + pyunit_utils.standalone_test(test_dt_multinomial) +else: + test_dt_multinomial() diff --git a/h2o-r/h2o-package/R/decisiontree.R b/h2o-r/h2o-package/R/decisiontree.R index 85d51865a0bf..b9eed69d9d8a 100644 --- a/h2o-r/h2o-package/R/decisiontree.R +++ b/h2o-r/h2o-package/R/decisiontree.R @@ -19,6 +19,8 @@ #' "Binary", "Eigen", "LabelEncoder", "SortByResponse", "EnumLimited". Defaults to AUTO. #' @param seed Seed for random numbers (affects certain parts of the algo that are stochastic and those might or might not be enabled by default). #' Defaults to -1 (time-based random number). +#' @param distribution Distribution function Must be one of: "AUTO", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", +#' "tweedie", "laplace", "quantile", "huber". Defaults to AUTO. #' @param max_depth Max depth of tree. Defaults to 20. #' @param min_rows Fewest allowed (weighted) observations in a leaf. Defaults to 10. #' @return Creates a \linkS4class{H2OModel} object of the right type. @@ -48,6 +50,7 @@ h2o.decision_tree <- function(x, ignore_const_cols = TRUE, categorical_encoding = c("AUTO", "Enum", "OneHotInternal", "OneHotExplicit", "Binary", "Eigen", "LabelEncoder", "SortByResponse", "EnumLimited"), seed = -1, + distribution = c("AUTO", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", "quantile", "huber"), max_depth = 20, min_rows = 10) { @@ -94,6 +97,7 @@ h2o.decision_tree <- function(x, ignore_const_cols = TRUE, categorical_encoding = c("AUTO", "Enum", "OneHotInternal", "OneHotExplicit", "Binary", "Eigen", "LabelEncoder", "SortByResponse", "EnumLimited"), seed = -1, + distribution = c("AUTO", "bernoulli", "multinomial", "gaussian", "poisson", "gamma", "tweedie", "laplace", "quantile", "huber"), max_depth = 20, min_rows = 10, segment_columns = NULL, diff --git a/h2o-r/tests/testdir_algos/dt/runit_dt_multinomial.R b/h2o-r/tests/testdir_algos/dt/runit_dt_multinomial.R new file mode 100644 index 000000000000..ce57c98411d7 --- /dev/null +++ b/h2o-r/tests/testdir_algos/dt/runit_dt_multinomial.R @@ -0,0 +1,26 @@ +setwd(normalizePath(dirname(R.utils::commandArgs(asValues=TRUE)$"f"))) +source("../../../scripts/h2o-r-test-setup.R") +library(rpart) + +# Initialize H2O cluster +h2o.init() + +# Define the test function +test_dt_multinomial <- function() { + # Load the data + data <- h2o.importFile(path = h2o:::.h2o.locate("smalldata/sdt/sdt_3EnumCols_10kRows_multinomial.csv")) + response_col <- "response" + data[, response_col] <- as.factor(data[, response_col]) + + predictors <- c("C1", "C2", "C3") + + # Train the model + dt <- h2o.decision_tree(max_depth = 3, x = predictors, y = response_col, training_frame = data) + pred <- predict(dt, data) + + # Print model summary + print(dt) + expect_equal(nrow(unique(as.data.frame(pred$predict))), 3) +} + +doTest("Decision tree: multinomial classification", test_dt_multinomial)