Skip to content

Commit 21a940e

Browse files
authored
Different criterion in regression and survival contrast sets. Some other fixes in contrast sets.
1 parent ea05815 commit 21a940e

File tree

13 files changed

+421
-275
lines changed

13 files changed

+421
-275
lines changed

adaa.analytics.rules/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ codeQuality {
2727
}
2828

2929
sourceCompatibility = 1.8
30-
version = '1.4.5'
30+
version = '1.4.8'
3131

3232

3333
jar {

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContingencyTable.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ public class ContingencyTable {
2929
public double median_y = 0;
3030
public double mean_y = 0;
3131
public double stddev_y = 0;
32+
33+
public double targetLabel;
3234

3335
public ContingencyTable() { }
3436

@@ -42,6 +44,7 @@ public ContingencyTable(double p, double n, double P, double N) {
4244
public void clear() {
4345
weighted_p = weighted_n = weighted_P = weighted_N = 0;
4446
mean_y = median_y = stddev_y = 0;
47+
targetLabel = -1;
4548
}
4649

4750
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastRegressionFinder.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.commons.math3.stat.inference.MannWhitneyUTest;
1212

1313
import java.io.Serializable;
14+
import java.security.InvalidParameterException;
1415
import java.util.HashSet;
1516
import java.util.Map;
1617
import java.util.Set;
@@ -30,24 +31,26 @@ public String getName() {
3031

3132
@Override
3233
public double calculate(ExampleSet dataset, ContingencyTable ct) {
33-
Covering cov = (Covering)ct;
3434

35-
double positiveSum = 0;
36-
double totalSum = 0;
35+
ContrastRegressionExampleSet cer = (dataset instanceof ContrastExampleSet) ? (ContrastRegressionExampleSet)dataset : null;
36+
if (cer == null) {
37+
throw new InvalidParameterException("ContrastSurvivalRuleSet supports only ContrastRegressionExampleSet instances");
38+
}
39+
40+
Covering cov = (Covering)ct;
41+
double sum = 0;
3742

3843
int i = 0;
3944
for (int e : cov.positives) {
40-
double label = dataset.getExample(e).getLabel();
41-
positiveSum += label;
45+
sum += dataset.getExample(e).getLabel();
4246
}
43-
44-
totalSum = positiveSum;
4547
for (int e : cov.negatives) {
46-
totalSum += dataset.getExample(e).getLabel();
48+
sum += dataset.getExample(e).getLabel();
4749
}
4850

4951
// the smaller the difference in means, the better the contrast set
50-
double diff = Math.abs(positiveSum / cov.weighted_p - totalSum / (cov.weighted_p + cov.weighted_n));
52+
double groupEstimator = cer.getGroupEstimators().get((int)ct.targetLabel);
53+
double diff = Math.abs(sum / (cov.weighted_p + cov.weighted_n) - groupEstimator);
5154
return -diff;
5255
}
5356

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastSnC.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import com.rapidminer.example.Attribute;
55
import com.rapidminer.example.ExampleSet;
66
import com.rapidminer.example.set.AttributeValueFilter;
7+
import com.rapidminer.example.set.AttributeValueFilterSingleCondition;
78
import com.rapidminer.example.set.ConditionedExampleSet;
9+
import com.rapidminer.example.set.SimpleExampleSet;
810
import com.rapidminer.example.table.NominalMapping;
11+
import com.rapidminer.operator.tools.ExpressionEvaluationException;
912

1013
import java.util.ArrayList;
1114
import java.util.List;
@@ -18,7 +21,7 @@ public ContrastSnC(AbstractFinder finder, InductionParameters params) {
1821

1922
int ruleType = RuleFactory.CONTRAST;
2023

21-
if (finder instanceof ContrastRegressionFinder) {
24+
if (finder instanceof ContrastRegressionFinder) {
2225
ruleType = RuleFactory.CONTRAST_REGRESSION;
2326
} else if (finder instanceof ContrastSurvivalFinder) {
2427
ruleType = RuleFactory.CONTRAST_SURVIVAL;
@@ -34,7 +37,19 @@ public ContrastSnC(AbstractFinder finder, InductionParameters params) {
3437
* @return Rule set.
3538
*/
3639
public RuleSetBase run(ExampleSet dataset) {
37-
ContrastRuleSet rs = (ContrastRuleSet) factory.create(dataset);
40+
41+
// make a contrast dataset
42+
ContrastExampleSet ces;
43+
44+
if (factory.getType() == RuleFactory.CONTRAST_REGRESSION) {
45+
ces = new ContrastRegressionExampleSet((SimpleExampleSet) dataset);
46+
} else if (factory.getType() == RuleFactory.CONTRAST_SURVIVAL) {
47+
ces = new ContrastSurvivalExampleSet((SimpleExampleSet) dataset);
48+
} else {
49+
ces = new ContrastExampleSet((SimpleExampleSet) dataset);
50+
}
51+
52+
ContrastRuleSet rs = (ContrastRuleSet) factory.create(ces);
3853
IPenalizedFinder pf = (IPenalizedFinder)finder;
3954

4055
// reset penalties
@@ -52,7 +67,7 @@ public RuleSetBase run(ExampleSet dataset) {
5267
for (double mincovAll : mincovs) {
5368
params.setMinimumCoveredAll(mincovAll);
5469

55-
run(dataset, rs);
70+
run(ces, rs);
5671

5772
// reset penalty when multiple passes
5873
if (params.getMaxPassesCount() > 1) {
@@ -68,13 +83,11 @@ public RuleSetBase run(ExampleSet dataset) {
6883
* @param dataset Training data set.
6984
* @return Rule set.
7085
*/
71-
public void run(ExampleSet dataset, ContrastRuleSet crs) {
86+
protected void run(ContrastExampleSet dataset, ContrastRuleSet crs) {
7287
Logger.log("ContrastSnC.run()\n", Level.FINE);
7388

7489
// try to get contrast attribute (use label if not specified)
75-
final Attribute contrastAttr = (dataset.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE) == null)
76-
? dataset.getAttributes().getLabel()
77-
: dataset.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE);
90+
final Attribute contrastAttr = dataset.getContrastAttribute();
7891

7992
NominalMapping mapping = contrastAttr.getMapping();
8093
IPenalizedFinder pf = (IPenalizedFinder)finder;

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastSurvivalFinder.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
import adaa.analytics.rules.logic.quality.LogRank;
55
import adaa.analytics.rules.logic.quality.NegativeControlledMeasure;
66
import adaa.analytics.rules.logic.representation.*;
7-
import com.rapidminer.example.Attribute;
87
import com.rapidminer.example.ExampleSet;
9-
import com.rapidminer.example.table.NominalMapping;
108
import com.rapidminer.tools.container.Pair;
119

1210
import java.io.Serializable;
11+
import java.security.InvalidParameterException;
1312
import java.util.HashSet;
14-
import java.util.Map;
1513
import java.util.Set;
16-
import java.util.TreeMap;
1714

1815
public class ContrastSurvivalFinder extends SurvivalLogRankFinder implements IPenalizedFinder {
1916

@@ -29,18 +26,23 @@ public String getName() {
2926

3027
@Override
3128
public double calculate(ExampleSet dataset, ContingencyTable ct) {
29+
30+
ContrastSurvivalExampleSet ces = (dataset instanceof ContrastExampleSet) ? (ContrastSurvivalExampleSet)dataset : null;
31+
if (ces == null) {
32+
throw new InvalidParameterException("ContrastSurvivalRuleSet supports only ContrastSurvivalExampleSet instances");
33+
}
34+
3235
Covering cov = (Covering)ct;
3336
Set<Integer> examples = new HashSet<>();
3437
examples.addAll(cov.positives);
35-
KaplanMeierEstimator positiveEstimator = new KaplanMeierEstimator(dataset, examples);
36-
3738
examples.addAll(cov.negatives);
3839
KaplanMeierEstimator entireEstimator = new KaplanMeierEstimator(dataset, examples);
3940

4041
// compare estimators of:
4142
// - all covered examples (entire contrast set)
42-
// - covered positives
43-
Pair<Double,Double> statsAndPValue = super.compareEstimators(positiveEstimator, entireEstimator);
43+
// - entire group
44+
KaplanMeierEstimator groupEstimator = ces.getGroupEstimators().get((int)ct.targetLabel);
45+
Pair<Double,Double> statsAndPValue = super.compareEstimators(groupEstimator, entireEstimator);
4446

4547
// smaller test statistics -> smaller difference -> better contrast set
4648
return -statsAndPValue.getFirst();

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ClassificationRule.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ public Covering covers(ExampleSet set, Set<Integer> filterIds) {
101101
}
102102
}
103103
}
104+
105+
covered.targetLabel = ((SingletonSet)this.getConsequence().getValueSet()).getValue();
104106
return covered;
105107
}
106108

@@ -145,6 +147,8 @@ public void covers(ExampleSet set, ContingencyTable ct) {
145147
}
146148
}
147149
}
150+
151+
ct.targetLabel = ((SingletonSet)this.getConsequence().getValueSet()).getValue();
148152
}
149153

150154
/**
@@ -178,5 +182,7 @@ public void covers(ExampleSet set, ContingencyTable ct, Set<Integer> positives,
178182
}
179183
++id;
180184
}
185+
186+
ct.targetLabel = ((SingletonSet)this.getConsequence().getValueSet()).getValue();
181187
}
182188
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package adaa.analytics.rules.logic.representation;
2+
3+
import com.rapidminer.example.Attribute;
4+
import com.rapidminer.example.ExampleSet;
5+
import com.rapidminer.example.set.AttributeValueFilterSingleCondition;
6+
import com.rapidminer.example.set.ConditionedExampleSet;
7+
import com.rapidminer.example.set.SimpleExampleSet;
8+
import com.rapidminer.example.table.NominalMapping;
9+
import com.rapidminer.operator.tools.ExpressionEvaluationException;
10+
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
14+
public class ContrastExampleSet extends SimpleExampleSet {
15+
16+
protected Attribute contrastAttribute;
17+
18+
public Attribute getContrastAttribute() { return contrastAttribute; }
19+
20+
public ContrastExampleSet(SimpleExampleSet exampleSet) {
21+
super(exampleSet);
22+
23+
contrastAttribute = (exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE) == null)
24+
? exampleSet.getAttributes().getLabel()
25+
: exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE);
26+
}
27+
}
28+
29+
30+
31+
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package adaa.analytics.rules.logic.representation;
2+
3+
import com.rapidminer.example.Attribute;
4+
import com.rapidminer.example.ExampleSet;
5+
import com.rapidminer.example.Statistics;
6+
import com.rapidminer.example.set.AttributeValueFilterSingleCondition;
7+
import com.rapidminer.example.set.ConditionedExampleSet;
8+
import com.rapidminer.example.set.SimpleExampleSet;
9+
import com.rapidminer.example.table.NominalMapping;
10+
import com.rapidminer.operator.tools.ExpressionEvaluationException;
11+
12+
import java.util.ArrayList;
13+
import java.util.List;
14+
15+
public class ContrastRegressionExampleSet extends ContrastExampleSet {
16+
17+
/** Training set estimator. */
18+
protected double trainingEstimator;
19+
20+
/** Collection of Kaplan-Meier estimators for contrast groups. */
21+
protected List<Double> groupEstimators = new ArrayList<Double>();
22+
23+
/** Gets {@link #groupEstimators} */
24+
public List<Double> getGroupEstimators() { return groupEstimators; }
25+
26+
/** Gets {@link #trainingEstimator}}. */
27+
public double getTrainingEstimator() { return trainingEstimator; }
28+
29+
public ContrastRegressionExampleSet(SimpleExampleSet exampleSet) {
30+
super(exampleSet);
31+
32+
String averageName = (exampleSet.getAttributes().getWeight() != null)
33+
? Statistics.AVERAGE_WEIGHTED : Statistics.AVERAGE;
34+
35+
// establish training estimator
36+
Attribute label = exampleSet.getAttributes().getLabel();
37+
exampleSet.recalculateAttributeStatistics(label);
38+
trainingEstimator = exampleSet.getStatistics(label, averageName);
39+
40+
// establish contrast groups estimator
41+
try {
42+
NominalMapping mapping = contrastAttribute.getMapping();
43+
44+
for (int i = 0; i < mapping.size(); ++i) {
45+
AttributeValueFilterSingleCondition cnd = new AttributeValueFilterSingleCondition(
46+
contrastAttribute, AttributeValueFilterSingleCondition.EQUALS, mapping.mapIndex(i));
47+
48+
ExampleSet conditionedSet = new ConditionedExampleSet(exampleSet,cnd);
49+
conditionedSet.recalculateAttributeStatistics(label);
50+
groupEstimators.add(conditionedSet.getStatistics(label, averageName));
51+
}
52+
53+
} catch (ExpressionEvaluationException e) {
54+
e.printStackTrace();
55+
}
56+
}
57+
}

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastRegressionRuleSet.java

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.rapidminer.example.table.NominalMapping;
99
import com.rapidminer.operator.tools.ExpressionEvaluationException;
1010

11+
import java.security.InvalidParameterException;
1112
import java.util.ArrayList;
1213
import java.util.List;
1314

@@ -38,31 +39,13 @@ public class ContrastRegressionRuleSet extends ContrastRuleSet {
3839
public ContrastRegressionRuleSet(ExampleSet exampleSet, boolean isVoting, InductionParameters params, Knowledge knowledge) {
3940
super(exampleSet, isVoting, params, knowledge);
4041

41-
// establish training survival estimator
42-
exampleSet.recalculateAttributeStatistics(exampleSet.getAttributes().getLabel());
43-
trainingMean = exampleSet.getStatistics(exampleSet.getAttributes().getLabel(), "average");
44-
45-
final Attribute contrastAttr = (exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE) == null)
46-
? exampleSet.getAttributes().getLabel()
47-
: exampleSet.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE);
48-
49-
// establish contrast groups survival estimator
50-
try {
51-
NominalMapping mapping = contrastAttr.getMapping();
52-
53-
for (int i = 0; i < mapping.size(); ++i) {
54-
AttributeValueFilterSingleCondition cnd = new AttributeValueFilterSingleCondition(
55-
contrastAttr, AttributeValueFilterSingleCondition.EQUALS, mapping.mapIndex(i));
56-
57-
ExampleSet conditionedSet = new ConditionedExampleSet(exampleSet,cnd);
58-
59-
conditionedSet.recalculateAttributeStatistics(exampleSet.getAttributes().getLabel());
60-
groupMeans.add(conditionedSet.getStatistics(exampleSet.getAttributes().getLabel(), "average"));
61-
}
62-
63-
} catch (ExpressionEvaluationException e) {
64-
e.printStackTrace();
42+
ContrastRegressionExampleSet cer = (exampleSet instanceof ContrastExampleSet) ? (ContrastRegressionExampleSet)exampleSet : null;
43+
if (cer == null) {
44+
throw new InvalidParameterException("ContrastRegressionRuleSet supports only ContrastRegressionExampleSet instances");
6545
}
46+
47+
trainingMean = cer.getTrainingEstimator();
48+
groupMeans.addAll(cer.getGroupEstimators());
6649
}
6750

6851
/**

adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ContrastRuleSet.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ public class ContrastRuleSet extends ClassificationRuleSet {
2121

2222
private Map<String, Integer> numDuplicates = new TreeMap<>();
2323

24+
public List<ContrastRule> getAllSets() {
25+
List<ContrastRule> out = new ArrayList<ContrastRule>();
26+
for (String key: sets.keySet()) {
27+
List<ContrastRule> cs = sets.get(key);
28+
for (ContrastRule r : cs) {
29+
out.add(r);
30+
}
31+
}
32+
33+
return out;
34+
}
35+
2436
public int getTotalDuplicates() {
2537
int total = 0;
2638
for (int v: numDuplicates.values()) {

0 commit comments

Comments
 (0)