Skip to content

Commit 40df406

Browse files
committed
fix(msmarco): address PR #46 review comments and code review findings
Step-based partitioning and correctness fixes for loadMSMARCODataset: - Replace flat even-split with MSMARCOEmbeddingProduct.getSteps() loop, matching CLI (MSMARCOLoader) and REST SIFT (loadSIFTDataset) behaviour - Add start_offset bounds check against STEPS so out-of-range values fail fast with IllegalArgumentException instead of AIOOBE - Add end_offset bounds check against steps[last] to prevent AIOOBE in the outer while loop when caller sends value beyond dataset size - Add end_offset <= start_offset guard to replace the deleted totalDocs <= 0 check, restoring the silent no-op diagnostic - Cap workers per range with effectivePool = min(poolSize, end-start) to prevent step=0 integer division that silently created zero-range workers when doc count was smaller than processConcurrency - Stage all WorkLoadGenerate instances in a local map before flushing to loader_tasks so a mid-loop failure leaves no orphaned tasks - Remove redundant ws.embeddingFilePath post-construction assignment already set by the WorkLoadSettings constructor - Strip trailing whitespace from three blank lines in SDKClientPool introduced by the PR Used Claude Code for code generation. Model used: claude-sonnet-4-6.
1 parent c57c600 commit 40df406

2 files changed

Lines changed: 80 additions & 62 deletions

File tree

src/main/java/RestServer/TaskRequest.java

Lines changed: 77 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import mongo.sdk.MongoSDKClient;
1313
import couchbase.sdk.Result;
1414
import utils.common.FileDownload;
15+
import utils.val.MSMARCOEmbeddingProduct;
1516
import utils.taskmanager.Task;
1617
import utils.taskmanager.TaskManager;
1718

@@ -979,6 +980,7 @@ public ResponseEntity<Map<String, Object>> loadMSMARCODataset() throws IOExcepti
979980
}
980981
}
981982

983+
long[] steps = MSMARCOEmbeddingProduct.getSteps();
982984
int poolSize = this.processConcurrency;
983985
long start_offset = 0, end_offset = 0;
984986
if (this.createPercent > 0) {
@@ -992,68 +994,84 @@ public ResponseEntity<Map<String, Object>> loadMSMARCODataset() throws IOExcepti
992994
end_offset = this.expiryEndIndex;
993995
}
994996

