Skip to content

Commit b7b7c67

Browse files
committed
Improve API consumption while _handle_bulk_insert_op
1 parent 7af6b98 commit b7b7c67

File tree

2 files changed

+43
-55
lines changed
  • community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts

2 files changed

+43
-55
lines changed

community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/resume.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from typing import List, Optional, Dict
18+
from functools import lru_cache
19+
from typing import List, Optional, Dict, Any
1920
import argparse
2021
from datetime import timedelta
2122
import shlex
@@ -33,7 +34,6 @@
3334
chunked,
3435
ensure_execute,
3536
execute_with_futures,
36-
get_insert_operations,
3737
log_api_request,
3838
map_with_futures,
3939
run,
@@ -375,6 +375,40 @@ def resume_nodes(nodes: List[str], resume_data: Optional[ResumeData]):
375375
_handle_bulk_insert_op(op, grouped_nodes[group].nodes, resume_data)
376376

377377

378+
def _get_failed_zonal_instance_inserts(bulk_op: Any, zone: str, lkp: util.Lookup) -> list[Any]:
379+
group_id = bulk_op["operationGroupId"]
380+
user = bulk_op["user"]
381+
started = bulk_op["startTime"]
382+
ended = bulk_op["endTime"]
383+
384+
fltr = f'(user eq "{user}") AND (operationType eq "insert") AND (creationTimestamp > "{started}") AND (creationTimestamp < "{ended}")'
385+
act = lkp.compute.zoneOperations()
386+
req = act.list(project=lkp.project, zone=zone, filter=fltr)
387+
ops = []
388+
while req is not None:
389+
result = util.ensure_execute(req)
390+
for op in result.get("items", []):
391+
if op.get("operationGroupId") == group_id and "error" in op:
392+
ops.append(op)
393+
req = act.list_next(req, result)
394+
return ops
395+
396+
397+
def _get_failed_instance_inserts(bulk_op: Any, lkp: util.Lookup) -> list[Any]:
398+
zones = set() # gather zones that had failed inserts
399+
for loc, stat in bulk_op.get("instancesBulkInsertOperationMetadata", {}).get("perLocationStatus", {}).items():
400+
pref, zone = loc.split("/", 1)
401+
if not pref == "zones":
402+
log.error(f"Unexpected location: {loc} in operation {bulk_op['name']}")
403+
continue
404+
if stat.get("targetVmCount", 0) != stat.get("createdVmCount", 0):
405+
zones.add(zone)
406+
407+
res = []
408+
for zone in zones:
409+
res.extend(_get_failed_zonal_instance_inserts(bulk_op, zone, lkp))
410+
return res
411+
378412
def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[ResumeData]) -> None:
379413
"""
380414
Handles **DONE** BulkInsert operations
@@ -384,10 +418,9 @@ def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[Res
384418
group_id = op["operationGroupId"]
385419
if "error" in op:
386420
error = op["error"]["errors"][0]
387-
log.warning(
421+
log.error(
388422
f"bulkInsert operation error: {error['code']} name={op['name']} operationGroupId={group_id} nodes={to_hostlist(nodes)}"
389423
)
390-
# TODO: does it make sense to query for insert-ops in case of bulkInsert-op error?
391424

392425
created = 0
393426
for status in op["instancesBulkInsertOperationMetadata"]["perLocationStatus"].values():
@@ -396,18 +429,13 @@ def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[Res
396429
log.info(f"created {len(nodes)} instances: nodes={to_hostlist(nodes)}")
397430
return # no need to gather status of insert-operations.
398431

399-
# TODO:
400-
# * don't perform globalOperations aggregateList request to gather insert-operations,
401-
# instead use specific locations from `instancesBulkInsertOperationMetadata`,
402-
# most of the time single zone should be sufficient.
403-
# * don't gather insert-operations per bulkInsert request, instead aggregate it across
404-
# all bulkInserts (goes one level above this function)
405-
successful_inserts, failed_inserts = separate(
406-
lambda op: "error" in op, get_insert_operations(group_id)
407-
)
408-
# Apparently multiple errors are possible... so join with +.
432+
# TODO: don't gather insert-operations per bulkInsert request, instead aggregate it
433+
# across all bulkInserts (goes one level above this function)
434+
failed = _get_failed_instance_inserts(op, util.lookup())
435+
436+
# Multiple errors are possible, group by all of them (joined string codes)
409437
by_error_inserts = util.groupby_unsorted(
410-
failed_inserts,
438+
failed,
411439
lambda op: "+".join(err["code"] for err in op["error"]["errors"]),
412440
)
413441
for code, failed_ops in by_error_inserts:
@@ -428,10 +456,6 @@ def _handle_bulk_insert_op(op: Dict, nodes: List[str], resume_data: Optional[Res
428456
f"errors from insert for node '{failed_nodes[0]}' ({failed_ops[0]['name']}): {msg}"
429457
)
430458

431-
ready_nodes = {trim_self_link(op["targetLink"]) for op in successful_inserts}
432-
if len(ready_nodes) > 0:
433-
log.info(f"created {len(ready_nodes)} instances: nodes={to_hostlist(ready_nodes)}")
434-
435459

436460
def down_nodes_notify_jobs(nodes: List[str], reason: str, resume_data: Optional[ResumeData]) -> None:
437461
"""set nodes down with reason"""

community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/util.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,42 +1245,6 @@ def wait_for_operations(operations):
12451245
]
12461246

12471247

1248-
def get_filtered_operations(op_filter):
1249-
"""get list of operations associated with group id"""
1250-
project = lookup().project
1251-
operations: List[Any] = []
1252-
1253-
def get_aggregated_operations(items):
1254-
# items is a dict of location key to value: dict(operations=<list of operations>) or an empty dict
1255-
operations.extend(
1256-
chain.from_iterable(
1257-
ops["operations"] for ops in items.values() if "operations" in ops
1258-
)
1259-
)
1260-
1261-
act = lookup().compute.globalOperations()
1262-
op = act.aggregatedList(
1263-
project=project, filter=op_filter, fields="items.*.operations,nextPageToken"
1264-
)
1265-
1266-
while op is not None:
1267-
result = ensure_execute(op)
1268-
get_aggregated_operations(result["items"])
1269-
op = act.aggregatedList_next(op, result)
1270-
return operations
1271-
1272-
1273-
def get_insert_operations(group_ids):
1274-
"""get all insert operations from a list of operationGroupId"""
1275-
if isinstance(group_ids, str):
1276-
group_ids = group_ids.split(",")
1277-
filters = [
1278-
"operationType=insert",
1279-
" OR ".join(f"(operationGroupId={id})" for id in group_ids),
1280-
]
1281-
return get_filtered_operations(" AND ".join(f"({f})" for f in filters if f))
1282-
1283-
12841248
def getThreadsPerCore(template) -> int:
12851249
if not template.machine_type.supports_smt:
12861250
return 1

0 commit comments

Comments
 (0)