@@ -20,6 +20,12 @@ std::vector<int> batch_write_tensor_impl(const std::vector<std::string> &keys,
2020 const ReplicateConfig &config,
2121 const char *operation_name,
2222 BatchWriteFromFn &&batch_write_from) {
23+ auto group_ids_error =
24+ ValidateGroupIdsForBatchConfig (config, keys.size (), operation_name);
25+ if (!group_ids_error.empty ()) {
26+ return group_ids_error;
27+ }
28+
2329 std::vector<int > results (keys.size (), 0 );
2430
2531 {
@@ -65,8 +71,10 @@ std::vector<int> batch_write_tensor_impl(const std::vector<std::string> &keys,
6571 }
6672
6773 if (!valid_keys.empty ()) {
68- std::vector<int > op_results =
69- batch_write_from (valid_keys, buffer_ptrs, buffer_sizes);
74+ ReplicateConfig write_config =
75+ MakeIndexedConfig (config, original_indices);
76+ std::vector<int > op_results = batch_write_from (
77+ valid_keys, buffer_ptrs, buffer_sizes, write_config);
7078 for (size_t i = 0 ; i < op_results.size (); ++i) {
7179 results[original_indices[i]] = op_results[i];
7280 }
@@ -905,6 +913,12 @@ std::vector<int> batch_put_tensor_with_parallelism(
905913 const py::object ¶llelisms = py::none(),
906914 const ReplicateConfig &config = ReplicateConfig{},
907915 const py::object &writer_partitions = py::none()) {
916+ auto group_ids_error = ValidateGroupIdsForBatchConfig (
917+ config, keys.size (), " batch_put_tensor_with_parallelism" );
918+ if (!group_ids_error.empty ()) {
919+ return group_ids_error;
920+ }
921+
908922 return execute_batch_parallelism_write_requests (
909923 keys, tensors_list.size (), parallelisms, writer_partitions,
910924 " batch_put_tensor_with_parallelism" ,
@@ -921,14 +935,16 @@ std::vector<int> batch_put_tensor_with_parallelism(
921935 },
922936 [this , &keys, &tensors_list, &config](size_t i,
923937 const py::handle ¶llelism) {
938+ ReplicateConfig key_config = config.ForSingleKey (i);
924939 return put_tensor_with_parallelism (
925940 keys[i], tensors_list[i],
926- py::reinterpret_borrow<py::object>(parallelism), config );
941+ py::reinterpret_borrow<py::object>(parallelism), key_config );
927942 },
928943 [this , &keys, &tensors_list, &config](
929944 size_t i, const py::handle &writer_partition) {
945+ ReplicateConfig key_config = config.ForSingleKey (i);
930946 return put_tensor_with_parallelism (
931- keys[i], tensors_list[i], py::none (), config ,
947+ keys[i], tensors_list[i], py::none (), key_config ,
932948 py::reinterpret_borrow<py::object>(writer_partition));
933949 });
934950}
@@ -1029,6 +1045,12 @@ std::vector<int> batch_put_tensor_with_parallelism_from(
10291045 const py::object ¶llelisms = py::none(),
10301046 const ReplicateConfig &config = ReplicateConfig{},
10311047 const py::object &writer_partitions = py::none()) {
1048+ auto group_ids_error = ValidateGroupIdsForBatchConfig (
1049+ config, keys.size (), " batch_put_tensor_with_parallelism_from" );
1050+ if (!group_ids_error.empty ()) {
1051+ return group_ids_error;
1052+ }
1053+
10321054 return execute_batch_parallelism_write_requests (
10331055 keys, buffer_ptrs.size (), parallelisms, writer_partitions,
10341056 " batch_put_tensor_with_parallelism_from" ,
@@ -1070,14 +1092,16 @@ std::vector<int> batch_put_tensor_with_parallelism_from(
10701092 },
10711093 [this , &keys, &buffer_ptrs, &sizes, &config](
10721094 size_t i, const py::handle ¶llelism) {
1095+ ReplicateConfig key_config = config.ForSingleKey (i);
10731096 return put_tensor_with_parallelism_from (
10741097 keys[i], buffer_ptrs[i], sizes[i],
1075- py::reinterpret_borrow<py::object>(parallelism), config );
1098+ py::reinterpret_borrow<py::object>(parallelism), key_config );
10761099 },
10771100 [this , &keys, &buffer_ptrs, &sizes, &config](
10781101 size_t i, const py::handle &writer_partition) {
1102+ ReplicateConfig key_config = config.ForSingleKey (i);
10791103 return put_tensor_with_parallelism_from (
1080- keys[i], buffer_ptrs[i], sizes[i], py::none (), config ,
1104+ keys[i], buffer_ptrs[i], sizes[i], py::none (), key_config ,
10811105 py::reinterpret_borrow<py::object>(writer_partition));
10821106 });
10831107}
@@ -1345,6 +1369,12 @@ std::vector<int> batch_upsert_tensor_with_parallelism(
13451369 const py::object ¶llelisms = py::none(),
13461370 const ReplicateConfig &config = ReplicateConfig{},
13471371 const py::object &writer_partitions = py::none()) {
1372+ auto group_ids_error = ValidateGroupIdsForBatchConfig (
1373+ config, keys.size (), " batch_upsert_tensor_with_parallelism" );
1374+ if (!group_ids_error.empty ()) {
1375+ return group_ids_error;
1376+ }
1377+
13481378 return execute_batch_parallelism_write_requests (
13491379 keys, tensors_list.size (), parallelisms, writer_partitions,
13501380 " batch_upsert_tensor_with_parallelism" ,
@@ -1361,14 +1391,16 @@ std::vector<int> batch_upsert_tensor_with_parallelism(
13611391 },
13621392 [this , &keys, &tensors_list, &config](size_t i,
13631393 const py::handle ¶llelism) {
1394+ ReplicateConfig key_config = config.ForSingleKey (i);
13641395 return upsert_tensor_with_parallelism (
13651396 keys[i], tensors_list[i],
1366- py::reinterpret_borrow<py::object>(parallelism), config );
1397+ py::reinterpret_borrow<py::object>(parallelism), key_config );
13671398 },
13681399 [this , &keys, &tensors_list, &config](
13691400 size_t i, const py::handle &writer_partition) {
1401+ ReplicateConfig key_config = config.ForSingleKey (i);
13701402 return upsert_tensor_with_parallelism (
1371- keys[i], tensors_list[i], py::none (), config ,
1403+ keys[i], tensors_list[i], py::none (), key_config ,
13721404 py::reinterpret_borrow<py::object>(writer_partition));
13731405 });
13741406}
@@ -1379,6 +1411,12 @@ std::vector<int> batch_upsert_tensor_with_parallelism_from(
13791411 const py::object ¶llelisms = py::none(),
13801412 const ReplicateConfig &config = ReplicateConfig{},
13811413 const py::object &writer_partitions = py::none()) {
1414+ auto group_ids_error = ValidateGroupIdsForBatchConfig (
1415+ config, keys.size (), " batch_upsert_tensor_with_parallelism_from" );
1416+ if (!group_ids_error.empty ()) {
1417+ return group_ids_error;
1418+ }
1419+
13821420 return execute_batch_parallelism_write_requests (
13831421 keys, buffer_ptrs.size (), parallelisms, writer_partitions,
13841422 " batch_upsert_tensor_with_parallelism_from" ,
@@ -1430,14 +1468,16 @@ std::vector<int> batch_upsert_tensor_with_parallelism_from(
14301468 },
14311469 [this , &keys, &buffer_ptrs, &sizes, &config](
14321470 size_t i, const py::handle ¶llelism) {
1471+ ReplicateConfig key_config = config.ForSingleKey (i);
14331472 return upsert_tensor_with_parallelism_from (
14341473 keys[i], buffer_ptrs[i], sizes[i],
1435- py::reinterpret_borrow<py::object>(parallelism), config );
1474+ py::reinterpret_borrow<py::object>(parallelism), key_config );
14361475 },
14371476 [this , &keys, &buffer_ptrs, &sizes, &config](
14381477 size_t i, const py::handle &writer_partition) {
1478+ ReplicateConfig key_config = config.ForSingleKey (i);
14391479 return upsert_tensor_with_parallelism_from (
1440- keys[i], buffer_ptrs[i], sizes[i], py::none (), config ,
1480+ keys[i], buffer_ptrs[i], sizes[i], py::none (), key_config ,
14411481 py::reinterpret_borrow<py::object>(writer_partition));
14421482 });
14431483}
0 commit comments