Skip to content

Commit 2669c52

Browse files
authored
[ML] feat: Add support for Index asset (Azure#35251)
* refactor: Remove duplicatation in workspace api discovery * chore: Generate restclient from AzureAI.Assets typespec Generated using eng/scripts/TypeSpec-Generate-Sdk.ps1 script from Azure/azure-rest-api-specs. Patched it to append the '@azure-tools/typespec-python.generate-packaging-files=false' option since we aren't generating a standalone SDK. Patched the tspconfig.yaml for AzureAI.Assets to: * Force the SDK to generate at sdk/ml/azure-ai-ml/azure/ai/ml/_restclient * Call it a different name from azurekai-resources-autogen * fix: Fix the base genericasset's api endpoint The endpoint in the typespec was out of date. * fix: Don't url encode the {endpoint} parameter. The {endpoint} parameter in the typespec is used to specify the base url for the generic asset api (depends on the workspace location). The python autogenerated sdk assumes that the parameter is part of the URL path, and tries to url encode the {endpoint} parameter. This is problematic since this encodes the forward slashes (//) joining the scheme and network location to `%2F%2F`. * feat: Add Index Asset * feat: Expose Index Asset publicly * feat: Add IndexOperations * feat: Add ml_client.indexes * feat: Add latest label support to IndexOperations.get * chore: Add __init__.py to ai assets rest client * fix: Resolve mypy failures * fix: Address FIXME about non-compliant service response The service response was updated to align with what's defined in typespec. * feat: Support local upload of index files * refactor: Replace `storage_uri` with `path` The rest api's concept is analogous with the `path` attribute for Assets in the SDK. Using `path` for the SDK Index asset keeps the semantics consistent. * docs: Update types in Index docstring * docs: Update CHANGELOG.md * feat: Auto-increment index version * feat: Mark index asset as experimental * feat: Enable dynamic dispatch create_or_update * docs: Update IndexOperations docstring * docs: Remove datastore from docstring for consistency * docs: Add a clarifying comment to Index._to_rest_object * fix: Prevent Project._from_rest_object from dropping discovery_url The initial implementation of Project._from_rest_object seems to drop many of the fields supported by its parent class. One of which is `discovery_url` which this PR depends on to find the correct endpoint to call for the GenericAsset API. The implementation of Project._from_rest_object appears to do a subset of the assignments of Workspace._from_rest_object. This commit removes the override for Project._from_rest_object allowing the Project class to fall back on Workspace._from_rest_object * test: Add unit tests for Index Asset
1 parent 97d41bf commit 2669c52

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+6314
-70
lines changed

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Project and Hub operations supported by workspace operations.
1010
- workspace list operation supports type filtering.
1111
- Add support for Microsoft Entra token (`aad_token`) auth in `invoke` and `get-credentials` operations.
12+
- Add experimental support for working with indexes: `ml_client.indexes`
1213

1314
### Bugs Fixed
1415

sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_artifact_utilities.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@
5050
from azure.storage.filedatalake import FileSasPermissions, generate_file_sas
5151

5252
if TYPE_CHECKING:
53-
from azure.ai.ml.operations import DataOperations, EnvironmentOperations, FeatureSetOperations, ModelOperations
53+
from azure.ai.ml.operations import (
54+
DataOperations,
55+
EnvironmentOperations,
56+
FeatureSetOperations,
57+
IndexOperations,
58+
ModelOperations,
59+
)
5460
from azure.ai.ml.operations._code_operations import CodeOperations
5561

5662
module_logger = logging.getLogger(__name__)
@@ -454,7 +460,9 @@ def _update_gen2_metadata(name, version, indicator_file, storage_client) -> None
454460

455461
def _check_and_upload_path(
456462
artifact: T,
457-
asset_operations: Union["DataOperations", "ModelOperations", "CodeOperations", "FeatureSetOperations"],
463+
asset_operations: Union[
464+
"DataOperations", "ModelOperations", "CodeOperations", "FeatureSetOperations", "IndexOperations"
465+
],
458466
artifact_type: str,
459467
datastore_name: Optional[str] = None,
460468
sas_uri: Optional[str] = None,
@@ -466,7 +474,7 @@ def _check_and_upload_path(
466474
:param artifact: artifact to check and upload param
467475
:type artifact: T
468476
:param asset_operations: The asset operations to use for uploading
469-
:type asset_operations: Union["DataOperations", "ModelOperations", "CodeOperations"]
477+
:type asset_operations: Union["DataOperations", "ModelOperations", "CodeOperations", "IndexOperations"]
470478
:param artifact_type: The artifact type
471479
:type artifact_type: str
472480
:param datastore_name: the name of the datastore to upload to

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
Compute,
5656
Datastore,
5757
Environment,
58+
Index,
5859
Job,
5960
Model,
6061
ModelBatchDeployment,
@@ -75,6 +76,7 @@
7576
DataOperations,
7677
DatastoreOperations,
7778
EnvironmentOperations,
79+
IndexOperations,
7880
JobOperations,
7981
ModelOperations,
8082
OnlineDeploymentOperations,
@@ -617,6 +619,18 @@ def __init__(
617619
)
618620
self._operation_container.add(AzureMLResourceType.SCHEDULE, self._schedules)
619621

622+
self._indexes = IndexOperations(
623+
operation_scope=self._operation_scope,
624+
operation_config=self._operation_config,
625+
credential=self._credential,
626+
all_operations=self._operation_container,
627+
datastore_operations=self._datastores,
628+
_service_client_kwargs=kwargs,
629+
requests_pipeline=self._requests_pipeline,
630+
**ops_kwargs,
631+
)
632+
self._operation_container.add(AzureMLResourceType.INDEX, self._indexes)
633+
620634
try:
621635
from azure.ai.ml.operations._virtual_cluster_operations import VirtualClusterOperations
622636

@@ -970,6 +984,15 @@ def schedules(self) -> ScheduleOperations:
970984
"""
971985
return self._schedules
972986

987+
@property
988+
def indexes(self) -> IndexOperations:
989+
"""A collection of index related operations.
990+
991+
:return: Index operations.
992+
:rtype: ~azure.ai.ml.operations.IndexOperations
993+
"""
994+
return self._indexes
995+
973996
@property
974997
def subscription_id(self) -> str:
975998
"""Get the subscription ID of an MLClient object.
@@ -1178,6 +1201,12 @@ def _(entity: Datastore, operations):
11781201
return operations[AzureMLResourceType.DATASTORE].create_or_update(entity)
11791202

11801203

1204+
@_create_or_update.register(Index)
1205+
def _(entity: Index, operations, *args, **kwargs):
1206+
module_logger.debug("Creating or updating indexes")
1207+
return operations[AzureMLResourceType.INDEX].begin_create_or_update(entity, **kwargs)
1208+
1209+
11811210
@singledispatch
11821211
def _begin_create_or_update(entity, operations, **kwargs):
11831212
raise TypeError("Please refer to begin_create_or_update docstring for valid input types.")

sdk/ml/azure-ai-ml/azure/ai/ml/_restclient/azure_ai_assets_v2024_04_01/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# coding=utf-8
2+
# --------------------------------------------------------------------------
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License. See License.txt in the project root for license information.
5+
# Code generated by Microsoft (R) Python Code Generator.
6+
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
7+
# --------------------------------------------------------------------------
8+
9+
from ._client import MachineLearningServicesClient
10+
from ._version import VERSION
11+
12+
__version__ = VERSION
13+
14+
try:
15+
from ._patch import __all__ as _patch_all
16+
from ._patch import * # pylint: disable=unused-wildcard-import
17+
except ImportError:
18+
_patch_all = []
19+
from ._patch import patch_sdk as _patch_sdk
20+
21+
__all__ = [
22+
"MachineLearningServicesClient",
23+
]
24+
__all__.extend([p for p in _patch_all if p not in __all__])
25+
26+
_patch_sdk()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# coding=utf-8
2+
# --------------------------------------------------------------------------
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License. See License.txt in the project root for license information.
5+
# Code generated by Microsoft (R) Python Code Generator.
6+
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
7+
# --------------------------------------------------------------------------
8+
9+
from copy import deepcopy
10+
from typing import Any, TYPE_CHECKING
11+
12+
from azure.core import PipelineClient
13+
from azure.core.pipeline import policies
14+
from azure.core.rest import HttpRequest, HttpResponse
15+
16+
from ._configuration import MachineLearningServicesClientConfiguration
17+
from ._serialization import Deserializer, Serializer
18+
from .operations import IndexesOperations
19+
20+
if TYPE_CHECKING:
21+
# pylint: disable=unused-import,ungrouped-imports
22+
from azure.core.credentials import TokenCredential
23+
24+
25+
class MachineLearningServicesClient: # pylint: disable=client-accepts-api-version-keyword
26+
"""MachineLearningServicesClient.
27+
28+
:ivar indexes: IndexesOperations operations
29+
:vartype indexes: azureaiassetsv20240401.operations.IndexesOperations
30+
:param endpoint: Supported Azure-AI asset endpoints. Required.
31+
:type endpoint: str
32+
:param subscription_id: The ID of the target subscription. Required.
33+
:type subscription_id: str
34+
:param resource_group_name: The name of the Resource Group. Required.
35+
:type resource_group_name: str
36+
:param workspace_name: The name of the AzureML workspace or AI project. Required.
37+
:type workspace_name: str
38+
:param credential: Credential used to authenticate requests to the service. Required.
39+
:type credential: ~azure.core.credentials.TokenCredential
40+
:keyword api_version: The API version to use for this operation. Default value is
41+
"2024-04-01-preview". Note that overriding this default value may result in unsupported
42+
behavior.
43+
:paramtype api_version: str
44+
"""
45+
46+
def __init__(
47+
self,
48+
endpoint: str,
49+
subscription_id: str,
50+
resource_group_name: str,
51+
workspace_name: str,
52+
credential: "TokenCredential",
53+
**kwargs: Any
54+
) -> None:
55+
_endpoint = "{endpoint}/genericasset/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}" # pylint: disable=line-too-long
56+
self._config = MachineLearningServicesClientConfiguration(
57+
endpoint=endpoint,
58+
subscription_id=subscription_id,
59+
resource_group_name=resource_group_name,
60+
workspace_name=workspace_name,
61+
credential=credential,
62+
**kwargs
63+
)
64+
_policies = kwargs.pop("policies", None)
65+
if _policies is None:
66+
_policies = [
67+
policies.RequestIdPolicy(**kwargs),
68+
self._config.headers_policy,
69+
self._config.user_agent_policy,
70+
self._config.proxy_policy,
71+
policies.ContentDecodePolicy(**kwargs),
72+
self._config.redirect_policy,
73+
self._config.retry_policy,
74+
self._config.authentication_policy,
75+
self._config.custom_hook_policy,
76+
self._config.logging_policy,
77+
policies.DistributedTracingPolicy(**kwargs),
78+
policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None,
79+
self._config.http_logging_policy,
80+
]
81+
self._client: PipelineClient = PipelineClient(base_url=_endpoint, policies=_policies, **kwargs)
82+
83+
self._serialize = Serializer()
84+
self._deserialize = Deserializer()
85+
self._serialize.client_side_validation = False
86+
self.indexes = IndexesOperations(self._client, self._config, self._serialize, self._deserialize)
87+
88+
def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse:
89+
"""Runs the network request through the client's chained policies.
90+
91+
>>> from azure.core.rest import HttpRequest
92+
>>> request = HttpRequest("GET", "https://www.example.org/")
93+
<HttpRequest [GET], url: 'https://www.example.org/'>
94+
>>> response = client.send_request(request)
95+
<HttpResponse: 200 OK>
96+
97+
For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request
98+
99+
:param request: The network request you want to make. Required.
100+
:type request: ~azure.core.rest.HttpRequest
101+
:keyword bool stream: Whether the response payload will be streamed. Defaults to False.
102+
:return: The response of your network call. Does not do error handling on your response.
103+
:rtype: ~azure.core.rest.HttpResponse
104+
"""
105+
106+
request_copy = deepcopy(request)
107+
path_format_arguments = {
108+
"endpoint": self._config.endpoint,
109+
"subscriptionId": self._serialize.url("self._config.subscription_id", self._config.subscription_id, "str"),
110+
"resourceGroupName": self._serialize.url(
111+
"self._config.resource_group_name", self._config.resource_group_name, "str"
112+
),
113+
"workspaceName": self._serialize.url("self._config.workspace_name", self._config.workspace_name, "str"),
114+
}
115+
116+
request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments)
117+
return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore
118+
119+
def close(self) -> None:
120+
self._client.close()
121+
122+
def __enter__(self) -> "MachineLearningServicesClient":
123+
self._client.__enter__()
124+
return self
125+
126+
def __exit__(self, *exc_details: Any) -> None:
127+
self._client.__exit__(*exc_details)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# coding=utf-8
2+
# --------------------------------------------------------------------------
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License. See License.txt in the project root for license information.
5+
# Code generated by Microsoft (R) Python Code Generator.
6+
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
7+
# --------------------------------------------------------------------------
8+
9+
from typing import Any, TYPE_CHECKING
10+
11+
from azure.core.pipeline import policies
12+
13+
from ._version import VERSION
14+
15+
if TYPE_CHECKING:
16+
# pylint: disable=unused-import,ungrouped-imports
17+
from azure.core.credentials import TokenCredential
18+
19+
20+
class MachineLearningServicesClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long
21+
"""Configuration for MachineLearningServicesClient.
22+
23+
Note that all parameters used to create this instance are saved as instance
24+
attributes.
25+
26+
:param endpoint: Supported Azure-AI asset endpoints. Required.
27+
:type endpoint: str
28+
:param subscription_id: The ID of the target subscription. Required.
29+
:type subscription_id: str
30+
:param resource_group_name: The name of the Resource Group. Required.
31+
:type resource_group_name: str
32+
:param workspace_name: The name of the AzureML workspace or AI project. Required.
33+
:type workspace_name: str
34+
:param credential: Credential used to authenticate requests to the service. Required.
35+
:type credential: ~azure.core.credentials.TokenCredential
36+
:keyword api_version: The API version to use for this operation. Default value is
37+
"2024-04-01-preview". Note that overriding this default value may result in unsupported
38+
behavior.
39+
:paramtype api_version: str
40+
"""
41+
42+
def __init__(
43+
self,
44+
endpoint: str,
45+
subscription_id: str,
46+
resource_group_name: str,
47+
workspace_name: str,
48+
credential: "TokenCredential",
49+
**kwargs: Any
50+
) -> None:
51+
api_version: str = kwargs.pop("api_version", "2024-04-01-preview")
52+
53+
if endpoint is None:
54+
raise ValueError("Parameter 'endpoint' must not be None.")
55+
if subscription_id is None:
56+
raise ValueError("Parameter 'subscription_id' must not be None.")
57+
if resource_group_name is None:
58+
raise ValueError("Parameter 'resource_group_name' must not be None.")
59+
if workspace_name is None:
60+
raise ValueError("Parameter 'workspace_name' must not be None.")
61+
if credential is None:
62+
raise ValueError("Parameter 'credential' must not be None.")
63+
64+
self.endpoint = endpoint
65+
self.subscription_id = subscription_id
66+
self.resource_group_name = resource_group_name
67+
self.workspace_name = workspace_name
68+
self.credential = credential
69+
self.api_version = api_version
70+
self.credential_scopes = kwargs.pop("credential_scopes", ["https://ml.azure.com/.default"])
71+
kwargs.setdefault("sdk_moniker", "azure_ai_assets_v2024_04_01/{}".format(VERSION))
72+
self.polling_interval = kwargs.get("polling_interval", 30)
73+
self._configure(**kwargs)
74+
75+
def _configure(self, **kwargs: Any) -> None:
76+
self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs)
77+
self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs)
78+
self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs)
79+
self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs)
80+
self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs)
81+
self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs)
82+
self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs)
83+
self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs)
84+
self.authentication_policy = kwargs.get("authentication_policy")
85+
if self.credential and not self.authentication_policy:
86+
self.authentication_policy = policies.BearerTokenCredentialPolicy(
87+
self.credential, *self.credential_scopes, **kwargs
88+
)

0 commit comments

Comments
 (0)