diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 0052c8bfaa..bd603929b4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -7,6 +7,9 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.getResourceSharingClient; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -38,7 +41,6 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; @@ -147,11 +149,13 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId, mlFeatureEnabledSetting.isMultiTenancyEnabled(), CommonValue.ML_MODEL_GROUP_INDEX ); - boolean rsClientPresent = ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null; - if (rsClientPresent && user != null && modelAccessControlHelper.modelAccessControlEnabled() && hasModelGroupIndex) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE) + && user != null + && modelAccessControlHelper.modelAccessControlEnabled() + && hasModelGroupIndex) { // RSC fast-path: get accessible group IDs → gate models (IDs or missing) - ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + ResourceSharingClient rsc = getResourceSharingClient(); rsc.getAccessibleResourceIds(CommonValue.ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> { SearchSourceBuilder gated = Optional.ofNullable(request.source()).orElseGet(SearchSourceBuilder::new); gated.query(rewriteQueryBuilderRSC(gated.query(), ids)); // ids may be empty → "missing only" diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index fbb87c0ff6..e486e7def6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -6,7 +6,9 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import org.opensearch.ExceptionsHelper; @@ -27,7 +29,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; @@ -96,7 +97,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); // if resource sharing feature is enabled, access will be automatically checked by security plugin, so no need to check again - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { checkForAssociatedModels(modelGroupId, tenantId, wrappedListener); } else { validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index 74fc87da30..915a4d8a4f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -8,6 +8,8 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; @@ -27,7 +29,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; @@ -186,7 +187,7 @@ private void validateModelGroupAccess( ) { // if resource sharing feature is enabled, security plugin will have automatically evaluated access to this model group, hence no // need to validate again - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); return; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index b9f00a706e..10cee89eca 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -7,6 +7,9 @@ import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.getResourceSharingClient; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.Collections; @@ -20,7 +23,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.search.MLSearchActionRequest; @@ -89,7 +91,7 @@ private void preProcessRoleAndPerformSearch( .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); // If resource-sharing feature is enabled, we fetch accessible model-groups and restrict the search to those model-groups only. - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { // If a model-group is shared, then it will have been shared at-least at read access, hence the final result is guaranteed // to only contain model-groups that the user at-least has read access to. addAccessibleModelGroupsFilterAndSearch(tenantId, request, doubleWrappedListener); @@ -113,7 +115,7 @@ private void addAccessibleModelGroupsFilterAndSearch( ActionListener wrappedListener ) { SearchSourceBuilder sourceBuilder = request.source() != null ? request.source() : new SearchSourceBuilder(); - ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + ResourceSharingClient rsc = getResourceSharingClient(); // filter by accessible model-groups rsc.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> { sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), ids)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 2db39e2952..9c9aef0683 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,6 +9,8 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; +import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; @@ -36,7 +38,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; @@ -150,7 +151,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener feature is disabled, follow old route - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() == null) { + if (!shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { // TODO: At some point, this call must be replaced by the one above, (i.e. no user info to // be stored in model-group index) if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index adcc5d196e..4697e4aa2f 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -9,6 +9,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import java.util.Collections; @@ -98,8 +99,8 @@ public void validateModelGroupAccess(User user, String modelGroupId, String acti listener.onResponse(true); return; } - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { - ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { + ResourceSharingClient resourceSharingClient = getResourceSharingClient(); resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> { if (!isAuthorized) { listener @@ -173,8 +174,8 @@ public void validateModelGroupAccess( listener.onResponse(true); return; } - if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) { - ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) { + ResourceSharingClient resourceSharingClient = getResourceSharingClient(); resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> { if (!isAuthorized) { listener @@ -288,6 +289,20 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti } } + /** + * Checks whether to utilize new ResourceAuthz + * @param resourceType for which to decide whether to use resource authz + * @return true if the resource-sharing feature is enabled, false otherwise. + */ + public static boolean shouldUseResourceAuthz(String resourceType) { + var client = getResourceSharingClient(); + return client != null && client.isFeatureEnabledForType(resourceType); + } + + public static ResourceSharingClient getResourceSharingClient() { + return ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + } + public boolean skipModelAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as diff --git a/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java b/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java index 03bcad8ec4..0ca10e518a 100644 --- a/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/resources/MLResourceSharingExtensionTests.java @@ -11,6 +11,7 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; +import static org.opensearch.ml.helper.ModelAccessControlHelper.getResourceSharingClient; import java.util.Iterator; import java.util.Set; @@ -71,15 +72,11 @@ public void testAssignResourceSharingClient_setsClientOnAccessor() { MLResourceSharingExtension ext = new MLResourceSharingExtension(); ResourceSharingClient mockClient = mock(ResourceSharingClient.class); - assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), is(nullValue())); + assertThat(getResourceSharingClient(), is(nullValue())); ext.assignResourceSharingClient(mockClient); - assertThat( - "Accessor should hold the client passed to extension", - ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), - equalTo(mockClient) - ); + assertThat("Accessor should hold the client passed to extension", getResourceSharingClient(), equalTo(mockClient)); } @Test @@ -90,16 +87,12 @@ public void testAssignResourceSharingClient_overwritesExistingClient() { // Prime with the first client ResourceSharingClientAccessor.getInstance().setResourceSharingClient(first); - assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), equalTo(first)); + assertThat(getResourceSharingClient(), equalTo(first)); // Now assign a new one via the extension ext.assignResourceSharingClient(second); - assertThat( - "Accessor should be updated to the new client", - ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), - equalTo(second) - ); + assertThat("Accessor should be updated to the new client", getResourceSharingClient(), equalTo(second)); } @Test