Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,16 @@ public static boolean isJson(String json) {
* prepareJsonValue("{\"key\":123}") → {\"key\":123} (valid JSON object, unchanged)
* </pre>
* @param input
* @param escape
* @return
*/
public static String prepareJsonValue(String input) {
public static String prepareJsonValue(String input, boolean escape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to change API documentation based on new behavior

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, the documentation need update as well

if (isJson(input)) {
return input;
if (!escape) {
return input;
} else {
return escapeJson(input);
}
}
return escapeJson(input);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
public class ToolUtils {

public static final String TOOL_OUTPUT_FILTERS_FIELD = "output_filter";
public static final String TOOL_OUTPUT_ESCAPED = "output_escaped";
public static final String TOOL_REQUIRED_PARAMS = "required_parameters";
public static final String NO_ESCAPE_PARAMS = "no_escape_params";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,14 +861,21 @@ public void testValidateFields_InvalidCharacterSet() {
@Test
public void prepareJsonValue_returnsRawIfJson() {
String json = "{\"key\": 123}";
String result = StringUtils.prepareJsonValue(json);
String result = StringUtils.prepareJsonValue(json, false);
assertSame(json, result); // branch where isJson(input)==true
}

@Test
public void prepareJsonValue_returnEscapeJsonIfForce() {
String json = "{\"key\": 123}";
String result = StringUtils.prepareJsonValue(json, true);
assertEquals("{\\\"key\\\": 123}", result);
}

@Test
public void prepareJsonValue_escapesBadCharsOtherwise() {
String input = "Tom & Jerry \"<script>";
String escaped = StringUtils.prepareJsonValue(input);
String escaped = StringUtils.prepareJsonValue(input, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is misleading here, you are actually escaping the plain text, but passing "escape" parameter as false. May be rename "escape" parameter in the API definition, which should understand it is only about json string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, i will update it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename it to escapeJson

assertNotEquals(input, escaped);
assertFalse(StringUtils.isJson(escaped));
assertEquals("Tom & Jerry \\\"<script>", escaped);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID;
import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD;
import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_ESCAPED;
import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD;
import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput;
import static org.opensearch.ml.common.utils.ToolUtils.getToolName;
Expand Down Expand Up @@ -262,7 +263,12 @@ private void processOutput(
String outputKey = toolName + ".output";
Map<String, String> toolParameters = ToolUtils.buildToolParameters(params, previousToolSpec, tenantId);
String filteredOutput = parseResponse(filterToolOutput(toolParameters, output));
params.put(outputKey, StringUtils.prepareJsonValue(filteredOutput));
params
.put(
outputKey,
StringUtils
.prepareJsonValue(filteredOutput, Boolean.parseBoolean(toolParameters.getOrDefault(TOOL_OUTPUT_ESCAPED, "false")))
);
boolean traceDisabled = params.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(params.get(DISABLE_TRACE));

if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_ESCAPED;
import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD;
import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput;
import static org.opensearch.ml.common.utils.ToolUtils.getToolName;
Expand Down Expand Up @@ -114,7 +115,15 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
String outputKey = toolName + ".output";
Map<String, String> toolParameters = ToolUtils.buildToolParameters(params, previousToolSpec, mlAgent.getTenantId());
String filteredOutput = parseResponse(filterToolOutput(toolParameters, output));
params.put(outputKey, StringUtils.prepareJsonValue(filteredOutput));
params
.put(
outputKey,
StringUtils
.prepareJsonValue(
filteredOutput,
Boolean.parseBoolean(toolParameters.getOrDefault(TOOL_OUTPUT_ESCAPED, "false"))
)
);
if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) {
if (toolParameters.containsKey(TOOL_OUTPUT_FILTERS_FIELD)) {
flowAgentOutput.add(ModelTensor.builder().name(outputKey).result(filteredOutput).build());
Expand Down
Loading