997+
if (end_offset <= start_offset)
998+
throw new IllegalArgumentException("No docs to process: start_offset="
999+
+ start_offset + " end_offset=" + end_offset);
1000+
if (start_offset < steps[0] || start_offset >= steps[steps.length - 1])
1001+
throw new IllegalArgumentException("start_offset " + start_offset + " is outside STEPS bounds ["
1002+
+ steps[0] + ", " + steps[steps.length - 1] + ")");
1003+
if (end_offset > steps[steps.length - 1])
1004+
throw new IllegalArgumentException("end_offset " + end_offset
1005+
+ " exceeds STEPS upper bound " + steps[steps.length - 1]);
1006+
9951007
ArrayList<String> task_names = new ArrayList<String>();
996-
long totalDocs = end_offset - start_offset;
997-
if (totalDocs <= 0)
998-
throw new IllegalArgumentException("No docs to process: start_offset=" + start_offset + " end_offset=" + end_offset);
999-
int effectivePool = (int) Math.min(poolSize, totalDocs);
1000-
long step = totalDocs / effectivePool;
1001-
for (int i = 0; i < effectivePool; i++) {
1002-
WorkLoadSettings ws = new WorkLoadSettings(this.keyPrefix,
1003-
this.keySize, this.docSize,
1004-
this.createPercent, this.readPercent,
1005-
this.updatePercent, this.deletePercent, this.expiryPercent, this.processConcurrency,
1006-
this.ops, this.loadType, this.keyType, msmarcoValueType,
1007-
this.validateDocs, this.gtm, this.validateDeletedDocs, this.mutate,
1008-
this.elastic, this.model, this.mockVector,
1009-
this.dim, this.base64, this.mutateField,
1010-
this.mutationTimeout, this.vecFilePath);
1011-
ws.embeddingFilePath = this.vecFilePath;
1012-
ws.baseVectorsFilePath = "MSMARCOSiftEmbeddingProduct".equals(msmarcoValueType)
1013-
? this.baseVectorsFilePath
1014-
: this.vecFilePath;
1015-
1016-
long workerStart = start_offset + step * i;
1017-
long workerEnd = (i == effectivePool - 1) ? end_offset : start_offset + step * (i + 1);
1018-
HashMap<String, Number> dr = new HashMap<>();
1019-
dr.put(DRConstants.create_s, workerStart);
1020-
dr.put(DRConstants.create_e, workerEnd);
1021-
dr.put(DRConstants.read_s, this.readStartIndex);
1022-
dr.put(DRConstants.read_e, this.readEndIndex);
1023-
dr.put(DRConstants.update_s, workerStart);
1024-
dr.put(DRConstants.update_e, workerEnd);
1025-
dr.put(DRConstants.delete_s, this.deleteStartIndex);
1026-
dr.put(DRConstants.delete_e, this.deleteEndIndex);
1027-
dr.put(DRConstants.touch_s, this.touchStartIndex);
1028-
dr.put(DRConstants.touch_e, this.touchEndIndex);
1029-
dr.put(DRConstants.replace_s, this.replaceStartIndex);
1030-
dr.put(DRConstants.replace_e, this.replaceEndIndex);
1031-
dr.put(DRConstants.expiry_s, workerStart);
1032-
dr.put(DRConstants.expiry_e, workerEnd);
1033-
1034-
DocRange range = new DocRange(dr);
1035-
DocumentGenerator dg = null;
1036-
1037-
ws.dr = range;
1038-
try {
1039-
dg = new DocumentGenerator(ws, ws.keyType, ws.valueType);
1040-
} catch (Exception e) {
1041-
body.put("error", "Failed to create doc generator");
1042-
body.put("message", e.toString());
1043-
return new ResponseEntity<>(body, HttpStatus.BAD_REQUEST);
1044-
}
1008+
HashMap<String, WorkLoadGenerate> pendingTasks = new HashMap<>();
1009+
int k = 0;
1010+
while (!(steps[k] <= start_offset && start_offset < steps[k + 1]))
1011+
k += 1;
1012+
while (steps[k] < end_offset) {
1013+
long start = Math.max(start_offset, steps[k]);
1014+
long end = Math.min(end_offset, steps[k + 1]);
1015+
int effectivePool = (int) Math.min(poolSize, end - start);
1016+
long step = (end - start) / effectivePool;
1017+
for (int i = 0; i < effectivePool; i++) {
1018+
WorkLoadSettings ws = new WorkLoadSettings(this.keyPrefix,
1019+
this.keySize, this.docSize,
1020+
this.createPercent, this.readPercent,
1021+
this.updatePercent, this.deletePercent, this.expiryPercent, this.processConcurrency,
1022+
this.ops, this.loadType, this.keyType, msmarcoValueType,
1023+
this.validateDocs, this.gtm, this.validateDeletedDocs, this.mutate,
1024+
this.elastic, this.model, this.mockVector,
1025+
this.dim, this.base64, this.mutateField,
1026+
this.mutationTimeout, this.vecFilePath);
1027+
ws.baseVectorsFilePath = "MSMARCOSiftEmbeddingProduct".equals(msmarcoValueType)
1028+
? this.baseVectorsFilePath
1029+
: this.vecFilePath;
1030+
1031+
long workerStart = start + step * i;
1032+
long workerEnd = (i == effectivePool - 1) ? end : start + step * (i + 1);
1033+
HashMap<String, Number> dr = new HashMap<>();
1034+
dr.put(DRConstants.create_s, workerStart);
1035+
dr.put(DRConstants.create_e, workerEnd);
1036+
dr.put(DRConstants.read_s, this.readStartIndex);
1037+
dr.put(DRConstants.read_e, this.readEndIndex);
1038+
dr.put(DRConstants.update_s, workerStart);
1039+
dr.put(DRConstants.update_e, workerEnd);
1040+
dr.put(DRConstants.delete_s, this.deleteStartIndex);
1041+
dr.put(DRConstants.delete_e, this.deleteEndIndex);
1042+
dr.put(DRConstants.touch_s, this.touchStartIndex);
1043+
dr.put(DRConstants.touch_e, this.touchEndIndex);
1044+
dr.put(DRConstants.replace_s, this.replaceStartIndex);
1045+
dr.put(DRConstants.replace_e, this.replaceEndIndex);
1046+
dr.put(DRConstants.expiry_s, workerStart);
1047+
dr.put(DRConstants.expiry_e, workerEnd);
10451048

1046-
String task_name = "MSMARCOTask_" + TaskRequest.task_id.incrementAndGet() + "_" + ws.dr.create_s + "_"
1047-
+ ws.dr.create_e;
1048-
int retry = 0;
1049-
String th_name = task_name + "_" + i;
1050-
WorkLoadGenerate wlg = new WorkLoadGenerate(th_name, dg, TaskRequest.SDKClientPool, esClient,
1051-
this.durabilityLevel,
1052-
this.docTTL, this.docTTLUnit, this.trackFailures, retry, null);
1053-
wlg.set_collection_for_load(this.bucketName, this.scopeName, this.collectionName);
1054-
TaskRequest.loader_tasks.put(th_name, wlg);
1055-
task_names.add(th_name);
1049+
DocRange range = new DocRange(dr);
1050+
DocumentGenerator dg = null;
1051+
1052+
ws.dr = range;
1053+
try {
1054+
dg = new DocumentGenerator(ws, ws.keyType, ws.valueType);
1055+
} catch (Exception e) {
1056+
body.put("error", "Failed to create doc generator");
1057+
body.put("message", e.toString());
1058+
return new ResponseEntity<>(body, HttpStatus.BAD_REQUEST);
1059+
}
1060+
1061+
String task_name = "MSMARCOTask_" + TaskRequest.task_id.incrementAndGet() + k + "_" + ws.dr.create_s
1062+
+ "_" + ws.dr.create_e;
1063+
int retry = 0;
1064+
String th_name = task_name + "_" + i;
1065+
WorkLoadGenerate wlg = new WorkLoadGenerate(th_name, dg, TaskRequest.SDKClientPool, esClient,
1066+
this.durabilityLevel,
1067+
this.docTTL, this.docTTLUnit, this.trackFailures, retry, null);
1068+
wlg.set_collection_for_load(this.bucketName, this.scopeName, this.collectionName);
1069+
pendingTasks.put(th_name, wlg);
1070+
task_names.add(th_name);
1071+
}
1072+
k += 1;
10561073
}
1074+
TaskRequest.loader_tasks.putAll(pendingTasks);
10571075

10581076
body.put("tasks", task_names);
10591077
body.put("status", true);

src/main/java/couchbase/sdk/SDKClientPool.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public void release_client(SDKClient client) {
131131
if (client == null || client.bucket == null) {
132132
return;
133133
}
134-
134+
135135
String bucket_key = client.bucket;
136136
String cache_key = bucket_key + ":" + client.scope + ":" + client.collection;
137137

@@ -140,10 +140,10 @@ public void release_client(SDKClient client) {
140140
if (info == null) {
141141
return;
142142
}
143-
143+
144144
// Decrement counter atomically
145145
int newCount = info.counter.decrementAndGet();
146-
146+
147147
if (newCount == 0) {
148148
// Remove from cache atomically
149149
clientCache.remove(cache_key);

0 commit comments

Comments
 (0)