-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMatrixUpdater.java
120 lines (89 loc) · 3.9 KB
/
MatrixUpdater.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
public class MatrixUpdater {
private static class UMapper extends Mapper<LongWritable, Text, LongWritable, Text> {
public void map(LongWritable key, Text value, Context context)
throws IOException, NumberFormatException, InterruptedException {
String[] vals = value.toString().split("\t");
if (!vals[1].contains(":")) {
vals[1] = "m:" + vals[1];
}
context.write(new LongWritable(Long.parseLong(vals[0])), new Text(vals[1]));
}
}
private static class UReducer extends Reducer<LongWritable, Text, LongWritable, Text> {
private int k;
@Override
protected void setup(Context context) throws IOException, InterruptedException {
k = context.getConfiguration().getInt("mw", -1);
}
public void reduce(LongWritable key, Iterable<Text> values,
Context context) throws IOException, InterruptedException {
boolean sqrt = context.getConfiguration().getBoolean("sqrt", false);
StringBuilder result = new StringBuilder();
String[] arrayNames = new String[] {"m", "a", "b"};
Map<String, double[]> arrays = new HashMap<>();
for (String arrayName : arrayNames) {
arrays.put(arrayName, new double[k]);
}
for (Text value : values) {
String[] keyVal = value.toString().split(":");
String[] xi = keyVal[1].split(",");
for (int j = 0; j < k; j++) {
arrays.get(keyVal[0])[j] = Double.parseDouble(xi[j]);
}
}
for (int j = 0; j < k; j++) {
double frac = arrays.get("a")[j] / arrays.get("b")[j];
if (sqrt) {
frac = Math.sqrt(frac);
}
result.append(arrays.get("m")[j] * frac);
if (j != k - 1)
result.append(",");
}
context.write(key, new Text(result.toString()));
}
}
public static void addInpuPath(Job job, Path path) throws IOException {
FileSystem fs = path.getFileSystem(new Configuration());
if (fs.isDirectory(path)) {
for (Path p : FileUtil.stat2Paths(fs.listStatus(path))) {
if (p.toString().contains("part"))
FileInputFormat.addInputPath(job, p);
}
} else {
FileInputFormat.addInputPath(job, path);
}
}
private Configuration configuration;
private String[] inputPaths;
private String outputPath;
private boolean sqrt;
public MatrixUpdater(Configuration configuration, String[] inputPaths, String outputPath) {
this.configuration = configuration;
this.inputPaths = inputPaths;
this.outputPath = outputPath;
this.sqrt = false;
}
public MatrixUpdater(Configuration configuration, String[] inputPaths, String outputPath, boolean sqrt) {
this.configuration = configuration;
this.inputPaths = inputPaths;
this.outputPath = outputPath;
this.sqrt = sqrt;
}
public void run() throws IOException, ClassNotFoundException, InterruptedException {
configuration.setBoolean("sqrt", sqrt);
Job job = Job.getInstance(configuration, "com.lsdp.util.MatrixUpdater");
job.setJarByClass(MRNMF.class);
for (String path : inputPaths) {
addInpuPath(job, new Path(path));
}
FileOutputFormat.setOutputPath(job, new Path(outputPath));
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setMapOutputKeyClass(LongWritable.class);
job.setMapOutputValueClass(Text.class);
job.setMapperClass(UMapper.class);
job.setReducerClass(UReducer.class);
job.waitForCompletion(true);
}
}