Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ai.koog.integration.tests.agent;

import ai.koog.agents.core.tools.reflect.ToolSet;
import ai.koog.agents.core.tools.annotations.Tool;
import ai.koog.agents.core.tools.annotations.LLMDescription;

public class CalculatorTools implements ToolSet {

@Tool
@LLMDescription(description = "Adds two numbers together")
public int add(
@LLMDescription(description = "First number") int a,
@LLMDescription(description = "Second number") int b
) {
return a + b;
}

@Tool
@LLMDescription(description = "Multiplies two numbers")
public int multiply(
@LLMDescription(description = "First number") int a,
@LLMDescription(description = "Second number") int b
) {
return a * b;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public void integration_CustomPipelineFeature(LLModel model) {
AtomicBoolean agentStarted = new AtomicBoolean(false);
AtomicBoolean agentCompleted = new AtomicBoolean(false);

JavaInteropUtils.TransactionTools transactionTools = new JavaInteropUtils.TransactionTools();
ToolRegistry toolRegistry = JavaInteropUtils.createToolRegistry(transactionTools);
TransactionTools transactionTools = new TransactionTools();
ToolRegistry toolRegistry = ToolRegistry.builder().tools(transactionTools).build();

AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public void integration_BuilderWithTemperature(LLModel model) {
public void integration_BuilderWithToolRegistry(LLModel model) {
Models.assumeAvailable(model.getProvider());

JavaInteropUtils.CalculatorTools calculator = new JavaInteropUtils.CalculatorTools();
ToolRegistry toolRegistry = JavaInteropUtils.createToolRegistry(calculator);
CalculatorTools calculator = new CalculatorTools();
ToolRegistry toolRegistry = ToolRegistry.builder().tools(calculator).build();

AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
Expand All @@ -90,8 +90,8 @@ public void integration_EventHandler(LLModel model) {
AtomicBoolean agentCompleted = new AtomicBoolean(false);
AtomicInteger llmCallsCount = new AtomicInteger(0);

JavaInteropUtils.CalculatorTools calculator = new JavaInteropUtils.CalculatorTools();
ToolRegistry toolRegistry = JavaInteropUtils.createToolRegistry(calculator);
CalculatorTools calculator = new CalculatorTools();
ToolRegistry toolRegistry = ToolRegistry.builder().tools(calculator).build();

AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
Expand Down Expand Up @@ -120,8 +120,8 @@ public void integration_EventHandler(LLModel model) {
public void integration_BuilderWithMaxIterations(LLModel model) {
Models.assumeAvailable(model.getProvider());

JavaInteropUtils.CalculatorTools calculator = new JavaInteropUtils.CalculatorTools();
ToolRegistry toolRegistry = JavaInteropUtils.createToolRegistry(calculator);
CalculatorTools calculator = new CalculatorTools();
ToolRegistry toolRegistry = ToolRegistry.builder().tools(calculator).build();

AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.koog.integration.tests.agent;

import ai.koog.agents.core.agent.AIAgent;
import ai.koog.agents.core.agent.context.AIAgentFunctionalContext;
import ai.koog.agents.core.environment.ReceivedToolResult;
import ai.koog.agents.core.tools.Tool;
import ai.koog.agents.core.tools.ToolRegistry;
Expand All @@ -25,7 +26,7 @@ public class JavaAIAgentFunctionalStrategyIntegrationTest extends KoogJavaTestBa

private String getAssistantContentOrDefault(Message.Response response, String defaultValue) {
if (response instanceof Message.Assistant) {
return JavaInteropUtils.getAssistantContent((Message.Assistant) response);
return response.getContent();
}
return defaultValue;
}
Expand All @@ -36,26 +37,25 @@ public void integration_SimpleFunctionalStrategyWithRetry(LLModel model) {
Models.assumeAvailable(model.getProvider());

// Test simple functional strategy with retry logic
AIAgent<String, String> agent = JavaInteropUtils.buildFunctionalAgent(
JavaInteropUtils.createAgentBuilder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a helpful assistant.")
.functionalStrategy((context, input) -> {
for (int i = 0; i < 3; i++) {
String result = getAssistantContentOrDefault(
JavaInteropUtils.requestLLM(context, input, true),
""
);
if (!result.isEmpty()) {
return result;
}
AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a helpful assistant.")
.functionalStrategy((AIAgentFunctionalContext context, String input) -> {
for (int i = 0; i < 3; i++) {
String result = getAssistantContentOrDefault(
context.requestLLM(input, true),
""
);
if (!result.isEmpty()) {
return result;
}
return "Failed after retries";
})
);
}
return "Failed after retries";
})
.build();

String result = JavaInteropUtils.runAgentBlocking(agent, "Say hello");
String result = agent.run("Say hello");

assertNotNull(result);
assertFalse(result.isEmpty());
Expand All @@ -67,26 +67,24 @@ public void integration_SimpleFunctionalStrategyWithRetry(LLModel model) {
public void integration_MultiStepFunctionalStrategy(LLModel model) {
Models.assumeAvailable(model.getProvider());

AIAgent<String, String> agent = JavaInteropUtils.buildFunctionalAgent(
JavaInteropUtils.createAgentBuilder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a helpful assistant.")
.functionalStrategy((context, input) -> {
Message.Response response1 = JavaInteropUtils.requestLLM(context, "First step: " + input, true);
String step1Result = getAssistantContentOrDefault(response1, "");

Message.Response response2 = JavaInteropUtils.requestLLM(
context,
"Second step, previous result was: " + step1Result,
true
);
AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a helpful assistant.")
.functionalStrategy((AIAgentFunctionalContext context, String input) -> {
Message.Response response1 = context.requestLLM("First step: " + input, true);
String step1Result = getAssistantContentOrDefault(response1, "");

return getAssistantContentOrDefault(response2, "Unexpected response type");
})
);
Message.Response response2 = context.requestLLM(
"Second step, previous result was: " + step1Result,
true
);

return getAssistantContentOrDefault(response2, "Unexpected response type");
})
.build();

String result = JavaInteropUtils.runAgentBlocking(agent, "Count to 3");
String result = agent.run("Count to 3");

assertNotNull(result);
assertFalse(result.isEmpty());
Expand All @@ -97,39 +95,37 @@ public void integration_MultiStepFunctionalStrategy(LLModel model) {
public void integration_FunctionalStrategyWithManualToolHandling(LLModel model) {
Models.assumeAvailable(model.getProvider());

JavaInteropUtils.CalculatorTools calculator = new JavaInteropUtils.CalculatorTools();
ToolRegistry toolRegistry = JavaInteropUtils.createToolRegistry(calculator);

AIAgent<String, String> agent = JavaInteropUtils.buildFunctionalAgent(
JavaInteropUtils.createAgentBuilder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a calculator. Use the add tool to perform calculations.")
.toolRegistry(toolRegistry)
.functionalStrategy((context, input) -> {
Message.Response currentResponse = JavaInteropUtils.requestLLM(
context,
"Calculate: " + input + ". You MUST use the add tool.",
true
);

int maxIterations = 5;
for (int i = 0; i < maxIterations && currentResponse instanceof Message.Tool.Call; i++) {
Message.Tool.Call toolCall = (Message.Tool.Call) currentResponse;
ReceivedToolResult toolResult = JavaInteropUtils.executeTool(context, toolCall);
currentResponse = JavaInteropUtils.sendToolResult(context, toolResult);
}

if (currentResponse instanceof Message.Assistant) {
return JavaInteropUtils.getAssistantContent((Message.Assistant) currentResponse);
} else if (currentResponse instanceof Message.Tool.Call) {
return "Max iterations reached, last tool: " + JavaInteropUtils.getToolName((Message.Tool.Call) currentResponse);
}
return "Unexpected response type";
})
);

String result = JavaInteropUtils.runAgentBlocking(agent, "10 + 5");
CalculatorTools calculator = new CalculatorTools();
ToolRegistry toolRegistry = ToolRegistry.builder().tools(calculator).build();

AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a calculator. Use the add tool to perform calculations.")
.toolRegistry(toolRegistry)
.functionalStrategy((AIAgentFunctionalContext context, String input) -> {
Message.Response currentResponse = context.requestLLM(
"Calculate: " + input + ". You MUST use the add tool.",
true
);

int maxIterations = 5;
for (int i = 0; i < maxIterations && currentResponse instanceof Message.Tool.Call; i++) {
Message.Tool.Call toolCall = (Message.Tool.Call) currentResponse;
ReceivedToolResult toolResult = context.executeTool(toolCall);
currentResponse = context.sendToolResult(toolResult);
}

if (currentResponse instanceof Message.Assistant) {
return currentResponse.getContent();
} else if (currentResponse instanceof Message.Tool.Call) {
return "Max iterations reached, last tool: " + ((Message.Tool.Call) currentResponse).getTool();
}
return "Unexpected response type";
})
.build();

String result = agent.run("10 + 5");

assertNotNull(result);
assertFalse(result.isBlank());
Expand All @@ -143,34 +139,31 @@ public void integration_Subtask(LLModel model) {

MultiLLMPromptExecutor executor = createExecutor(model);

JavaInteropUtils.CalculatorTools calculator = new JavaInteropUtils.CalculatorTools();
CalculatorTools calculator = new CalculatorTools();

List<Tool<?, ?>> calculatorTools = List.of(
calculator.getAddTool(),
calculator.getMultiplyTool()
);

AIAgent<String, String> agent = JavaInteropUtils.buildFunctionalAgent(
JavaInteropUtils.createAgentBuilder()
.promptExecutor(executor)
.llmModel(model)
.systemPrompt("You are a helpful assistant that coordinates calculations.")
.toolRegistry(JavaInteropUtils.createToolRegistry(calculator))
.functionalStrategy((context, input) -> {
String subtaskResult = JavaInteropUtils.runSubtask(
context,
"Calculate: " + input,
input,
String.class,
calculatorTools,
model
);

return "Calculation result: " + subtaskResult;
})
calculator.getTool("add"),
calculator.getTool("multiply")
);

String result = JavaInteropUtils.runAgentBlocking(agent, "What is 5 + 3?");
AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(executor)
.llmModel(model)
.systemPrompt("You are a helpful assistant that coordinates calculations.")
.toolRegistry(ToolRegistry.builder().tools(calculator).build())
.functionalStrategy((AIAgentFunctionalContext context, String input) -> {
String subtaskResult = context.subtask("Calculate: " + input)
.withInput(input)
.withOutput(String.class)
.withTools(calculatorTools)
.useLLM(model)
.run();

return "Calculation result: " + subtaskResult;
})
.build();

String result = agent.run("What is 5 + 3?");

assertNotNull(result);
assertFalse(result.isEmpty());
Expand All @@ -181,27 +174,25 @@ public void integration_Subtask(LLModel model) {
public void integration_CustomStrategyWithValidation(LLModel model) {
Models.assumeAvailable(model.getProvider());

AIAgent<String, String> agent = JavaInteropUtils.buildFunctionalAgent(
JavaInteropUtils.createAgentBuilder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a helpful assistant that generates JSON.")
.functionalStrategy((context, input) -> {
Message.Response response = JavaInteropUtils.requestLLM(
context,
"Generate a JSON object with 'status' field set to 'success'",
true
);

String content = getAssistantContentOrDefault(response, "Unexpected response type");
if (content.contains("status") && content.contains("success")) {
return content;
}
return "Validation failed: response doesn't contain expected fields";
})
);

String result = JavaInteropUtils.runAgentBlocking(agent, "Generate status JSON");
AIAgent<String, String> agent = AIAgent.builder()
.promptExecutor(createExecutor(model))
.llmModel(model)
.systemPrompt("You are a helpful assistant that generates JSON.")
.functionalStrategy((AIAgentFunctionalContext context, String input) -> {
Message.Response response = context.requestLLM(
"Generate a JSON object with 'status' field set to 'success'",
true
);

String content = getAssistantContentOrDefault(response, "Unexpected response type");
if (content.contains("status") && content.contains("success")) {
return content;
}
return "Validation failed: response doesn't contain expected fields";
})
.build();

String result = agent.run("Generate status JSON");

assertNotNull(result);
assertFalse(result.isEmpty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ public void integration_AIAgentServiceCreateAgentAndRun(LLModel model) {
public void integration_AIAgentServiceWithCustomToolRegistry(LLModel model) {
Models.assumeAvailable(model.getProvider());

JavaInteropUtils.CalculatorTools calculator = new JavaInteropUtils.CalculatorTools();
ToolRegistry serviceToolRegistry = JavaInteropUtils.createToolRegistry(calculator);
CalculatorTools calculator = new CalculatorTools();
ToolRegistry serviceToolRegistry = ToolRegistry.builder().tools(calculator).build();

GraphAIAgentService<String, String> service = AIAgentService.builder()
.promptExecutor(createExecutor(model))
Expand Down Expand Up @@ -240,9 +240,9 @@ public void integration_AIAgentServiceBuilderFunctionalStrategy(LLModel model) {
.systemPrompt("You are a helpful assistant.")
.functionalStrategy((context, input) -> {
String inputStr = (input instanceof String) ? (String) input : String.valueOf(input);
Message.Response response = JavaInteropUtils.requestLLM(context, inputStr, true);
Message.Response response = context.requestLLM(inputStr, true);
if (response instanceof Message.Assistant) {
return JavaInteropUtils.getAssistantContent((Message.Assistant) response);
return response.getContent();
}
return "Unexpected response type";
})
Expand Down
Loading
Loading