Skip to content

Commit 5aa2937

Browse files
authored
Fix workspace operations when connected by name, improve tests (#706)
* Improve error message in arm request * Use connection params instead of duplicating in workspace, add live test to test workspace by name * Helper method name improvement
1 parent 0df2ac7 commit 5aa2937

File tree

4 files changed

+368
-60
lines changed

4 files changed

+368
-60
lines changed

azure-quantum/azure/quantum/workspace.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def __init__(
159159

160160
self._connection_params = connection_params
161161
self._storage = storage
162-
self._subscription_id = connection_params.subscription_id
163-
self._resource_group = connection_params.resource_group
164-
self._workspace_name = connection_params.workspace_name
165162

166163
if not self._mgmt_client:
167164
credential = connection_params.get_credential_or_default()
@@ -404,9 +401,9 @@ def _get_linked_storage_sas_uri(
404401
container_name=container_name, blob_name=blob_name
405402
)
406403
container_uri = client.get_sas_uri(
407-
self._subscription_id,
408-
self._resource_group,
409-
self._workspace_name,
404+
self.subscription_id,
405+
self.resource_group,
406+
self.name,
410407
blob_details=blob_details)
411408

412409
logger.debug("Container URI from service: %s", container_uri)
@@ -424,9 +421,9 @@ def submit_job(self, job: Job) -> Job:
424421
"""
425422
client = self._get_jobs_client()
426423
details = client.create_or_replace(
427-
self._subscription_id,
428-
self._resource_group,
429-
self._workspace_name,
424+
self.subscription_id,
425+
self.resource_group,
426+
self.name,
430427
job.details.id,
431428
job.details
432429
)
@@ -445,14 +442,14 @@ def cancel_job(self, job: Job) -> Job:
445442
"""
446443
client = self._get_jobs_client()
447444
client.delete(
448-
self._subscription_id,
449-
self._resource_group,
450-
self._workspace_name,
445+
self.subscription_id,
446+
self.resource_group,
447+
self.name,
451448
job.details.id)
452449
details = client.get(
453-
self._subscription_id,
454-
self._resource_group,
455-
self._workspace_name,
450+
self.subscription_id,
451+
self.resource_group,
452+
self.name,
456453
job.id)
457454
return Job(self, details)
458455

@@ -472,9 +469,9 @@ def get_job(self, job_id: str) -> Job:
472469

473470
client = self._get_jobs_client()
474471
details = client.get(
475-
self._subscription_id,
476-
self._resource_group,
477-
self._workspace_name,
472+
self.subscription_id,
473+
self.resource_group,
474+
self.name,
478475
job_id)
479476
target_factory = TargetFactory(base_cls=Target, workspace=self)
480477
# pylint: disable=protected-access
@@ -557,7 +554,7 @@ def list_jobs_paginated(
557554
)
558555
orderby = self._create_orderby(orderby_property, is_asc)
559556

560-
return client.list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self._workspace_name, filter=job_filter, orderby=orderby, top = top, skip = skip)
557+
return client.list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self.name, filter=job_filter, orderby=orderby, top = top, skip = skip)
561558

562559
def _get_target_status(
563560
self,
@@ -580,9 +577,9 @@ def _get_target_status(
580577
return [
581578
(provider.id, target)
582579
for provider in self._client.providers.list(
583-
self._subscription_id,
584-
self._resource_group,
585-
self._workspace_name)
580+
self.subscription_id,
581+
self.resource_group,
582+
self.name)
586583
for target in provider.targets
587584
if (provider_id is None or provider.id.lower() == provider_id.lower())
588585
and (name is None or target.id.lower() == name.lower())
@@ -639,9 +636,9 @@ def get_quotas(self) -> List[Dict[str, Any]]:
639636
"""
640637
client = self._get_quotas_client()
641638
return [q.as_dict() for q in client.list(
642-
self._subscription_id,
643-
self._resource_group,
644-
self._workspace_name
639+
self.subscription_id,
640+
self.resource_group,
641+
self.name
645642
)]
646643

647644
def list_top_level_items(
@@ -712,7 +709,7 @@ def list_top_level_items_paginated(
712709
)
713710
orderby = self._create_orderby(orderby_property, is_asc)
714711

715-
return client.list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self._workspace_name, filter=top_level_item_filter, orderby=orderby, top = top, skip = skip)
712+
return client.list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self.name, filter=top_level_item_filter, orderby=orderby, top = top, skip = skip)
716713

717714
def list_sessions(
718715
self,
@@ -773,7 +770,7 @@ def list_sessions_paginated(
773770

774771
orderby = self._create_orderby(orderby_property=orderby_property, is_asc=is_asc)
775772

776-
return client.list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self._workspace_name, filter = session_filter, orderby=orderby, skip=skip, top=top)
773+
return client.list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self.name, filter = session_filter, orderby=orderby, skip=skip, top=top)
777774

778775
def open_session(
779776
self,
@@ -790,9 +787,9 @@ def open_session(
790787
"""
791788
client = self._get_sessions_client()
792789
session.details = client.create_or_replace(
793-
self._subscription_id,
794-
self._resource_group,
795-
self._workspace_name,
790+
self.subscription_id,
791+
self.resource_group,
792+
self.name,
796793
session.id,
797794
session.details)
798795

@@ -811,15 +808,15 @@ def close_session(
811808
client = self._get_sessions_client()
812809
if not session.is_in_terminal_state():
813810
session.details = client.close(
814-
self._subscription_id,
815-
self._resource_group,
816-
self._workspace_name,
811+
self.subscription_id,
812+
self.resource_group,
813+
self.name,
817814
session_id=session.id)
818815
else:
819816
session.details = client.get(
820-
self._subscription_id,
821-
self._resource_group,
822-
self._workspace_name,
817+
self.subscription_id,
818+
self.resource_group,
819+
self.name,
823820
session_id=session.id)
824821

825822
if session.target:
@@ -855,9 +852,9 @@ def get_session(
855852
"""
856853
client = self._get_sessions_client()
857854
session_details = client.get(
858-
self._subscription_id,
859-
self._resource_group,
860-
self._workspace_name,
855+
self.subscription_id,
856+
self.resource_group,
857+
self.name,
861858
session_id=session_id)
862859
result = Session(workspace=self, details=session_details)
863860
return result
@@ -919,7 +916,7 @@ def list_session_jobs_paginated(
919916

920917
orderby = self._create_orderby(orderby_property=orderby_property, is_asc=is_asc)
921918

922-
return client.jobs_list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self._workspace_name, session_id=session_id, filter = session_job_filter, orderby=orderby, skip=skip, top=top)
919+
return client.jobs_list(subscription_id=self.subscription_id, resource_group_name=self.resource_group, workspace_name=self.name, session_id=session_id, filter = session_job_filter, orderby=orderby, skip=skip, top=top)
923920

924921
def get_container_uri(
925922
self,

0 commit comments

Comments
 (0)