Skip to content

Commit

Permalink
[Backport] Remove ppl tool execution setting (#383)
Browse files Browse the repository at this point in the history
* Remove ppl tool execution setting

Signed-off-by: zane-neo <[email protected]>

* fix failure UTs

Signed-off-by: zane-neo <[email protected]>

* backport 381 to 2.x

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored Aug 3, 2024
1 parent c52dbea commit fc9ae93
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 131 deletions.
12 changes: 1 addition & 11 deletions src/main/java/org/opensearch/agent/ToolPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.List;
import java.util.function.Supplier;

import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.CreateAnomalyDetectorTool;
import org.opensearch.agent.tools.NeuralSparseSearchTool;
import org.opensearch.agent.tools.PPLTool;
Expand All @@ -20,12 +19,9 @@
import org.opensearch.agent.tools.SearchAnomalyResultsTool;
import org.opensearch.agent.tools.SearchMonitorsTool;
import org.opensearch.agent.tools.VectorDBTool;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -64,9 +60,7 @@ public Collection<Object> createComponents(
this.client = client;
this.clusterService = clusterService;
this.xContentRegistry = xContentRegistry;
Settings settings = environment.settings();
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
PPLTool.Factory.getInstance().init(client);
NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry);
VectorDBTool.Factory.getInstance().init(client, xContentRegistry);
RAGTool.Factory.getInstance().init(client, xContentRegistry);
Expand Down Expand Up @@ -94,8 +88,4 @@ public List<Tool.Factory<? extends Tool>> getToolFactories() {
);
}

@Override
public List<Setting<?>> getSettings() {
return List.of(SkillSettings.PPL_EXECUTION_ENABLED);
}
}
22 changes: 0 additions & 22 deletions src/main/java/org/opensearch/agent/common/SkillSettings.java

This file was deleted.

23 changes: 3 additions & 20 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.MappingMetadata;
Expand Down Expand Up @@ -98,9 +96,7 @@ public class PPLTool implements Tool {

private int head;

private ClusterSettingHelper clusterSettingHelper;

private static Gson gson = new Gson();
private static Gson gson = org.opensearch.ml.common.utils.StringUtils.gson;

private static Map<String, String> DEFAULT_PROMPT_DICT;

Expand Down Expand Up @@ -153,7 +149,6 @@ public static PPLModelType from(String value) {

public PPLTool(
Client client,
ClusterSettingHelper clusterSettingHelper,
String modelId,
String contextPrompt,
String pplModelType,
Expand All @@ -172,7 +167,6 @@ public PPLTool(
this.previousToolKey = previousToolKey;
this.head = head;
this.execute = execute;
this.clusterSettingHelper = clusterSettingHelper;
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -222,14 +216,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0);
Map<String, String> dataAsMap = (Map<String, String>) modelTensor.getDataAsMap();
String ppl = parseOutput(dataAsMap.get("response"), indexName);
boolean pplExecutedEnabled = clusterSettingHelper.getClusterSettings(SkillSettings.PPL_EXECUTION_ENABLED);
if (!pplExecutedEnabled || !this.execute) {
if (!pplExecutedEnabled) {
log
.debug(
"PPL execution is disabled, the query will be returned directly, to enable this, please set plugins.skills.ppl_execution_enabled to true"
);
}
if (!this.execute) {
Map<String, String> ret = ImmutableMap.of("ppl", ppl);
listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(ret)));
return;
Expand Down Expand Up @@ -298,8 +285,6 @@ public boolean validate(Map<String, String> parameters) {
public static class Factory implements Tool.Factory<PPLTool> {
private Client client;

private ClusterSettingHelper clusterSettingHelper;

private static Factory INSTANCE;

public static Factory getInstance() {
Expand All @@ -315,17 +300,15 @@ public static Factory getInstance() {
}
}

public void init(Client client, ClusterSettingHelper clusterSettingHelper) {
public void init(Client client) {
this.client = client;
this.clusterSettingHelper = clusterSettingHelper;
}

@Override
public PPLTool create(Map<String, Object> map) {
validatePPLToolParameters(map);
return new PPLTool(
client,
clusterSettingHelper,
(String) map.get("model_id"),
(String) map.getOrDefault("prompt", ""),
(String) map.getOrDefault("model_type", ""),
Expand Down

This file was deleted.

35 changes: 1 addition & 34 deletions src/test/java/org/opensearch/agent/tools/PPLToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
Expand All @@ -26,15 +24,10 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.agent.common.SkillSettings;
import org.opensearch.agent.tools.utils.ClusterSettingHelper;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.IndicesAdminClient;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -128,13 +121,7 @@ public void setup() {
listener.onResponse(transportPPLQueryResponse);
return null;
}).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any());

Settings settings = Settings.builder().put(SkillSettings.PPL_EXECUTION_ENABLED.getKey(), true).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(SkillSettings.PPL_EXECUTION_ENABLED)));
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
PPLTool.Factory.getInstance().init(client);
}

@Test
Expand Down Expand Up @@ -413,26 +400,6 @@ public void testTool_executePPLFailure() {
);
}

@Test
public void test_pplTool_whenPPLExecutionDisabled_returnOnlyContainsPPL() {
Settings settings = Settings.builder().put(SkillSettings.PPL_EXECUTION_ENABLED.getKey(), false).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(SkillSettings.PPL_EXECUTION_ENABLED)));
ClusterSettingHelper clusterSettingHelper = new ClusterSettingHelper(settings, clusterService);
PPLTool.Factory.getInstance().init(client, clusterSettingHelper);
PPLTool tool = PPLTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "100"));
assertEquals(PPLTool.TYPE, tool.getName());

tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.<String>wrap(executePPLResult -> {
Map<String, String> returnResults = gson.fromJson(executePPLResult, Map.class);
assertNull(returnResults.get("executionResult"));
assertEquals("source=demo| head 1", returnResults.get("ppl"));
}, log::error));
}

private void createMappings() {
indexMappings = new HashMap<>();
indexMappings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void updateClusterSettings() {
updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100);
updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true);
updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true);
updateClusterSettings("plugins.skills.ppl_execution_enabled", true);
}

@SneakyThrows
Expand Down
8 changes: 0 additions & 8 deletions src/test/java/org/opensearch/integTest/PPLToolIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ public void testPPLTool() {
);
}

public void test_PPLTool_whenPPLExecutionDisabled_ResultOnlyContainsPPL() {
updateClusterSettings("plugins.skills.ppl_execution_enabled", false);
prepareIndex();
String agentId = registerAgent();
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}");
assertEquals("{\"ppl\":\"source\\u003demployee| where age \\u003e 56 | stats COUNT() as cnt\"}", result);
}

public void testPPLTool_withWrongPPLGenerated_thenThrowException() {
prepareIndex();
String agentId = registerAgent();
Expand Down

0 comments on commit fc9ae93

Please sign in to comment.