-
Notifications
You must be signed in to change notification settings - Fork 2k
/
Copy pathKNN.java
104 lines (89 loc) · 3.8 KB
/
KNN.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package hex.knn;
import hex.*;
import water.DKV;
import water.Key;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
public class KNN extends ModelBuilder<KNNModel,KNNModel.KNNParameters,KNNModel.KNNOutput> {
public KNN(KNNModel.KNNParameters parms) {
super(parms);
init(false);
}
public KNN(boolean startup_once) {
super(new KNNModel.KNNParameters(), startup_once);
}
@Override
protected KNNDriver trainModelImpl() {
return new KNNDriver();
}
@Override
public ModelCategory[] can_build() {
return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Multinomial};
}
@Override
public boolean isSupervised() {
return true;
}
@Override public void init(boolean expensive) {
super.init(expensive);
if( null == _parms._id_column) {
error("_id_column", "ID column parameter not set.");
}
if( null == _parms._distance) {
error("_distance", "Distance parameter not set.");
}
if (null != _parms._categorical_encoding && Model.Parameters.CategoricalEncodingScheme.Enum != _parms._categorical_encoding
&& Model.Parameters.CategoricalEncodingScheme.AUTO != _parms._categorical_encoding) {
error("_categorical_encoding", "Only enum categorical encoding is supported.");
}
}
class KNNDriver extends Driver {
@Override
public void computeImpl() {
KNNModel model = null;
Frame result = new Frame(Key.make("KNN_distances"));
try {
init(true); // Initialize parameters
if (error_count() > 0) {
throw new IllegalArgumentException("Found validation errors: " + validationErrors());
}
model = new KNNModel(dest(), _parms, new KNNModel.KNNOutput(KNN.this));
model.delete_and_lock(_job);
Frame train = _parms.train();
String idColumn = _parms._id_column;
int idColumnIndex = train.find(idColumn);
byte idType = train.vec(idColumnIndex).get_type();
String responseColumn = _parms._response_column;
int responseColumnIndex = train.find(responseColumn);
int nChunks = train.anyVec().nChunks();
int nCols = train.numCols();
// split data into chunks to calculate distances in parallel task
for (int i = 0; i < nChunks; i++) {
Chunk[] query = new Chunk[nCols];
for (int j = 0; j < nCols; j++) {
query[j] = train.vec(j).chunkForChunkIdx(i).deepCopy();
}
KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, KNNDistanceFactory.createDistance(_parms._distance), idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn);
Frame tmpResult = task.doAll(train).outputFrame();
Scope.untrack(tmpResult);
// merge result from a chunk
result = result.add(tmpResult);
}
Key<Frame> key = result._key;
DKV.put(key, result);
model._output.setDistancesKey(key);
Scope.untrack(result);
model.update(_job);
model.score(_parms.train()).delete();
model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train());
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (model != null) {
model.unlock(_job);
}
}
}
}
}