|
| 1 | + |
| 2 | +import java.io.IOException; |
| 3 | +import java.util.HashMap; |
| 4 | +import java.util.concurrent.TimeUnit; |
| 5 | + |
| 6 | +import org.apache.commons.cli.CommandLine; |
| 7 | +import org.apache.commons.cli.CommandLineParser; |
| 8 | +import org.apache.commons.cli.DefaultParser; |
| 9 | +import org.apache.commons.cli.HelpFormatter; |
| 10 | +import org.apache.commons.cli.Option; |
| 11 | +import org.apache.commons.cli.Options; |
| 12 | +import org.apache.commons.cli.ParseException; |
| 13 | +import org.apache.log4j.LogManager; |
| 14 | +import org.apache.log4j.Logger; |
| 15 | + |
| 16 | +import couchbase.loadgen.WorkLoadGenerate; |
| 17 | +import couchbase.sdk.SDKClientPool; |
| 18 | +import couchbase.sdk.Server; |
| 19 | +import elasticsearch.EsClient; |
| 20 | +import utils.docgen.DRConstants; |
| 21 | +import utils.docgen.DocRange; |
| 22 | +import utils.docgen.DocumentGenerator; |
| 23 | +import utils.docgen.WorkLoadSettings; |
| 24 | +import utils.taskmanager.TaskManager; |
| 25 | +import utils.val.MSMARCOEmbeddingProduct; |
| 26 | +import utils.val.MSMARCOSiftEmbeddingProduct; |
| 27 | + |
| 28 | +public class MSMARCOLoader { |
| 29 | + static Logger logger = LogManager.getLogger(MSMARCOLoader.class); |
| 30 | + |
| 31 | + public static void main(String[] args) throws IOException { |
| 32 | + logger.info("#################### Starting MSMARCO Sparse Vector Loader ####################"); |
| 33 | + |
| 34 | + Options options = new Options(); |
| 35 | + |
| 36 | + Option name = new Option("n", "node", true, "IP Address"); |
| 37 | + name.setRequired(true); |
| 38 | + options.addOption(name); |
| 39 | + |
| 40 | + Option rest_username = new Option("user", "rest_username", true, "Username"); |
| 41 | + rest_username.setRequired(true); |
| 42 | + options.addOption(rest_username); |
| 43 | + |
| 44 | + Option rest_password = new Option("pwd", "rest_password", true, "Password"); |
| 45 | + rest_password.setRequired(true); |
| 46 | + options.addOption(rest_password); |
| 47 | + |
| 48 | + Option bucket = new Option("b", "bucket", true, "Bucket"); |
| 49 | + bucket.setRequired(true); |
| 50 | + options.addOption(bucket); |
| 51 | + |
| 52 | + Option port = new Option("p", "port", true, "Memcached Port"); |
| 53 | + port.setRequired(true); |
| 54 | + options.addOption(port); |
| 55 | + |
| 56 | + options.addOption(new Option("scope", true, "Scope")); |
| 57 | + options.addOption(new Option("collection", true, "Collection")); |
| 58 | + |
| 59 | + options.addOption(new Option("create_s", "create_s", true, "Creates Start")); |
| 60 | + options.addOption(new Option("create_e", "create_e", true, "Creates End")); |
| 61 | + options.addOption(new Option("read_s", "read_s", true, "Read Start")); |
| 62 | + options.addOption(new Option("read_e", "read_e", true, "Read End")); |
| 63 | + options.addOption(new Option("update_s", "update_s", true, "Update Start")); |
| 64 | + options.addOption(new Option("update_e", "update_e", true, "Update End")); |
| 65 | + options.addOption(new Option("delete_s", "delete_s", true, "Delete Start")); |
| 66 | + options.addOption(new Option("delete_e", "delete_e", true, "Delete End")); |
| 67 | + options.addOption(new Option("touch_s", "touch_s", true, "Touch Start")); |
| 68 | + options.addOption(new Option("touch_e", "touch_e", true, "Touch End")); |
| 69 | + options.addOption(new Option("replace_s", "replace_s", true, "Replace Start")); |
| 70 | + options.addOption(new Option("replace_e", "replace_e", true, "Replace End")); |
| 71 | + options.addOption(new Option("expiry_s", "expiry_s", true, "Expiry Start")); |
| 72 | + options.addOption(new Option("expiry_e", "expiry_e", true, "Expiry End")); |
| 73 | + |
| 74 | + options.addOption(new Option("cr", "create", true, "Creates%")); |
| 75 | + options.addOption(new Option("up", "update", true, "Updates%")); |
| 76 | + options.addOption(new Option("dl", "delete", true, "Deletes%")); |
| 77 | + options.addOption(new Option("ex", "expiry", true, "Expiry%")); |
| 78 | + options.addOption(new Option("rd", "read", true, "Reads%")); |
| 79 | + |
| 80 | + options.addOption(new Option("w", "workers", true, "Workers")); |
| 81 | + options.addOption(new Option("ops", "ops", true, "Ops/Sec")); |
| 82 | + options.addOption(new Option("keySize", "keySize", true, "Size of the key")); |
| 83 | + options.addOption(new Option("docSize", "docSize", true, "Size of the doc")); |
| 84 | + options.addOption(new Option("loadType", "loadType", true, "Hot/Cold")); |
| 85 | + options.addOption(new Option("keyType", "keyType", true, "Random/Sequential/Reverse")); |
| 86 | + options.addOption(new Option("keyPrefix", "keyPrefix", true, "String")); |
| 87 | + options.addOption(new Option("validate", "validate", true, "Validate Data during Reads")); |
| 88 | + options.addOption(new Option("gtm", "gtm", true, "Go for max doc ops")); |
| 89 | + options.addOption(new Option("deleted", "deleted", true, "To verify deleted docs")); |
| 90 | + options.addOption(new Option("base64", "base64", true, "base64 encoding for Vector embedding")); |
| 91 | + options.addOption(new Option("durability", true, "Durability Level")); |
| 92 | + options.addOption(new Option("mutate", true, "mutate")); |
| 93 | + options.addOption(new Option("mutation_timeout", true, "Mutation timeout in seconds")); |
| 94 | + options.addOption(new Option("mutate_field", true, "Mutate field")); |
| 95 | + options.addOption(new Option("maxTTL", true, "Expiry Time")); |
| 96 | + options.addOption(new Option("maxTTLUnit", true, "Expiry Time unit")); |
| 97 | + options.addOption(new Option("retry", true, "Retry failures n times")); |
| 98 | + |
| 99 | + options.addOption(new Option("vecFilePath", true, "Path to the .vec sparse vector file")); |
| 100 | + options.addOption(new Option("siftFilePath", true, "Path to SIFT bigann_base.bvecs file (required for MSMARCOSiftEmbeddingProduct)")); |
| 101 | + options.addOption(new Option("valueType", true, "Value type to generate (default MSMARCOEmbeddingProduct)")); |
| 102 | + |
| 103 | + Option elastic = new Option("elastic", "elastic", true, "Flag to insert data in ElasticSearch cluster"); |
| 104 | + options.addOption(elastic); |
| 105 | + Option esServer = new Option("esServer", "esServer", true, "ElasticSearch cluster"); |
| 106 | + options.addOption(esServer); |
| 107 | + Option esAPIKey = new Option("esAPIKey", "esAPIKey", true, "ElasticSearch APIKey"); |
| 108 | + options.addOption(esAPIKey); |
| 109 | + Option esSimilarity = new Option("esSimilarity", "esSimilarity", true, "ElasticSearch esSimilarity"); |
| 110 | + options.addOption(esSimilarity); |
| 111 | + Option skipCB = new Option("skipCB", "skipCB", true, "Skip loading data into Couchbase"); |
| 112 | + options.addOption(skipCB); |
| 113 | + |
| 114 | + HelpFormatter formatter = new HelpFormatter(); |
| 115 | + CommandLineParser parser = new DefaultParser(); |
| 116 | + CommandLine cmd; |
| 117 | + try { |
| 118 | + cmd = parser.parse(options, args); |
| 119 | + } catch (ParseException e) { |
| 120 | + e.printStackTrace(); |
| 121 | + System.out.println(e.getMessage()); |
| 122 | + formatter.printHelp("Supported Options", options); |
| 123 | + System.exit(1); |
| 124 | + return; |
| 125 | + } |
| 126 | + |
| 127 | + String vecFilePath = cmd.getOptionValue("vecFilePath"); |
| 128 | + if (vecFilePath == null || vecFilePath.trim().isEmpty()) { |
| 129 | + System.err.println("Error: -vecFilePath is required (path to .vec sparse vector file)"); |
| 130 | + System.exit(1); |
| 131 | + return; |
| 132 | + } |
| 133 | + String valueType = cmd.getOptionValue("valueType", MSMARCOEmbeddingProduct.class.getSimpleName()); |
| 134 | + String siftFilePath = cmd.getOptionValue("siftFilePath"); |
| 135 | + if (MSMARCOSiftEmbeddingProduct.class.getSimpleName().equals(valueType) |
| 136 | + && (siftFilePath == null || siftFilePath.trim().isEmpty())) { |
| 137 | + System.err.println("Error: -siftFilePath is required when -valueType is MSMARCOSiftEmbeddingProduct"); |
| 138 | + System.exit(1); |
| 139 | + return; |
| 140 | + } |
| 141 | + |
| 142 | + Server master = new Server(cmd.getOptionValue("node"), cmd.getOptionValue("port"), |
| 143 | + cmd.getOptionValue("rest_username"), cmd.getOptionValue("rest_password"), cmd.getOptionValue("port")); |
| 144 | + TaskManager tm = new TaskManager(Integer.parseInt(cmd.getOptionValue("workers", "10"))); |
| 145 | + SDKClientPool clientPool = new SDKClientPool(); |
| 146 | + String cb = cmd.getOptionValue("skipCB", "false"); |
| 147 | + if (!Boolean.parseBoolean(cb)) { |
| 148 | + try { |
| 149 | + clientPool.create_clients(cmd.getOptionValue("bucket"), master, 2); |
| 150 | + } catch (Exception e) { |
| 151 | + e.printStackTrace(); |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + EsClient esClient = null; |
| 156 | + if (Boolean.parseBoolean(cmd.getOptionValue("elastic", "false"))) { |
| 157 | + if (cmd.getOptionValue("esAPIKey") != null) |
| 158 | + esClient = new EsClient(cmd.getOptionValue("esServer"), cmd.getOptionValue("esAPIKey")); |
| 159 | + if (esClient != null) { |
| 160 | + esClient.initializeSDK(); |
| 161 | + esClient.deleteESIndex(cmd.getOptionValue("collection", "_default").replace("_", "")); |
| 162 | + esClient.createESIndex(cmd.getOptionValue("collection", "_default").replace("_", ""), |
| 163 | + cmd.getOptionValue("esSimilarity", "l2_norm"), null); |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + // Use the same step ranges as MSMARCOEmbeddingProduct |
| 168 | + long[] steps = MSMARCOEmbeddingProduct.getSteps(); |
| 169 | + int poolSize = Integer.parseInt(cmd.getOptionValue("workers", "10")); |
| 170 | + long start_offset = 0, end_offset = 0; |
| 171 | + if (Integer.parseInt(cmd.getOptionValue("cr", "0")) > 0) { |
| 172 | + start_offset = Long.parseLong(cmd.getOptionValue(DRConstants.create_s, "0")); |
| 173 | + end_offset = Long.parseLong(cmd.getOptionValue(DRConstants.create_e, "0")); |
| 174 | + } else if (Integer.parseInt(cmd.getOptionValue("up", "0")) > 0) { |
| 175 | + start_offset = Long.parseLong(cmd.getOptionValue(DRConstants.update_s, "0")); |
| 176 | + end_offset = Long.parseLong(cmd.getOptionValue(DRConstants.update_e, "0")); |
| 177 | + } else if (Integer.parseInt(cmd.getOptionValue("ex", "0")) > 0) { |
| 178 | + start_offset = Long.parseLong(cmd.getOptionValue(DRConstants.expiry_s, "0")); |
| 179 | + end_offset = Long.parseLong(cmd.getOptionValue(DRConstants.expiry_e, "0")); |
| 180 | + } |
| 181 | + |
| 182 | + // Find which step range contains start_offset |
| 183 | + int k = 0; |
| 184 | + while (!(steps[k] <= start_offset && start_offset < steps[k + 1])) |
| 185 | + k += 1; |
| 186 | + |
| 187 | + // Process each step range that overlaps with [start_offset, end_offset) |
| 188 | + while (steps[k] < end_offset) { |
| 189 | + long start = Math.max(start_offset, steps[k]); |
| 190 | + long end = Math.min(end_offset, steps[k + 1]); |
| 191 | + long step = (end - start) / poolSize; |
| 192 | + |
| 193 | + for (int i = 0; i < poolSize; i++) { |
| 194 | + WorkLoadSettings ws = new WorkLoadSettings( |
| 195 | + cmd.getOptionValue("keyPrefix", "msmarco-"), |
| 196 | + Integer.parseInt(cmd.getOptionValue("keySize", "20")), |
| 197 | + Integer.parseInt(cmd.getOptionValue("docSize", "256")), |
| 198 | + Integer.parseInt(cmd.getOptionValue("cr", "0")), |
| 199 | + Integer.parseInt(cmd.getOptionValue("rd", "0")), |
| 200 | + Integer.parseInt(cmd.getOptionValue("up", "0")), |
| 201 | + Integer.parseInt(cmd.getOptionValue("dl", "0")), |
| 202 | + Integer.parseInt(cmd.getOptionValue("ex", "0")), |
| 203 | + Integer.parseInt(cmd.getOptionValue("workers", "10")), |
| 204 | + Integer.parseInt(cmd.getOptionValue("ops", "10000")), |
| 205 | + cmd.getOptionValue("loadType", null), |
| 206 | + cmd.getOptionValue("keyType", "SimpleKey"), |
| 207 | + valueType, |
| 208 | + Boolean.parseBoolean(cmd.getOptionValue("validate", "false")), |
| 209 | + Boolean.parseBoolean(cmd.getOptionValue("gtm", "false")), |
| 210 | + Boolean.parseBoolean(cmd.getOptionValue("deleted", "false")), |
| 211 | + Integer.parseInt(cmd.getOptionValue("mutate", "0")), |
| 212 | + Boolean.parseBoolean(cmd.getOptionValue("elastic", "false")), |
| 213 | + cmd.getOptionValue("model", ""), |
| 214 | + false, |
| 215 | + 0, |
| 216 | + Boolean.parseBoolean(cmd.getOptionValue("base64", "false")), |
| 217 | + cmd.getOptionValue("mutate_field", ""), |
| 218 | + Integer.parseInt(cmd.getOptionValue("mutation_timeout", "0")), |
| 219 | + vecFilePath); |
| 220 | + ws.embeddingFilePath = vecFilePath; |
| 221 | + ws.baseVectorsFilePath = MSMARCOSiftEmbeddingProduct.class.getSimpleName().equals(valueType) |
| 222 | + ? siftFilePath |
| 223 | + : vecFilePath; |
| 224 | + |
| 225 | + long workerStart = start + step * i; |
| 226 | + long workerEnd = (i == poolSize - 1) ? end : start + step * (i + 1); |
| 227 | + HashMap<String, Number> dr = new HashMap<String, Number>(); |
| 228 | + dr.put(DRConstants.create_s, workerStart); |
| 229 | + dr.put(DRConstants.create_e, workerEnd); |
| 230 | + dr.put(DRConstants.read_s, Long.parseLong(cmd.getOptionValue(DRConstants.read_s, "0"))); |
| 231 | + dr.put(DRConstants.read_e, Long.parseLong(cmd.getOptionValue(DRConstants.read_e, "0"))); |
| 232 | + dr.put(DRConstants.update_s, workerStart); |
| 233 | + dr.put(DRConstants.update_e, workerEnd); |
| 234 | + dr.put(DRConstants.delete_s, Long.parseLong(cmd.getOptionValue(DRConstants.delete_s, "0"))); |
| 235 | + dr.put(DRConstants.delete_e, Long.parseLong(cmd.getOptionValue(DRConstants.delete_e, "0"))); |
| 236 | + dr.put(DRConstants.touch_s, Long.parseLong(cmd.getOptionValue(DRConstants.touch_s, "0"))); |
| 237 | + dr.put(DRConstants.touch_e, Long.parseLong(cmd.getOptionValue(DRConstants.touch_e, "0"))); |
| 238 | + dr.put(DRConstants.replace_s, Long.parseLong(cmd.getOptionValue(DRConstants.replace_s, "0"))); |
| 239 | + dr.put(DRConstants.replace_e, Long.parseLong(cmd.getOptionValue(DRConstants.replace_e, "0"))); |
| 240 | + dr.put(DRConstants.expiry_s, workerStart); |
| 241 | + dr.put(DRConstants.expiry_e, workerEnd); |
| 242 | + |
| 243 | + DocRange range = new DocRange(dr); |
| 244 | + ws.dr = range; |
| 245 | + DocumentGenerator dg = null; |
| 246 | + try { |
| 247 | + dg = new DocumentGenerator(ws, ws.keyType, ws.valueType); |
| 248 | + } catch (ClassNotFoundException e1) { |
| 249 | + e1.printStackTrace(); |
| 250 | + } |
| 251 | + try { |
| 252 | + String th_name = "MSMARCOLoader_" + k + "_" + ws.dr.create_s + "_" + ws.dr.create_e; |
| 253 | + boolean trackFailures = false; |
| 254 | + if (Integer.parseInt(cmd.getOptionValue("retry", "0")) > 0) |
| 255 | + trackFailures = true; |
| 256 | + WorkLoadGenerate wlg = new WorkLoadGenerate(th_name, dg, clientPool, esClient, |
| 257 | + cmd.getOptionValue("durability", "NONE"), |
| 258 | + Integer.parseInt(cmd.getOptionValue("maxTTL", "0")), |
| 259 | + cmd.getOptionValue("maxTTLUnit", "seconds"), trackFailures, |
| 260 | + Integer.parseInt(cmd.getOptionValue("retry", "0")), null); |
| 261 | + wlg.set_collection_for_load( |
| 262 | + cmd.getOptionValue("bucket"), |
| 263 | + cmd.getOptionValue("scope", "_default"), |
| 264 | + cmd.getOptionValue("collection", "_default")); |
| 265 | + tm.submit(wlg); |
| 266 | + TimeUnit.MILLISECONDS.sleep(500); |
| 267 | + } catch (Exception e) { |
| 268 | + e.printStackTrace(); |
| 269 | + } |
| 270 | + } |
| 271 | + k += 1; |
| 272 | + } |
| 273 | + tm.getAllTaskResult(); |
| 274 | + tm.shutdown(); |
| 275 | + if (esClient != null) |
| 276 | + esClient.transport.close(); |
| 277 | + } |
| 278 | +} |
0 commit comments