Skip to content

Commit 6d76050

Browse files
committed
retry all steps during self delete if the delete was not successful and change self-delete permission assignment to happen during container creation and update benchmarks to set more random seeds
1 parent 9fc4eb5 commit 6d76050

File tree

3 files changed

+137
-99
lines changed

3 files changed

+137
-99
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ We also build on top of many great packages. Please check them out!
622622

623623
# Papers that use or compare EBMs
624624

625+
- [Challenging the Performance-Interpretability Trade-off: An Evaluation of Interpretable Machine Learning Models](https://arxiv.org/pdf/2409.14429)
625626
- [Data Science with LLMs and Interpretable Models](https://arxiv.org/pdf/2402.14474v1.pdf)
626627
- [DimVis: Interpreting Visual Clusters in Dimensionality Reduction With Explainable Boosting Machine](https://arxiv.org/pdf/2402.06885.pdf)
627628
- [Distill knowledge of additive tree models into generalized linear models](https://detralytics.com/wp-content/uploads/2023/10/Detra-Note_Additive-tree-ensembles.pdf)

docs/benchmarks/ebm-benchmark.ipynb

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"force_recreate = False\n",
1818
"exist_ok = True\n",
1919
"TIMEOUT_SEC = 60 * 60 * 24 * 180 # 180 days\n",
20-
"wheel_filepaths = ['interpret_core-0.6.3-py3-none-any.whl', 'powerlift-0.1.11-py3-none-any.whl']\n",
20+
"wheel_filepaths = ['interpret_core-0.6.4-py3-none-any.whl', 'powerlift-0.1.12-py3-none-any.whl']\n",
2121
"\n",
2222
"import datetime\n",
2323
"experiment_name = datetime.datetime.now().strftime('%Y_%m_%d_%H%M__') + 'myexperiment'\n",
@@ -230,6 +230,10 @@
230230
" import warnings\n",
231231
" import gc\n",
232232
" import re\n",
233+
" import random\n",
234+
"\n",
235+
" random.seed(seed)\n",
236+
" np.random.seed(seed)\n",
233237
"\n",
234238
" X, y = trial.task.data()\n",
235239
"\n",
@@ -851,7 +855,7 @@
851855
" executor = AzureContainerInstance(\n",
852856
" store, azure_tenant_id, subscription_id, azure_client_id, credential,\n",
853857
" resource_group=resource_group,\n",
854-
" pip_install= requirements + \" interpret-core\",\n",
858+
" pip_install=requirements + \" interpret-core\",\n",
855859
" wheel_filepaths=wheel_filepaths,\n",
856860
" n_running_containers=n_containers\n",
857861
" )\n",

python/powerlift/powerlift/run_azure/__main__.py

+130-97
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,80 @@
11
"""This is called to run a trial by worker nodes (local / remote)."""
22

33

4+
def assign_delete_permissions(
5+
aci_client,
6+
auth_client,
7+
max_undead_containers,
8+
credential,
9+
subscription_id,
10+
client_id,
11+
resource_group_name,
12+
container_groups,
13+
):
14+
from heapq import heappush, heappop
15+
from datetime import datetime
16+
import time
17+
import uuid
18+
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
19+
from azure.mgmt.authorization import AuthorizationManagementClient
20+
from azure.mgmt.authorization.models import RoleAssignmentCreateParameters
21+
from azure.core.exceptions import HttpResponseError
22+
23+
# Contributor Role
24+
role_definition_id = f"/subscriptions/{subscription_id}/providers/Microsoft.Authorization/roleDefinitions/b24988ac-6180-42a0-ab88-20f7382dd24c"
25+
26+
while max_undead_containers < len(container_groups):
27+
_, container_group_name, started = heappop(container_groups)
28+
try:
29+
if started is not None:
30+
if not started.done():
31+
heappush(
32+
container_groups,
33+
(datetime.now(), container_group_name, started),
34+
)
35+
time.sleep(1)
36+
continue
37+
started = None
38+
39+
if aci_client is None:
40+
aci_client = ContainerInstanceManagementClient(
41+
credential, subscription_id
42+
)
43+
44+
container_group = aci_client.container_groups.get(
45+
resource_group_name, container_group_name
46+
)
47+
48+
role_assignment_params1 = RoleAssignmentCreateParameters(
49+
role_definition_id=role_definition_id,
50+
principal_id=container_group.identity.principal_id,
51+
principal_type="ServicePrincipal",
52+
)
53+
role_assignment_params2 = RoleAssignmentCreateParameters(
54+
role_definition_id=role_definition_id,
55+
principal_id=client_id,
56+
principal_type="User",
57+
)
58+
scope = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.ContainerInstance/containerGroups/{container_group_name}"
59+
60+
if auth_client is None:
61+
auth_client = AuthorizationManagementClient(credential, subscription_id)
62+
63+
auth_client.role_assignments.create(
64+
scope, str(uuid.uuid4()), role_assignment_params1
65+
)
66+
auth_client.role_assignments.create(
67+
scope, str(uuid.uuid4()), role_assignment_params2
68+
)
69+
except HttpResponseError:
70+
aci_client = None
71+
auth_client = None
72+
heappush(container_groups, (datetime.now(), container_group_name, started))
73+
time.sleep(1)
74+
75+
return aci_client, auth_client
76+
77+
478
def run_azure_process(
579
experiment_id,
680
n_runners,
@@ -16,46 +90,33 @@ def run_azure_process(
1690
):
1791
startup_script = """
1892
self_delete() {
19-
echo "Attempt to self-delete this container group. Exit code was $1."
93+
echo "Attempt to self-delete this container group. Exit code was: $1"
94+
95+
if [ $1 -ne 0 ]; then
96+
echo "Waiting 10 mintues to allow inspection of the logs..."
97+
sleep 600
98+
fi
99+
100+
SUBSCRIPTION_ID=${SUBSCRIPTION_ID}
101+
RESOURCE_GROUP_NAME=${RESOURCE_GROUP_NAME}
102+
CONTAINER_GROUP_NAME=${CONTAINER_GROUP_NAME}
20103
21104
retry_count=0
22105
while true; do
23106
echo "Downloading azure tools."
24107
25108
curl -sL https://aka.ms/InstallAzureCLIDeb -o install_script.sh
26109
exit_code=$?
27-
if [ $exit_code -eq 0 ]; then
28-
break
110+
if [ $exit_code -ne 0 ]; then
111+
echo "curl failed with exit code $exit_code."
29112
fi
30113
31-
echo "curl failed with exit code $exit_code."
32-
if [ $retry_count -ge 300 ]; then
33-
echo "Maximum number of retries reached. Command failed."
34-
exit 62
114+
bash install_script.sh
115+
exit_code=$?
116+
if [ $exit_code -ne 0 ]; then
117+
echo "Failed to install azure tools with exit code $exit_code. Attempting to delete anyway."
35118
fi
36-
retry_count=$((retry_count + 1))
37-
echo "Sleeping."
38-
sleep 300
39-
echo "Retrying."
40-
done
41-
42-
bash install_script.sh
43-
exit_code=$?
44-
if [ $exit_code -ne 0 ]; then
45-
echo "Failed to install azure tools with exit code $exit_code. Attempting to delete anyway."
46-
fi
47-
48-
SUBSCRIPTION_ID=${SUBSCRIPTION_ID}
49-
RESOURCE_GROUP_NAME=${RESOURCE_GROUP_NAME}
50-
CONTAINER_GROUP_NAME=${CONTAINER_GROUP_NAME}
51119
52-
if [ $1 -ne 0 ]; then
53-
echo "Waiting 10 mintues to allow inspection of the logs..."
54-
sleep 600
55-
fi
56-
57-
retry_count=0
58-
while true; do
59120
echo "Logging into azure to delete this container."
60121
az login --identity
61122
exit_code=$?
@@ -78,8 +139,10 @@ def run_azure_process(
78139
break
79140
fi
80141
retry_count=$((retry_count + 1))
81-
done
82142
143+
echo "Retrying."
144+
done
145+
83146
exit $1 # failed to self-kill the container we are running this on.
84147
}
85148
@@ -297,12 +360,10 @@ def run_azure_process(
297360

298361
import time
299362
import uuid
300-
from multiprocessing.pool import MaybeEncodingError
301-
363+
from heapq import heappush
364+
from datetime import datetime
302365
from azure.core.exceptions import HttpResponseError
303366
from azure.identity import ClientSecretCredential
304-
from azure.mgmt.authorization import AuthorizationManagementClient
305-
from azure.mgmt.authorization.models import RoleAssignmentCreateParameters
306367
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
307368
from azure.mgmt.containerinstance.models import (
308369
Container,
@@ -315,6 +376,8 @@ def run_azure_process(
315376
)
316377
from azure.mgmt.resource import ResourceManagementClient
317378

379+
max_undead_containers = 20
380+
318381
client_id = azure_json["client_id"]
319382

320383
if credential is None:
@@ -327,11 +390,15 @@ def run_azure_process(
327390
resource_group_name = azure_json["resource_group"]
328391
subscription_id = azure_json["subscription_id"]
329392

330-
aci_client = ContainerInstanceManagementClient(credential, subscription_id)
393+
aci_client = None
394+
auth_client = None
395+
container_groups = []
331396
res_client = ResourceManagementClient(credential, subscription_id)
332397

333398
# If this first call fails, then allow the Exception to propagate.
334-
resource_group = res_client.resource_groups.get(resource_group_name)
399+
resource_group_location = res_client.resource_groups.get(
400+
resource_group_name
401+
).location
335402

336403
container_resource_requests = ResourceRequests(
337404
cpu=num_cores,
@@ -342,7 +409,6 @@ def run_azure_process(
342409
)
343410

344411
container_group_names = set()
345-
starts = []
346412
for runner_id in range(n_runners):
347413
container_group_name = f"powerlift-container-group-{batch_id}-{runner_id:04}"
348414

@@ -366,85 +432,52 @@ def run_azure_process(
366432
environment_variables=env_vars,
367433
)
368434
container_group = ContainerGroup(
369-
location=resource_group.location,
435+
location=resource_group_location,
370436
containers=[container],
371437
os_type=OperatingSystemTypes.linux,
372438
restart_policy=ContainerGroupRestartPolicy.never,
373439
identity={"type": "SystemAssigned"},
374440
)
375441

442+
if aci_client is None:
443+
aci_client = ContainerInstanceManagementClient(credential, subscription_id)
444+
376445
while True:
377446
try:
378447
# begin_create_or_update returns LROPoller,
379448
# but this is only indicates when the containter is started
380449
started = aci_client.container_groups.begin_create_or_update(
381-
resource_group.name, container_group_name, container_group
450+
resource_group_name, container_group_name, container_group
382451
)
383452
break
384453
except HttpResponseError:
385454
time.sleep(1)
386455

387-
starts.append(started)
388-
389456
container_group_names.add(container_group_name)
457+
heappush(container_groups, (datetime.now(), container_group_name, started))
458+
aci_client, auth_client = assign_delete_permissions(
459+
aci_client,
460+
auth_client,
461+
max_undead_containers,
462+
credential,
463+
subscription_id,
464+
client_id,
465+
resource_group_name,
466+
container_groups,
467+
)
390468

391-
# make sure they have all started before exiting the process
392-
for started in starts:
393-
while True:
394-
try:
395-
while not started.done():
396-
time.sleep(1)
397-
break
398-
except HttpResponseError:
399-
time.sleep(1)
469+
assign_delete_permissions(
470+
aci_client,
471+
auth_client,
472+
0,
473+
credential,
474+
subscription_id,
475+
client_id,
476+
resource_group_name,
477+
container_groups,
478+
)
400479

401480
if delete_group_container_on_complete:
402-
auth_client = AuthorizationManagementClient(credential, subscription_id)
403-
404-
# Contributor Role
405-
role_definition_id = f"/subscriptions/{subscription_id}/providers/Microsoft.Authorization/roleDefinitions/b24988ac-6180-42a0-ab88-20f7382dd24c"
406-
407-
for container_group_name in container_group_names:
408-
while True:
409-
try:
410-
container_group = aci_client.container_groups.get(
411-
resource_group_name, container_group_name
412-
)
413-
break
414-
except HttpResponseError:
415-
time.sleep(1)
416-
417-
scope = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.ContainerInstance/containerGroups/{container_group_name}"
418-
role_assignment_params = RoleAssignmentCreateParameters(
419-
role_definition_id=role_definition_id,
420-
principal_id=container_group.identity.principal_id,
421-
principal_type="ServicePrincipal",
422-
)
423-
424-
while True:
425-
try:
426-
auth_client.role_assignments.create(
427-
scope, str(uuid.uuid4()), role_assignment_params
428-
)
429-
break
430-
except HttpResponseError:
431-
time.sleep(1)
432-
433-
role_assignment_params = RoleAssignmentCreateParameters(
434-
role_definition_id=role_definition_id,
435-
principal_id=client_id,
436-
principal_type="User",
437-
)
438-
439-
while True:
440-
try:
441-
auth_client.role_assignments.create(
442-
scope, str(uuid.uuid4()), role_assignment_params
443-
)
444-
break
445-
except HttpResponseError:
446-
time.sleep(1)
447-
448481
deletes = []
449482
while len(container_group_names) != 0:
450483
remove_after = []
@@ -466,7 +499,7 @@ def run_azure_process(
466499
)
467500
deletes.append(deleted)
468501
remove_after.append(container_group_name)
469-
except (HttpResponseError, MaybeEncodingError):
502+
except HttpResponseError:
470503
pass
471504

472505
for container_group_name in remove_after:

0 commit comments

Comments
 (0)