diff --git a/src/main/java/RestServer/TaskRequest.java b/src/main/java/RestServer/TaskRequest.java index 2926f4d..2854a8b 100644 --- a/src/main/java/RestServer/TaskRequest.java +++ b/src/main/java/RestServer/TaskRequest.java @@ -12,6 +12,7 @@ import mongo.sdk.MongoSDKClient; import couchbase.sdk.Result; import utils.common.FileDownload; +import utils.val.MSMARCOEmbeddingProduct; import utils.taskmanager.Task; import utils.taskmanager.TaskManager; @@ -979,6 +980,7 @@ public ResponseEntity> loadMSMARCODataset() throws IOExcepti } } + long[] steps = MSMARCOEmbeddingProduct.getSteps(); int poolSize = this.processConcurrency; long start_offset = 0, end_offset = 0; if (this.createPercent > 0) { @@ -992,68 +994,84 @@ public ResponseEntity> loadMSMARCODataset() throws IOExcepti end_offset = this.expiryEndIndex; } + if (end_offset <= start_offset) + throw new IllegalArgumentException("No docs to process: start_offset=" + + start_offset + " end_offset=" + end_offset); + if (start_offset < steps[0] || start_offset >= steps[steps.length - 1]) + throw new IllegalArgumentException("start_offset " + start_offset + " is outside STEPS bounds [" + + steps[0] + ", " + steps[steps.length - 1] + ")"); + if (end_offset > steps[steps.length - 1]) + throw new IllegalArgumentException("end_offset " + end_offset + + " exceeds STEPS upper bound " + steps[steps.length - 1]); + ArrayList task_names = new ArrayList(); - long totalDocs = end_offset - start_offset; - if (totalDocs <= 0) - throw new IllegalArgumentException("No docs to process: start_offset=" + start_offset + " end_offset=" + end_offset); - int effectivePool = (int) Math.min(poolSize, totalDocs); - long step = totalDocs / effectivePool; - for (int i = 0; i < effectivePool; i++) { - WorkLoadSettings ws = new WorkLoadSettings(this.keyPrefix, - this.keySize, this.docSize, - this.createPercent, this.readPercent, - this.updatePercent, this.deletePercent, this.expiryPercent, this.processConcurrency, - this.ops, this.loadType, this.keyType, msmarcoValueType, - this.validateDocs, this.gtm, this.validateDeletedDocs, this.mutate, - this.elastic, this.model, this.mockVector, - this.dim, this.base64, this.mutateField, - this.mutationTimeout, this.vecFilePath); - ws.embeddingFilePath = this.vecFilePath; - ws.baseVectorsFilePath = "MSMARCOSiftEmbeddingProduct".equals(msmarcoValueType) - ? this.baseVectorsFilePath - : this.vecFilePath; - - long workerStart = start_offset + step * i; - long workerEnd = (i == effectivePool - 1) ? end_offset : start_offset + step * (i + 1); - HashMap dr = new HashMap<>(); - dr.put(DRConstants.create_s, workerStart); - dr.put(DRConstants.create_e, workerEnd); - dr.put(DRConstants.read_s, this.readStartIndex); - dr.put(DRConstants.read_e, this.readEndIndex); - dr.put(DRConstants.update_s, workerStart); - dr.put(DRConstants.update_e, workerEnd); - dr.put(DRConstants.delete_s, this.deleteStartIndex); - dr.put(DRConstants.delete_e, this.deleteEndIndex); - dr.put(DRConstants.touch_s, this.touchStartIndex); - dr.put(DRConstants.touch_e, this.touchEndIndex); - dr.put(DRConstants.replace_s, this.replaceStartIndex); - dr.put(DRConstants.replace_e, this.replaceEndIndex); - dr.put(DRConstants.expiry_s, workerStart); - dr.put(DRConstants.expiry_e, workerEnd); - - DocRange range = new DocRange(dr); - DocumentGenerator dg = null; - - ws.dr = range; - try { - dg = new DocumentGenerator(ws, ws.keyType, ws.valueType); - } catch (Exception e) { - body.put("error", "Failed to create doc generator"); - body.put("message", e.toString()); - return new ResponseEntity<>(body, HttpStatus.BAD_REQUEST); - } + HashMap pendingTasks = new HashMap<>(); + int k = 0; + while (!(steps[k] <= start_offset && start_offset < steps[k + 1])) + k += 1; + while (steps[k] < end_offset) { + long start = Math.max(start_offset, steps[k]); + long end = Math.min(end_offset, steps[k + 1]); + int effectivePool = (int) Math.min(poolSize, end - start); + long step = (end - start) / effectivePool; + for (int i = 0; i < effectivePool; i++) { + WorkLoadSettings ws = new WorkLoadSettings(this.keyPrefix, + this.keySize, this.docSize, + this.createPercent, this.readPercent, + this.updatePercent, this.deletePercent, this.expiryPercent, this.processConcurrency, + this.ops, this.loadType, this.keyType, msmarcoValueType, + this.validateDocs, this.gtm, this.validateDeletedDocs, this.mutate, + this.elastic, this.model, this.mockVector, + this.dim, this.base64, this.mutateField, + this.mutationTimeout, this.vecFilePath); + ws.baseVectorsFilePath = "MSMARCOSiftEmbeddingProduct".equals(msmarcoValueType) + ? this.baseVectorsFilePath + : this.vecFilePath; + + long workerStart = start + step * i; + long workerEnd = (i == effectivePool - 1) ? end : start + step * (i + 1); + HashMap dr = new HashMap<>(); + dr.put(DRConstants.create_s, workerStart); + dr.put(DRConstants.create_e, workerEnd); + dr.put(DRConstants.read_s, this.readStartIndex); + dr.put(DRConstants.read_e, this.readEndIndex); + dr.put(DRConstants.update_s, workerStart); + dr.put(DRConstants.update_e, workerEnd); + dr.put(DRConstants.delete_s, this.deleteStartIndex); + dr.put(DRConstants.delete_e, this.deleteEndIndex); + dr.put(DRConstants.touch_s, this.touchStartIndex); + dr.put(DRConstants.touch_e, this.touchEndIndex); + dr.put(DRConstants.replace_s, this.replaceStartIndex); + dr.put(DRConstants.replace_e, this.replaceEndIndex); + dr.put(DRConstants.expiry_s, workerStart); + dr.put(DRConstants.expiry_e, workerEnd); - String task_name = "MSMARCOTask_" + TaskRequest.task_id.incrementAndGet() + "_" + ws.dr.create_s + "_" - + ws.dr.create_e; - int retry = 0; - String th_name = task_name + "_" + i; - WorkLoadGenerate wlg = new WorkLoadGenerate(th_name, dg, TaskRequest.SDKClientPool, esClient, - this.durabilityLevel, - this.docTTL, this.docTTLUnit, this.trackFailures, retry, null); - wlg.set_collection_for_load(this.bucketName, this.scopeName, this.collectionName); - TaskRequest.loader_tasks.put(th_name, wlg); - task_names.add(th_name); + DocRange range = new DocRange(dr); + DocumentGenerator dg = null; + + ws.dr = range; + try { + dg = new DocumentGenerator(ws, ws.keyType, ws.valueType); + } catch (Exception e) { + body.put("error", "Failed to create doc generator"); + body.put("message", e.toString()); + return new ResponseEntity<>(body, HttpStatus.BAD_REQUEST); + } + + String task_name = "MSMARCOTask_" + TaskRequest.task_id.incrementAndGet() + k + "_" + ws.dr.create_s + + "_" + ws.dr.create_e; + int retry = 0; + String th_name = task_name + "_" + i; + WorkLoadGenerate wlg = new WorkLoadGenerate(th_name, dg, TaskRequest.SDKClientPool, esClient, + this.durabilityLevel, + this.docTTL, this.docTTLUnit, this.trackFailures, retry, null); + wlg.set_collection_for_load(this.bucketName, this.scopeName, this.collectionName); + pendingTasks.put(th_name, wlg); + task_names.add(th_name); + } + k += 1; } + TaskRequest.loader_tasks.putAll(pendingTasks); body.put("tasks", task_names); body.put("status", true); diff --git a/src/main/java/couchbase/sdk/SDKClientPool.java b/src/main/java/couchbase/sdk/SDKClientPool.java index 317325e..39596f4 100644 --- a/src/main/java/couchbase/sdk/SDKClientPool.java +++ b/src/main/java/couchbase/sdk/SDKClientPool.java @@ -131,7 +131,7 @@ public void release_client(SDKClient client) { if (client == null || client.bucket == null) { return; } - + String bucket_key = client.bucket; String cache_key = bucket_key + ":" + client.scope + ":" + client.collection; @@ -140,10 +140,10 @@ public void release_client(SDKClient client) { if (info == null) { return; } - + // Decrement counter atomically int newCount = info.counter.decrementAndGet(); - + if (newCount == 0) { // Remove from cache atomically clientCache.remove(cache_key);