Skip to content

Commit 063ceb1

Browse files
committed
Sparse Embedding Support(MSMARCO)
Change-Id: I478723e781d2eb188eca9900bfc3d491caad601c
1 parent 16c9e83 commit 063ceb1

7 files changed

Lines changed: 1181 additions & 0 deletions

File tree

src/main/java/MSMARCOLoader.java

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
2+
import java.io.IOException;
3+
import java.util.HashMap;
4+
import java.util.concurrent.TimeUnit;
5+
6+
import org.apache.commons.cli.CommandLine;
7+
import org.apache.commons.cli.CommandLineParser;
8+
import org.apache.commons.cli.DefaultParser;
9+
import org.apache.commons.cli.HelpFormatter;
10+
import org.apache.commons.cli.Option;
11+
import org.apache.commons.cli.Options;
12+
import org.apache.commons.cli.ParseException;
13+
import org.apache.log4j.LogManager;
14+
import org.apache.log4j.Logger;
15+
16+
import couchbase.loadgen.WorkLoadGenerate;
17+
import couchbase.sdk.SDKClientPool;
18+
import couchbase.sdk.Server;
19+
import elasticsearch.EsClient;
20+
import utils.docgen.DRConstants;
21+
import utils.docgen.DocRange;
22+
import utils.docgen.DocumentGenerator;
23+
import utils.docgen.WorkLoadSettings;
24+
import utils.taskmanager.TaskManager;
25+
import utils.val.MSMARCOEmbeddingProduct;
26+
import utils.val.MSMARCOSiftEmbeddingProduct;
27+
28+
public class MSMARCOLoader {
29+
static Logger logger = LogManager.getLogger(MSMARCOLoader.class);
30+
31+
public static void main(String[] args) throws IOException {
32+
logger.info("#################### Starting MSMARCO Sparse Vector Loader ####################");
33+
34+
Options options = new Options();
35+
36+
Option name = new Option("n", "node", true, "IP Address");
37+
name.setRequired(true);
38+
options.addOption(name);
39+
40+
Option rest_username = new Option("user", "rest_username", true, "Username");
41+
rest_username.setRequired(true);
42+
options.addOption(rest_username);
43+
44+
Option rest_password = new Option("pwd", "rest_password", true, "Password");
45+
rest_password.setRequired(true);
46+
options.addOption(rest_password);
47+
48+
Option bucket = new Option("b", "bucket", true, "Bucket");
49+
bucket.setRequired(true);
50+
options.addOption(bucket);
51+
52+
Option port = new Option("p", "port", true, "Memcached Port");
53+
port.setRequired(true);
54+
options.addOption(port);
55+
56+
options.addOption(new Option("scope", true, "Scope"));
57+
options.addOption(new Option("collection", true, "Collection"));
58+
59+
options.addOption(new Option("create_s", "create_s", true, "Creates Start"));
60+
options.addOption(new Option("create_e", "create_e", true, "Creates End"));
61+
options.addOption(new Option("read_s", "read_s", true, "Read Start"));
62+
options.addOption(new Option("read_e", "read_e", true, "Read End"));
63+
options.addOption(new Option("update_s", "update_s", true, "Update Start"));
64+
options.addOption(new Option("update_e", "update_e", true, "Update End"));
65+
options.addOption(new Option("delete_s", "delete_s", true, "Delete Start"));
66+
options.addOption(new Option("delete_e", "delete_e", true, "Delete End"));
67+
options.addOption(new Option("touch_s", "touch_s", true, "Touch Start"));
68+
options.addOption(new Option("touch_e", "touch_e", true, "Touch End"));
69+
options.addOption(new Option("replace_s", "replace_s", true, "Replace Start"));
70+
options.addOption(new Option("replace_e", "replace_e", true, "Replace End"));
71+
options.addOption(new Option("expiry_s", "expiry_s", true, "Expiry Start"));
72+
options.addOption(new Option("expiry_e", "expiry_e", true, "Expiry End"));
73+
74+
options.addOption(new Option("cr", "create", true, "Creates%"));
75+
options.addOption(new Option("up", "update", true, "Updates%"));
76+
options.addOption(new Option("dl", "delete", true, "Deletes%"));
77+
options.addOption(new Option("ex", "expiry", true, "Expiry%"));
78+
options.addOption(new Option("rd", "read", true, "Reads%"));
79+
80+
options.addOption(new Option("w", "workers", true, "Workers"));
81+
options.addOption(new Option("ops", "ops", true, "Ops/Sec"));
82+
options.addOption(new Option("keySize", "keySize", true, "Size of the key"));
83+
options.addOption(new Option("docSize", "docSize", true, "Size of the doc"));
84+
options.addOption(new Option("loadType", "loadType", true, "Hot/Cold"));
85+
options.addOption(new Option("keyType", "keyType", true, "Random/Sequential/Reverse"));
86+
options.addOption(new Option("keyPrefix", "keyPrefix", true, "String"));
87+
options.addOption(new Option("validate", "validate", true, "Validate Data during Reads"));
88+
options.addOption(new Option("gtm", "gtm", true, "Go for max doc ops"));
89+
options.addOption(new Option("deleted", "deleted", true, "To verify deleted docs"));
90+
options.addOption(new Option("base64", "base64", true, "base64 encoding for Vector embedding"));
91+
options.addOption(new Option("durability", true, "Durability Level"));
92+
options.addOption(new Option("mutate", true, "mutate"));
93+
options.addOption(new Option("mutation_timeout", true, "Mutation timeout in seconds"));
94+
options.addOption(new Option("mutate_field", true, "Mutate field"));
95+
options.addOption(new Option("maxTTL", true, "Expiry Time"));
96+
options.addOption(new Option("maxTTLUnit", true, "Expiry Time unit"));
97+
options.addOption(new Option("retry", true, "Retry failures n times"));
98+
99+
options.addOption(new Option("vecFilePath", true, "Path to the .vec sparse vector file"));
100+
options.addOption(new Option("siftFilePath", true, "Path to SIFT bigann_base.bvecs file (required for MSMARCOSiftEmbeddingProduct)"));
101+
options.addOption(new Option("valueType", true, "Value type to generate (default MSMARCOEmbeddingProduct)"));
102+
103+
Option elastic = new Option("elastic", "elastic", true, "Flag to insert data in ElasticSearch cluster");
104+
options.addOption(elastic);
105+
Option esServer = new Option("esServer", "esServer", true, "ElasticSearch cluster");
106+
options.addOption(esServer);
107+
Option esAPIKey = new Option("esAPIKey", "esAPIKey", true, "ElasticSearch APIKey");
108+
options.addOption(esAPIKey);
109+
Option esSimilarity = new Option("esSimilarity", "esSimilarity", true, "ElasticSearch esSimilarity");
110+
options.addOption(esSimilarity);
111+
Option skipCB = new Option("skipCB", "skipCB", true, "Skip loading data into Couchbase");
112+
options.addOption(skipCB);
113+
114+
HelpFormatter formatter = new HelpFormatter();
115+
CommandLineParser parser = new DefaultParser();
116+
CommandLine cmd;
117+
try {
118+
cmd = parser.parse(options, args);
119+
} catch (ParseException e) {
120+
e.printStackTrace();
121+
System.out.println(e.getMessage());
122+
formatter.printHelp("Supported Options", options);
123+
System.exit(1);
124+
return;
125+
}
126+
127+
String vecFilePath = cmd.getOptionValue("vecFilePath");
128+
if (vecFilePath == null || vecFilePath.trim().isEmpty()) {
129+
System.err.println("Error: -vecFilePath is required (path to .vec sparse vector file)");
130+
System.exit(1);
131+
return;
132+
}
133+
String valueType = cmd.getOptionValue("valueType", MSMARCOEmbeddingProduct.class.getSimpleName());
134+
String siftFilePath = cmd.getOptionValue("siftFilePath");
135+
if (MSMARCOSiftEmbeddingProduct.class.getSimpleName().equals(valueType)
136+
&& (siftFilePath == null || siftFilePath.trim().isEmpty())) {
137+
System.err.println("Error: -siftFilePath is required when -valueType is MSMARCOSiftEmbeddingProduct");
138+
System.exit(1);
139+
return;
140+
}
141+
142+
Server master = new Server(cmd.getOptionValue("node"), cmd.getOptionValue("port"),
143+
cmd.getOptionValue("rest_username"), cmd.getOptionValue("rest_password"), cmd.getOptionValue("port"));
144+
TaskManager tm = new TaskManager(Integer.parseInt(cmd.getOptionValue("workers", "10")));
145+
SDKClientPool clientPool = new SDKClientPool();
146+
String cb = cmd.getOptionValue("skipCB", "false");
147+
if (!Boolean.parseBoolean(cb)) {
148+
try {
149+
clientPool.create_clients(cmd.getOptionValue("bucket"), master, 2);
150+
} catch (Exception e) {
151+
e.printStackTrace();
152+
}
153+
}
154+
155+
EsClient esClient = null;
156+
if (Boolean.parseBoolean(cmd.getOptionValue("elastic", "false"))) {
157+
if (cmd.getOptionValue("esAPIKey") != null)
158+
esClient = new EsClient(cmd.getOptionValue("esServer"), cmd.getOptionValue("esAPIKey"));
159+
if (esClient != null) {
160+
esClient.initializeSDK();
161+
esClient.deleteESIndex(cmd.getOptionValue("collection", "_default").replace("_", ""));
162+
esClient.createESIndex(cmd.getOptionValue("collection", "_default").replace("_", ""),
163+
cmd.getOptionValue("esSimilarity", "l2_norm"), null);
164+
}
165+
}
166+
167+
// Use the same step ranges as MSMARCOEmbeddingProduct
168+
long[] steps = MSMARCOEmbeddingProduct.getSteps();
169+
int poolSize = Integer.parseInt(cmd.getOptionValue("workers", "10"));
170+
long start_offset = 0, end_offset = 0;
171+
if (Integer.parseInt(cmd.getOptionValue("cr", "0")) > 0) {
172+
start_offset = Long.parseLong(cmd.getOptionValue(DRConstants.create_s, "0"));
173+
end_offset = Long.parseLong(cmd.getOptionValue(DRConstants.create_e, "0"));
174+
} else if (Integer.parseInt(cmd.getOptionValue("up", "0")) > 0) {
175+
start_offset = Long.parseLong(cmd.getOptionValue(DRConstants.update_s, "0"));
176+
end_offset = Long.parseLong(cmd.getOptionValue(DRConstants.update_e, "0"));
177+
} else if (Integer.parseInt(cmd.getOptionValue("ex", "0")) > 0) {
178+
start_offset = Long.parseLong(cmd.getOptionValue(DRConstants.expiry_s, "0"));
179+
end_offset = Long.parseLong(cmd.getOptionValue(DRConstants.expiry_e, "0"));
180+
}
181+
182+
// Find which step range contains start_offset
183+
int k = 0;
184+
while (!(steps[k] <= start_offset && start_offset < steps[k + 1]))
185+
k += 1;
186+
187+
// Process each step range that overlaps with [start_offset, end_offset)
188+
while (steps[k] < end_offset) {
189+
long start = Math.max(start_offset, steps[k]);
190+
long end = Math.min(end_offset, steps[k + 1]);
191+
long step = (end - start) / poolSize;
192+
193+
for (int i = 0; i < poolSize; i++) {
194+
WorkLoadSettings ws = new WorkLoadSettings(
195+
cmd.getOptionValue("keyPrefix", "msmarco-"),
196+
Integer.parseInt(cmd.getOptionValue("keySize", "20")),
197+
Integer.parseInt(cmd.getOptionValue("docSize", "256")),
198+
Integer.parseInt(cmd.getOptionValue("cr", "0")),
199+
Integer.parseInt(cmd.getOptionValue("rd", "0")),
200+
Integer.parseInt(cmd.getOptionValue("up", "0")),
201+
Integer.parseInt(cmd.getOptionValue("dl", "0")),
202+
Integer.parseInt(cmd.getOptionValue("ex", "0")),
203+
Integer.parseInt(cmd.getOptionValue("workers", "10")),
204+
Integer.parseInt(cmd.getOptionValue("ops", "10000")),
205+
cmd.getOptionValue("loadType", null),
206+
cmd.getOptionValue("keyType", "SimpleKey"),
207+
valueType,
208+
Boolean.parseBoolean(cmd.getOptionValue("validate", "false")),
209+
Boolean.parseBoolean(cmd.getOptionValue("gtm", "false")),
210+
Boolean.parseBoolean(cmd.getOptionValue("deleted", "false")),
211+
Integer.parseInt(cmd.getOptionValue("mutate", "0")),
212+
Boolean.parseBoolean(cmd.getOptionValue("elastic", "false")),
213+
cmd.getOptionValue("model", ""),
214+
false,
215+
0,
216+
Boolean.parseBoolean(cmd.getOptionValue("base64", "false")),
217+
cmd.getOptionValue("mutate_field", ""),
218+
Integer.parseInt(cmd.getOptionValue("mutation_timeout", "0")),
219+
vecFilePath);
220+
ws.embeddingFilePath = vecFilePath;
221+
ws.baseVectorsFilePath = MSMARCOSiftEmbeddingProduct.class.getSimpleName().equals(valueType)
222+
? siftFilePath
223+
: vecFilePath;
224+
225+
long workerStart = start + step * i;
226+
long workerEnd = (i == poolSize - 1) ? end : start + step * (i + 1);
227+
HashMap<String, Number> dr = new HashMap<String, Number>();
228+
dr.put(DRConstants.create_s, workerStart);
229+
dr.put(DRConstants.create_e, workerEnd);
230+
dr.put(DRConstants.read_s, Long.parseLong(cmd.getOptionValue(DRConstants.read_s, "0")));
231+
dr.put(DRConstants.read_e, Long.parseLong(cmd.getOptionValue(DRConstants.read_e, "0")));
232+
dr.put(DRConstants.update_s, workerStart);
233+
dr.put(DRConstants.update_e, workerEnd);
234+
dr.put(DRConstants.delete_s, Long.parseLong(cmd.getOptionValue(DRConstants.delete_s, "0")));
235+
dr.put(DRConstants.delete_e, Long.parseLong(cmd.getOptionValue(DRConstants.delete_e, "0")));
236+
dr.put(DRConstants.touch_s, Long.parseLong(cmd.getOptionValue(DRConstants.touch_s, "0")));
237+
dr.put(DRConstants.touch_e, Long.parseLong(cmd.getOptionValue(DRConstants.touch_e, "0")));
238+
dr.put(DRConstants.replace_s, Long.parseLong(cmd.getOptionValue(DRConstants.replace_s, "0")));
239+
dr.put(DRConstants.replace_e, Long.parseLong(cmd.getOptionValue(DRConstants.replace_e, "0")));
240+
dr.put(DRConstants.expiry_s, workerStart);
241+
dr.put(DRConstants.expiry_e, workerEnd);
242+
243+
DocRange range = new DocRange(dr);
244+
ws.dr = range;
245+
DocumentGenerator dg = null;
246+
try {
247+
dg = new DocumentGenerator(ws, ws.keyType, ws.valueType);
248+
} catch (ClassNotFoundException e1) {
249+
e1.printStackTrace();
250+
}
251+
try {
252+
String th_name = "MSMARCOLoader_" + k + "_" + ws.dr.create_s + "_" + ws.dr.create_e;
253+
boolean trackFailures = false;
254+
if (Integer.parseInt(cmd.getOptionValue("retry", "0")) > 0)
255+
trackFailures = true;
256+
WorkLoadGenerate wlg = new WorkLoadGenerate(th_name, dg, clientPool, esClient,
257+
cmd.getOptionValue("durability", "NONE"),
258+
Integer.parseInt(cmd.getOptionValue("maxTTL", "0")),
259+
cmd.getOptionValue("maxTTLUnit", "seconds"), trackFailures,
260+
Integer.parseInt(cmd.getOptionValue("retry", "0")), null);
261+
wlg.set_collection_for_load(
262+
cmd.getOptionValue("bucket"),
263+
cmd.getOptionValue("scope", "_default"),
264+
cmd.getOptionValue("collection", "_default"));
265+
tm.submit(wlg);
266+
TimeUnit.MILLISECONDS.sleep(500);
267+
} catch (Exception e) {
268+
e.printStackTrace();
269+
}
270+
}
271+
k += 1;
272+
}
273+
tm.getAllTaskResult();
274+
tm.shutdown();
275+
if (esClient != null)
276+
esClient.transport.close();
277+
}
278+
}

src/main/java/RestServer/RestApplication.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,18 @@ public ResponseEntity<Map<String, Object>> check_sift_file(@RequestBody TaskRequ
146146
}
147147
}
148148

149+
@PostMapping(value="/msmarco_doc_load")
150+
public ResponseEntity<Map<String, Object>> msmarco_doc_load(@RequestBody TaskRequest taskRequest) {
151+
try {
152+
return taskRequest.loadMSMARCODataset();
153+
} catch (Exception e) {
154+
Map<String, Object> body = new HashMap<>();
155+
body.put("error", e.toString());
156+
body.put("status", false);
157+
return new ResponseEntity<>(body, HttpStatus.BAD_REQUEST);
158+
}
159+
}
160+
149161
@PostMapping(value="/sift_doc_load")
150162
public ResponseEntity<Map<String, Object>> sift_doc_load(@RequestBody TaskRequest taskRequest) {
151163
try {

0 commit comments

Comments
 (0)