Skip to content

Commit 77fc4fa

Browse files
committed
Improve API consumption while _handle_bulk_insert_op
1 parent 7af6b98 commit 77fc4fa

File tree

2 files changed

+42
-55
lines changed
  • community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts

2 files changed

+42
-55
lines changed

community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from typing import List, Optional, Dict
18+
from typing import List, Optional, Dict, Any
1919
import argparse
2020
from datetime import timedelta
2121
import shlex
@@ -33,7 +33,6 @@
3333
chunked,
3434
ensure_execute,
3535
execute_with_futures,
36-
get_insert_operations,
3736
log_api_request,
3837
map_with_futures,
3938
run,
@@ -375,6 +374,40 @@ def resume_nodes(nodes: List[str], resume_data: Optional[ResumeData]):
375374
_handle_bulk_insert_op(op, grouped_nodes[group].nodes, resume_data)
376375

377376

377+
def _get_failed_zonal_instance_inserts(bulk_op: Any, zone: str, lkp: util.Lookup) -> list[Any]:
378+
group_id = bulk_op["operationGroupId"]
379+
user = bulk_op["user"]
380+
started = bulk_op["startTime"]
381+
ended = bulk_op["endTime"]
382+
383+
fltr = f'(user eq "{user}") AND (operationType eq "insert") AND (creationTimestamp > "{started}") AND (creationTimestamp < "{ended}")'
384+
act = lkp.compute.zoneOperations()
385+
req = act.list(project=lkp.project, zone=zone, filter=fltr)
386+
ops = []
387+
while req is not None:
388+
result = util.ensure_execute(req)
389+
for op in result.get("items", []):
390+
if op.get("operationGroupId") == group_id and "error" in op:
391+
ops.append(op)
392+
req = act.list_next(req, result)
393+
return ops
394+
395+
396+
def _get_failed_instance_inserts(bulk_op: Any, lkp: util.Lookup) -> list[Any]:
397+
zones = set() # gather zones that had failed inserts
398+
for loc, stat in bulk_op.get("instancesBulkInsertOperationMetadata", {}).get("perLocationStatus", {}).items():
399+
pref, zone = loc.split("/", 1)
400+
if not pref == "zones":
401+
log.error(f"Unexpected location: {loc} in operation {bulk_op['name']}")
402+
continue
403+
if stat.get("targetVmCount", 0) != stat.get("createdVmCount", 0):
404+
zones.add(zone)
405+
406+
res = []
407+
for zone in zones:
408+
res.extend(_get_failed_zonal_instance_inserts(bulk_op, zone, lkp))
409+
return res
410+
378411
def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[ResumeData]) -> None:
379412
"""
380413
Handles **DONE** BulkInsert operations
@@ -384,10 +417,9 @@ def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[Res
384417
group_id = op["operationGroupId"]
385418
if "error" in op:
386419
error = op["error"]["errors"][0]
387-
log.warning(
420+
log.error(
388421
f"bulkInsert operation error: {error['code']} name={op['name']} operationGroupId={group_id} nodes={to_hostlist(nodes)}"
389422
)
390-
# TODO: does it make sense to query for insert-ops in case of bulkInsert-op error?
391423

392424
created = 0
393425
for status in op["instancesBulkInsertOperationMetadata"]["perLocationStatus"].values():
@@ -396,18 +428,13 @@ def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[Res
396428
log.info(f"created {len(nodes)} instances: nodes={to_hostlist(nodes)}")
397429
return # no need to gather status of insert-operations.
398430

399-
# TODO:
400-
# * don't perform globalOperations aggregateList request to gather insert-operations,
401-
# instead use specific locations from `instancesBulkInsertOperationMetadata`,
402-
# most of the time single zone should be sufficient.
403-
# * don't gather insert-operations per bulkInsert request, instead aggregate it across
404-
# all bulkInserts (goes one level above this function)
405-
successful_inserts, failed_inserts = separate(
406-
lambda op: "error" in op, get_insert_operations(group_id)
407-
)
408-
# Apparently multiple errors are possible... so join with +.
431+
# TODO: don't gather insert-operations per bulkInsert request, instead aggregate it
432+
# across all bulkInserts (goes one level above this function)
433+
failed = _get_failed_instance_inserts(op, util.lookup())
434+
435+
# Multiple errors are possible, group by all of them (joined string codes)
409436
by_error_inserts = util.groupby_unsorted(
410-
failed_inserts,
437+
failed,
411438
lambda op: "+".join(err["code"] for err in op["error"]["errors"]),
412439
)
413440
for code, failed_ops in by_error_inserts:
@@ -428,10 +455,6 @@ def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[Res
428455
f"errors from insert for node '{failed_nodes[0]}' ({failed_ops[0]['name']}): {msg}"
429456
)
430457

431-
ready_nodes = {trim_self_link(op["targetLink"]) for op in successful_inserts}
432-
if len(ready_nodes) > 0:
433-
log.info(f"created {len(ready_nodes)} instances: nodes={to_hostlist(ready_nodes)}")
434-
435458

436459
def down_nodes_notify_jobs(nodes: List[str], reason: str, resume_data: Optional[ResumeData]) -> None:
437460
"""set nodes down with reason"""

community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,42 +1245,6 @@ def wait_for_operations(operations):
12451245
]
12461246

12471247

1248-
def get_filtered_operations(op_filter):
1249-
"""get list of operations associated with group id"""
1250-
project = lookup().project
1251-
operations: List[Any] = []
1252-
1253-
def get_aggregated_operations(items):
1254-
# items is a dict of location key to value: dict(operations=<list of operations>) or an empty dict
1255-
operations.extend(
1256-
chain.from_iterable(
1257-
ops["operations"] for ops in items.values() if "operations" in ops
1258-
)
1259-
)
1260-
1261-
act = lookup().compute.globalOperations()
1262-
op = act.aggregatedList(
1263-
project=project, filter=op_filter, fields="items.*.operations,nextPageToken"
1264-
)
1265-
1266-
while op is not None:
1267-
result = ensure_execute(op)
1268-
get_aggregated_operations(result["items"])
1269-
op = act.aggregatedList_next(op, result)
1270-
return operations
1271-
1272-
1273-
def get_insert_operations(group_ids):
1274-
"""get all insert operations from a list of operationGroupId"""
1275-
if isinstance(group_ids, str):
1276-
group_ids = group_ids.split(",")
1277-
filters = [
1278-
"operationType=insert",
1279-
" OR ".join(f"(operationGroupId={id})" for id in group_ids),
1280-
]
1281-
return get_filtered_operations(" AND ".join(f"({f})" for f in filters if f))
1282-
1283-
12841248
def getThreadsPerCore(template) -> int:
12851249
if not template.machine_type.supports_smt:
12861250
return 1

0 commit comments

Comments
 (0)