Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.getResourceSharingClient;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.ArrayList;
Expand Down Expand Up @@ -38,7 +41,6 @@
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
Expand Down Expand Up @@ -147,11 +149,13 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId,
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
CommonValue.ML_MODEL_GROUP_INDEX
);
boolean rsClientPresent = ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null;

if (rsClientPresent && user != null && modelAccessControlHelper.modelAccessControlEnabled() && hasModelGroupIndex) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)
&& user != null
&& modelAccessControlHelper.modelAccessControlEnabled()
&& hasModelGroupIndex) {
// RSC fast-path: get accessible group IDs → gate models (IDs or missing)
ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
ResourceSharingClient rsc = getResourceSharingClient();
rsc.getAccessibleResourceIds(CommonValue.ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> {
SearchSourceBuilder gated = Optional.ofNullable(request.source()).orElseGet(SearchSourceBuilder::new);
gated.query(rewriteQueryBuilderRSC(gated.query(), ids)); // ids may be empty → "missing only"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.ml.action.model_group;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID;

import org.opensearch.ExceptionsHelper;
Expand All @@ -27,7 +29,6 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction;
Expand Down Expand Up @@ -96,7 +97,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);

// if resource sharing feature is enabled, access will be automatically checked by security plugin, so no need to check again
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
checkForAssociatedModels(modelGroupId, tenantId, wrappedListener);
} else {
validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;

import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
Expand All @@ -27,7 +29,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest;
Expand Down Expand Up @@ -186,7 +187,7 @@ private void validateModelGroupAccess(
) {
// if resource sharing feature is enabled, security plugin will have automatically evaluated access to this model group, hence no
// need to validate again
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build());
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.getResourceSharingClient;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.Collections;
Expand All @@ -20,7 +23,6 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
Expand Down Expand Up @@ -89,7 +91,7 @@ private void preProcessRoleAndPerformSearch(
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));

// If resource-sharing feature is enabled, we fetch accessible model-groups and restrict the search to those model-groups only.
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
// If a model-group is shared, then it will have been shared at-least at read access, hence the final result is guaranteed
// to only contain model-groups that the user at-least has read access to.
addAccessibleModelGroupsFilterAndSearch(tenantId, request, doubleWrappedListener);
Expand All @@ -113,7 +115,7 @@ private void addAccessibleModelGroupsFilterAndSearch(
ActionListener<SearchResponse> wrappedListener
) {
SearchSourceBuilder sourceBuilder = request.source() != null ? request.source() : new SearchSourceBuilder();
ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
ResourceSharingClient rsc = getResourceSharingClient();
// filter by accessible model-groups
rsc.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> {
sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), ids));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
Expand Down Expand Up @@ -36,7 +38,6 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction;
Expand Down Expand Up @@ -150,7 +151,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
)) {
// NOTE all sharing and revoking must happen through share API exposed by security plugin
// client == null -> feature is disabled, follow old route
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() == null) {
if (!shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
// TODO: At some point, this call must be replaced by the one above, (i.e. no user info to
// be stored in model-group index)
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED;

import java.util.Collections;
Expand Down Expand Up @@ -98,8 +99,8 @@ public void validateModelGroupAccess(User user, String modelGroupId, String acti
listener.onResponse(true);
return;
}
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
ResourceSharingClient resourceSharingClient = getResourceSharingClient();
resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> {
if (!isAuthorized) {
listener
Expand Down Expand Up @@ -173,8 +174,8 @@ public void validateModelGroupAccess(
listener.onResponse(true);
return;
}
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
ResourceSharingClient resourceSharingClient = getResourceSharingClient();
resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> {
if (!isAuthorized) {
listener
Expand Down Expand Up @@ -288,6 +289,20 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti
}
}

/**
* Checks whether to utilize new ResourceAuthz
* @param resourceType for which to decide whether to use resource authz
* @return true if the resource-sharing feature is enabled, false otherwise.
*/
public static boolean shouldUseResourceAuthz(String resourceType) {
var client = getResourceSharingClient();
return client != null && client.isFeatureEnabledForType(resourceType);
}

public static ResourceSharingClient getResourceSharingClient() {
return ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
}

public boolean skipModelAccessControl(User user) {
// Case 1: user == null when 1. Security is disabled. 2. When user is super-admin
// Case 2: If Security is enabled and filter is disabled, proceed with search as
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.opensearch.ml.helper.ModelAccessControlHelper.getResourceSharingClient;

import java.util.Iterator;
import java.util.Set;
Expand Down Expand Up @@ -71,15 +72,11 @@ public void testAssignResourceSharingClient_setsClientOnAccessor() {
MLResourceSharingExtension ext = new MLResourceSharingExtension();
ResourceSharingClient mockClient = mock(ResourceSharingClient.class);

assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), is(nullValue()));
assertThat(getResourceSharingClient(), is(nullValue()));

ext.assignResourceSharingClient(mockClient);

assertThat(
"Accessor should hold the client passed to extension",
ResourceSharingClientAccessor.getInstance().getResourceSharingClient(),
equalTo(mockClient)
);
assertThat("Accessor should hold the client passed to extension", getResourceSharingClient(), equalTo(mockClient));
}

@Test
Expand All @@ -90,16 +87,12 @@ public void testAssignResourceSharingClient_overwritesExistingClient() {

// Prime with the first client
ResourceSharingClientAccessor.getInstance().setResourceSharingClient(first);
assertThat(ResourceSharingClientAccessor.getInstance().getResourceSharingClient(), equalTo(first));
assertThat(getResourceSharingClient(), equalTo(first));

// Now assign a new one via the extension
ext.assignResourceSharingClient(second);

assertThat(
"Accessor should be updated to the new client",
ResourceSharingClientAccessor.getInstance().getResourceSharingClient(),
equalTo(second)
);
assertThat("Accessor should be updated to the new client", getResourceSharingClient(), equalTo(second));
}

@Test
Expand Down
Loading