Skip to content

Commit aa26f8a

Browse files
committed
FIX:ffm predictor
1 parent 47bbbb9 commit aa26f8a

7 files changed

Lines changed: 32 additions & 22 deletions

File tree

bin/predict.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ resultFileSuffix="_"${model_name}"_"${resultSaveMode}
2424

2525
# max error data format tolerate number
2626
max_error_tol=100
27+
# auc,mae,rmse,confusion_matrix
2728
eval_metric="auc,mae"
2829
#value or leafid
2930
predict_type="value"

config/model/ffm.conf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ data {
1919
field_delim : "@"
2020
},
2121

22+
max_feature_dim: ???,
2223
// ["0@0.1","1@0.5",...]
2324
y_sampling : [],
2425
assigned : false,

demo/ffm/binary_classification/ffm.conf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ data {
1919
field_delim : "@"
2020
},
2121

22+
max_feature_dim: 117,
2223
// ["0@0.1","1@0.5",...]
2324
y_sampling : [],
2425
assigned : false,

demo/ffm/regression/ffm.conf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ data {
1919
field_delim : "@"
2020
},
2121

22+
max_feature_dim: 39,
2223
// ["0@0.1","1@0.5",...]
2324
y_sampling : [],
2425
assigned : false,

src/main/java/com/fenbi/ytklearn/dataflow/FFMModelDataFlow.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public class FFMModelDataFlow extends ContinuousDataFlow {
5858
private Map<String, Integer> field2IndexMap = new HashMap<>();
5959
private int fieldSize;
6060
private int maxFeatureNum = -1;
61+
private int maxFeatureDim = 100;
6162

6263
private RandomParams randomParams;
6364

@@ -88,6 +89,7 @@ public FFMModelDataFlow(IFileSystem fs,
8889

8990
fieldDelim = config.getString("data.delim.field_delim");
9091
fieldDictPath = config.getString("model.field_dict_path");
92+
maxFeatureDim = config.getInt("data.max_feature_dim");
9193

9294
randomParams = new RandomParams(config, "");
9395

@@ -103,6 +105,7 @@ public FFMModelDataFlow(IFileSystem fs,
103105
@Data
104106
public static class FFMCoreData extends ContinuousCoreData {
105107
private int maxFeatureNum;
108+
private int maxFeatureDim;
106109
private String fieldDelim;
107110
private Map<String, Integer> field2IndexMap;
108111

@@ -320,7 +323,7 @@ protected void loadModel() throws IOException, Mp4jException {
320323
@Override
321324
protected void handleOtherTrainInfo() throws Mp4jException {
322325
this.maxFeatureNum = ((FFMCoreData)threadTrainCoreDatas[0]).getMaxFeatureNum();
323-
LOG_UTILS.importantInfo("train line max feature num:" + maxFeatureNum);
326+
LOG_UTILS.importantInfo("train line max feature num:" + maxFeatureNum + ", config max feature dim:" + maxFeatureDim);
324327
}
325328

326329
@Override

src/main/java/com/fenbi/ytklearn/predictor/ContinuousOnlinePredictor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ public double batchPredictFromFiles(String modelName,
327327
}
328328
}
329329
double predict = predict(fmap, otherinfo);
330-
331330
if (hasLabel) {
332331
label = Float.parseFloat(linfo[0]);
333332
loss += weight * loss(fmap, label, otherinfo); // score not predict?
@@ -388,10 +387,10 @@ public double batchPredictFromFiles(String modelName,
388387
", local error total number:" + errorNum +
389388
", max error tol:" + maxErrorTol +
390389
", has read real lines:" + realcnt +
391-
", weight lines:" + weightCnt);
390+
", weight lines:" + weightCnt, e);
392391
if (errorNum > maxErrorTol) {
393392
LOG.error("[ERROR] error number:" + errorNum +
394-
" > " + "max tol:" + maxErrorTol);
393+
" > " + "max tol:" + maxErrorTol, e);
395394
throw e;
396395
}
397396
}

src/main/java/com/fenbi/ytklearn/predictor/FFMOnlinePredictor.java

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class FFMOnlinePredictor extends ContinuousOnlinePredictor<float[]> {
4343
private final Map<String, Integer> field2IndexMap = new HashMap<>();
4444
private int fieldSize;
4545

46-
private final ThreadLocal<double[]> assistbuffer = new ThreadLocal<>();
46+
private final ThreadLocal<float[]> assistbuffer = new ThreadLocal<>();
4747
private final ThreadLocal<int[]> fieldbuffer = new ThreadLocal<>();
4848
private final ThreadLocal<float[]> valbuffer = new ThreadLocal<>();
4949

@@ -55,9 +55,9 @@ public FFMOnlinePredictor(String configPath) throws Exception {
5555
List<Integer> klist = config.getIntList("k");
5656
K = klist.get(1);
5757

58-
fieldDelim = config.getString("field_delim");
58+
fieldDelim = config.getString("data.delim.field_delim");
5959
fieldDictPath = config.getString("model.field_dict_path");
60-
maxFeatureNum = config.getInt("max_line_feature_num");
60+
maxFeatureNum = config.getInt("data.max_feature_dim") + 1;
6161

6262
loadModel();
6363
}
@@ -68,9 +68,9 @@ public FFMOnlinePredictor(Reader configReader) throws Exception {
6868
List<Integer> klist = config.getIntList("k");
6969
K = klist.get(1);
7070

71-
fieldDelim = config.getString("field_delim");
71+
fieldDelim = config.getString("data.delim.field_delim");
7272
fieldDictPath = config.getString("model.field_dict_path");
73-
maxFeatureNum = config.getInt("max_line_feature_num");
73+
maxFeatureNum = config.getInt("data.max_feature_dim") + 1;
7474

7575
loadModel();
7676
}
@@ -104,24 +104,24 @@ protected OnlinePredictor loadModel() throws Exception {
104104
}
105105

106106
int cnt = 0;
107-
iterators = fs.read(Arrays.asList(fieldDictPath));
107+
iterators = fs.read(Arrays.asList(modelParams.data_path));
108108
for (Iterator<String> it : iterators) {
109109
while (it.hasNext()) {
110110
String line = it.next();
111111
if (line.trim().length() == 0) {
112-
LOG.error("invalid model line:" + line);
112+
LOG.error("invalid model line(length=0):" + line);
113113
continue;
114114
}
115115
String []info = line.trim().split(modelParams.delim);
116-
if (fieldSize != (info.length - 5) / K) {
117-
LOG.info("invalid model line:" + line);
118-
continue;
119-
}
116+
// if (fieldSize != (info.length - 5) / K) {
117+
// LOG.info("invalid model line:" + line);
118+
// continue;
119+
// }
120120

121-
if (info.length < 2) {
122-
LOG.error("[invalid model line:" + line);
123-
continue;
124-
}
121+
// if (info.length < 2) {
122+
// LOG.error("[invalid model line:" + line);
123+
// continue;
124+
// }
125125

126126

127127
float []w = modelMap.get(info[0]);
@@ -156,9 +156,9 @@ public double score(Map<String, Float> features, Object other) {
156156

157157
int stride = fieldSize * K;
158158

159-
double []assist = assistbuffer.get();
159+
float []assist = assistbuffer.get();
160160
if (assist == null) {
161-
assist = new double[K * fieldSize * (maxFeatureNum + 1)];
161+
assist = new float[K * fieldSize * (maxFeatureNum + 1)];
162162
assistbuffer.set(assist);
163163
}
164164

@@ -183,7 +183,11 @@ public double score(Map<String, Float> features, Object other) {
183183
for (Map.Entry<String, Float> feature : features.entrySet()) {
184184

185185
// field idx
186-
fieldIdxArr[cidx] = field2IndexMap.get(feature.getKey().split(fieldDelim)[0]);
186+
Integer fieldIdx = field2IndexMap.get(feature.getKey().split(fieldDelim)[0]);
187+
if (fieldIdx == null) {
188+
continue;
189+
}
190+
fieldIdxArr[cidx] = fieldIdx;
187191

188192
float val = transform(feature.getKey(), feature.getValue());
189193
// val

0 commit comments

Comments
 (0)