Skip to content

Commit 2aef6c3

Browse files
authored
KG-508 Update Tool API to fix name and descriptor discrepancy and make it more robust (#1226)
**This PR introduces breaking changes in Tool API** Currently `argsSerializer`, `resultSerializer`, `name`, `description` and `descriptor` are meant to be overriden to provide your own values. Since we can also generate `descriptor` automatically now, it introduces a discrepancy between property values and the actual state of things. It's described in [KG-508](https://youtrack.jetbrains.com/issue/KG-508). So current version makes for unstable and somewhat confusing API. In this PR, I've moved these configurable tool properties to the constructors instead. This avoids confusion, since now the ways in which you can configure the tool are visible more explicitly - either provide your own `descriptor`, or opt for an automatic generation and only provide `name` and `description`. Updated API also looks a bit more concise and fluent, IMHO, since you have to write less `override` now. Compare these: ```kt // Before object DummyTool : SimpleTool<Unit>() { override val argsSerializer = Unit.serializer() override val name: String = "dummy" override val description: String = "Dummy tool for testing" override suspend fun doExecute(args: Unit): String = "Dummy result" } // After object DummyTool : SimpleTool<Unit>( argsSerializer = Unit.serializer(), name = "dummy", description = "Dummy tool for testing" ) { override suspend fun execute(args: Unit): String = "Dummy result" } ``` Also, another breaking change, as you have already seen from the example above, is `doExecute` removal. It is not needed anymore, since tools support primitive types now, so you can just use a regular `execute`
1 parent 3a8b117 commit 2aef6c3

File tree

76 files changed

+990
-1152
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+990
-1152
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgentTool.kt

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package ai.koog.agents.core.agent
22

33
import ai.koog.agents.core.agent.AIAgentTool.AgentToolResult
44
import ai.koog.agents.core.tools.Tool
5-
import ai.koog.agents.core.tools.ToolDescriptor
65
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
76
import ai.koog.agents.core.tools.asToolDescriptor
87
import kotlinx.serialization.KSerializer
@@ -76,15 +75,19 @@ public inline fun <reified Input, reified Output> AIAgent<Input, Output>.asTool(
7675
* @property outputSerializer A serializer for converting the output type to/from JSON.
7776
* @param parentAgentId Optional ID of the parent AI agent. Tool agent IDs will be generated as "parentAgentId.<number of tool call>"
7877
*/
79-
public class AIAgentTool<Input, Output>(
78+
public class AIAgentTool<Input, Output> @OptIn(InternalAgentToolsApi::class) constructor(
8079
private val agentService: AIAgentService<Input, Output, *>,
8180
private val agentName: String,
8281
private val agentDescription: String,
8382
private val inputDescription: String? = null,
8483
private val inputSerializer: KSerializer<Input>,
8584
private val outputSerializer: KSerializer<Output>,
8685
private val parentAgentId: String? = null
87-
) : Tool<Input, AgentToolResult<Output>>() {
86+
) : Tool<Input, AgentToolResult<Output>>(
87+
argsSerializer = inputSerializer,
88+
resultSerializer = AgentToolResult.serializer(outputSerializer),
89+
descriptor = inputSerializer.descriptor.asToolDescriptor(agentName, agentDescription, inputDescription)
90+
) {
8891
@OptIn(ExperimentalAtomicApi::class)
8992
private val toolCallNumber: AtomicInt = AtomicInt(0)
9093

@@ -105,16 +108,6 @@ public class AIAgentTool<Input, Output>(
105108
val result: Output? = null
106109
)
107110

108-
override val argsSerializer: KSerializer<Input> = inputSerializer
109-
override val resultSerializer: KSerializer<AgentToolResult<Output>> = AgentToolResult.serializer(outputSerializer)
110-
111-
override val name: String = agentName
112-
override val description: String = agentDescription
113-
114-
@OptIn(InternalAgentToolsApi::class)
115-
override val descriptor: ToolDescriptor =
116-
inputSerializer.descriptor.asToolDescriptor(name, description, inputDescription)
117-
118111
@OptIn(InternalAgentToolsApi::class)
119112
override suspend fun execute(args: Input): AgentToolResult<Output> {
120113
return try {

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ internal class GenericAgentEnvironment(
5858
)
5959
}
6060

61-
val toolDescription = tool.description
61+
val toolDescription = tool.descriptor.description
6262

6363
// Tool Args
6464
val toolArgs = try {

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/CalculatorTools.kt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@ package ai.koog.agents.core
22

33
import ai.koog.agents.core.tools.Tool
44
import ai.koog.agents.core.tools.annotations.LLMDescription
5-
import kotlinx.serialization.KSerializer
65
import kotlinx.serialization.Serializable
76
import kotlin.jvm.JvmInline
87

98
object CalculatorTools {
109

1110
abstract class CalculatorTool(
12-
override val name: String,
13-
override val description: String,
14-
) : Tool<CalculatorTool.Args, CalculatorTool.Result>() {
11+
name: String,
12+
description: String,
13+
) : Tool<CalculatorTool.Args, CalculatorTool.Result>(
14+
argsSerializer = Args.serializer(),
15+
resultSerializer = Result.serializer(),
16+
name = name,
17+
description = description
18+
) {
1519
@Serializable
1620
data class Args(
1721
@property:LLMDescription("First number")
@@ -23,9 +27,6 @@ object CalculatorTools {
2327
@Serializable
2428
@JvmInline
2529
value class Result(val result: Float)
26-
27-
final override val argsSerializer = Args.serializer()
28-
override val resultSerializer: KSerializer<Result> = Result.serializer()
2930
}
3031

3132
object PlusTool : CalculatorTool(

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/DummyTools.kt

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,23 @@ import ai.koog.agents.core.tools.annotations.LLMDescription
55
import kotlinx.serialization.Serializable
66
import kotlinx.serialization.builtins.serializer
77

8-
object DummyTool : SimpleTool<Unit>() {
9-
override val argsSerializer = Unit.serializer()
10-
11-
override val description: String = "Dummy tool for testing"
12-
13-
override suspend fun doExecute(args: Unit): String = "Dummy result"
8+
object DummyTool : SimpleTool<Unit>(
9+
argsSerializer = Unit.serializer(),
10+
name = "dummy_tool",
11+
description = "Dummy tool for testing"
12+
) {
13+
override suspend fun execute(args: Unit): String = "Dummy result"
1414
}
1515

16-
object CreateTool : SimpleTool<CreateTool.Args>() {
16+
object CreateTool : SimpleTool<CreateTool.Args>(
17+
argsSerializer = Args.serializer(),
18+
name = "create",
19+
description = "Create something"
20+
) {
1721
@Serializable
1822
data class Args(
1923
@property:LLMDescription("Name of the entity to create") val name: String
2024
)
2125

22-
override val argsSerializer = Args.serializer()
23-
24-
override val name: String = "create"
25-
override val description: String = "Create something"
26-
27-
override suspend fun doExecute(args: Args): String = "created"
26+
override suspend fun execute(args: Args): String = "created"
2827
}

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import ai.koog.prompt.executor.clients.google.GoogleModels
2424
import ai.koog.prompt.executor.clients.openai.OpenAIModels
2525
import ai.koog.prompt.llm.OllamaModels
2626
import kotlinx.coroutines.test.runTest
27-
import kotlinx.serialization.KSerializer
2827
import kotlinx.serialization.Serializable
2928
import kotlinx.serialization.builtins.serializer
3029
import kotlin.test.Test
@@ -272,28 +271,28 @@ class FunctionalAIAgentTest {
272271
data class SimpleOut(val value: String)
273272

274273
object QATools {
275-
object TestEngine : SimpleTool<Spacecraft>() {
276-
override val argsSerializer: KSerializer<Spacecraft> = Spacecraft.serializer()
277-
278-
override val description: String = "Performs testing of the spacecraft engine."
279-
280-
override suspend fun doExecute(args: Spacecraft): String = "Engine is good"
274+
object TestEngine : SimpleTool<Spacecraft>(
275+
argsSerializer = Spacecraft.serializer(),
276+
name = "test_engine",
277+
description = "Performs testing of the spacecraft engine."
278+
) {
279+
override suspend fun execute(args: Spacecraft): String = "Engine is good"
281280
}
282281

283-
object TestBody : SimpleTool<Spacecraft>() {
284-
override val argsSerializer: KSerializer<Spacecraft> = Spacecraft.serializer()
285-
286-
override val description: String = "Performs testing of the spacecraft bofy."
287-
288-
override suspend fun doExecute(args: Spacecraft): String = "Body is good"
282+
object TestBody : SimpleTool<Spacecraft>(
283+
argsSerializer = Spacecraft.serializer(),
284+
name = "test_body",
285+
description = "Performs testing of the spacecraft bofy."
286+
) {
287+
override suspend fun execute(args: Spacecraft): String = "Body is good"
289288
}
290289

291-
object TestBuild : SimpleTool<Spacecraft>() {
292-
override val argsSerializer: KSerializer<Spacecraft> = Spacecraft.serializer()
293-
294-
override val description: String = "Tests how spacecraft is built."
295-
296-
override suspend fun doExecute(args: Spacecraft): String =
290+
object TestBuild : SimpleTool<Spacecraft>(
291+
argsSerializer = Spacecraft.serializer(),
292+
name = "test_build",
293+
description = "Tests how spacecraft is built."
294+
) {
295+
override suspend fun execute(args: Spacecraft): String =
297296
"Spacecraft is built badly... Engine is too big for the body"
298297
}
299298

@@ -302,65 +301,81 @@ class FunctionalAIAgentTest {
302301

303302
// Define sample tools for subtasks, similar in spirit to QATools so tool lists are not empty
304303
object ArchitectureTools {
305-
object AnalyzeRequirements : SimpleTool<String>() {
306-
override val argsSerializer: KSerializer<String> = String.serializer()
307-
override val description: String = "Analyzes high-level mission requirements."
308-
override suspend fun doExecute(args: String): String = "Requirements analyzed: $args"
304+
object AnalyzeRequirements : SimpleTool<String>(
305+
argsSerializer = String.serializer(),
306+
name = "analyze_requirements",
307+
description = "Analyzes high-level mission requirements."
308+
) {
309+
override suspend fun execute(args: String): String = "Requirements analyzed: $args"
309310
}
310311

311-
object DraftArchitecture : SimpleTool<Architecture>() {
312-
override val argsSerializer: KSerializer<Architecture> = Architecture.serializer()
313-
override val description: String = "Drafts an initial spacecraft architecture proposal."
314-
override suspend fun doExecute(args: Architecture): String = "Drafted architecture: ${'$'}{args.name}"
312+
object DraftArchitecture : SimpleTool<Architecture>(
313+
argsSerializer = Architecture.serializer(),
314+
name = "draft_architecture",
315+
description = "Drafts an initial spacecraft architecture proposal."
316+
) {
317+
override suspend fun execute(args: Architecture): String = "Drafted architecture: ${'$'}{args.name}"
315318
}
316319

317320
val tools: List<Tool<*, *>> = listOf(AnalyzeRequirements, DraftArchitecture)
318321
}
319322

320323
object BuildEngineTools {
321-
object EstimateThrust : SimpleTool<Architecture>() {
322-
override val argsSerializer: KSerializer<Architecture> = Architecture.serializer()
323-
override val description: String = "Estimates required thrust for the given architecture."
324-
override suspend fun doExecute(args: Architecture): String = "Estimated thrust for ${'$'}{args.name}"
324+
object EstimateThrust : SimpleTool<Architecture>(
325+
argsSerializer = Architecture.serializer(),
326+
name = "estimate_thrust",
327+
description = "Estimates required thrust for the given architecture."
328+
) {
329+
override suspend fun execute(args: Architecture): String = "Estimated thrust for ${'$'}{args.name}"
325330
}
326331

327-
object SelectFuelType : SimpleTool<Architecture>() {
328-
override val argsSerializer: KSerializer<Architecture> = Architecture.serializer()
329-
override val description: String = "Selects suitable fuel type based on mission profile and constraints."
330-
override suspend fun doExecute(args: Architecture): String = "Fuel selected for ${'$'}{args.name}"
332+
object SelectFuelType : SimpleTool<Architecture>(
333+
argsSerializer = Architecture.serializer(),
334+
name = "select_fuel_type",
335+
description = "Selects suitable fuel type based on mission profile and constraints."
336+
) {
337+
override suspend fun execute(args: Architecture): String = "Fuel selected for ${'$'}{args.name}"
331338
}
332339

333340
val tools: List<Tool<*, *>> = listOf(EstimateThrust, SelectFuelType)
334341
}
335342

336343
object BuildBodyTools {
337-
object ComputeMassBudget : SimpleTool<Architecture>() {
338-
override val argsSerializer: KSerializer<Architecture> = Architecture.serializer()
339-
override val description: String = "Computes mass budget for the spacecraft body."
340-
override suspend fun doExecute(args: Architecture): String = "Mass budget computed for ${'$'}{args.name}"
344+
object ComputeMassBudget : SimpleTool<Architecture>(
345+
argsSerializer = Architecture.serializer(),
346+
name = "compute_mass_budget",
347+
description = "Computes mass budget for the spacecraft body."
348+
) {
349+
override suspend fun execute(args: Architecture): String = "Mass budget computed for ${'$'}{args.name}"
341350
}
342351

343-
object ChooseMaterial : SimpleTool<Architecture>() {
344-
override val argsSerializer: KSerializer<Architecture> = Architecture.serializer()
345-
override val description: String = "Chooses hull material given constraints."
346-
override suspend fun doExecute(args: Architecture): String = "Material chosen for ${'$'}{args.name}"
352+
object ChooseMaterial : SimpleTool<Architecture>(
353+
argsSerializer = Architecture.serializer(),
354+
name = "choose_material",
355+
description = "Chooses hull material given constraints."
356+
) {
357+
override suspend fun execute(args: Architecture): String = "Material chosen for ${'$'}{args.name}"
347358
}
348359

349360
val tools: List<Tool<*, *>> = listOf(ComputeMassBudget, ChooseMaterial)
350361
}
351362

352363
object AssemblyTools {
353-
object CheckInterfaces : SimpleTool<Assembly>() {
354-
override val argsSerializer: KSerializer<Assembly> = Assembly.serializer()
355-
override val description: String = "Checks mechanical, power, and data interfaces between components."
356-
override suspend fun doExecute(args: Assembly): String =
364+
object CheckInterfaces : SimpleTool<Assembly>(
365+
argsSerializer = Assembly.serializer(),
366+
name = "check_interfaces",
367+
description = "Checks mechanical, power, and data interfaces between components."
368+
) {
369+
override suspend fun execute(args: Assembly): String =
357370
"Interfaces check passed for engine ${'$'}{args.engine.name} and body ${'$'}{args.body.name}"
358371
}
359372

360-
object ComputeDryMass : SimpleTool<Assembly>() {
361-
override val argsSerializer: KSerializer<Assembly> = Assembly.serializer()
362-
override val description: String = "Computes total dry mass of the assembly."
363-
override suspend fun doExecute(args: Assembly): String = "Dry mass: ${'$'}{args.totalDryMassKg} kg"
373+
object ComputeDryMass : SimpleTool<Assembly>(
374+
argsSerializer = Assembly.serializer(),
375+
name = "compute_dry_mass",
376+
description = "Computes total dry mass of the assembly."
377+
) {
378+
override suspend fun execute(args: Assembly): String = "Dry mass: ${'$'}{args.totalDryMassKg} kg"
364379
}
365380

366381
val tools: List<Tool<*, *>> = listOf(CheckInterfaces, ComputeDryMass)

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContextTest.kt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,12 @@ class AIAgentLLMContextTest : AgentTestBase() {
174174
val input: String
175175
)
176176

177-
private class TestTool : SimpleTool<TestToolArgs>() {
178-
override val argsSerializer = TestToolArgs.serializer()
179-
180-
override val name: String = "test-tool"
181-
override val description: String = "A test tool for testing"
182-
183-
override suspend fun doExecute(args: TestToolArgs): String {
177+
private class TestTool : SimpleTool<TestToolArgs>(
178+
argsSerializer = TestToolArgs.serializer(),
179+
name = "test-tool",
180+
description = "A test tool for testing"
181+
) {
182+
override suspend fun execute(args: TestToolArgs): String {
184183
return "Processed: ${args.input}"
185184
}
186185
}

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSessionTest.kt

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import ai.koog.prompt.message.ResponseMetaInfo
2525
import ai.koog.prompt.params.LLMParams
2626
import ai.koog.prompt.processor.ResponseProcessor
2727
import kotlinx.coroutines.test.runTest
28-
import kotlinx.serialization.KSerializer
2928
import kotlinx.serialization.Serializable
3029
import kotlin.test.Test
3130
import kotlin.test.assertEquals
@@ -62,40 +61,26 @@ class AIAgentLLMWriteSessionTest {
6261
}
6362
}
6463

65-
class TestTool : SimpleTool<TestTool.Args>() {
64+
class TestTool : SimpleTool<TestTool.Args>(
65+
argsSerializer = Args.serializer(),
66+
name = "test-tool",
67+
description = "A test tool"
68+
) {
6669
@Serializable
6770
data class Args(
6871
@property:LLMDescription("Input parameter")
6972
val input: String
7073
)
7174

72-
override val argsSerializer: KSerializer<Args> = Args.serializer()
73-
74-
override val name: String = "test-tool"
75-
override val description: String = "A test tool"
76-
77-
override suspend fun doExecute(args: Args): String {
75+
override suspend fun execute(args: Args): String {
7876
return "Processed: ${args.input}"
7977
}
8078
}
8179

82-
class CustomTool : Tool<CustomTool.Args, CustomTool.Result>() {
83-
@Serializable
84-
data class Args(val input: String)
85-
86-
@Serializable
87-
data class Result(
88-
@property:LLMDescription("Input parameter")
89-
val output: String
90-
)
91-
92-
override val argsSerializer: KSerializer<Args> = Args.serializer()
93-
override val resultSerializer: KSerializer<Result> = Result.serializer()
94-
95-
override val name: String = "custom-tool"
96-
override val description: String = "A custom tool"
97-
98-
override val descriptor: ToolDescriptor = ToolDescriptor(
80+
class CustomTool : Tool<CustomTool.Args, CustomTool.Result>(
81+
argsSerializer = Args.serializer(),
82+
resultSerializer = Result.serializer(),
83+
descriptor = ToolDescriptor(
9984
name = "custom-tool",
10085
description = "A custom tool",
10186
requiredParameters = listOf(
@@ -106,6 +91,15 @@ class AIAgentLLMWriteSessionTest {
10691
)
10792
)
10893
)
94+
) {
95+
@Serializable
96+
data class Args(val input: String)
97+
98+
@Serializable
99+
data class Result(
100+
@property:LLMDescription("Input parameter")
101+
val output: String
102+
)
109103

110104
override suspend fun execute(args: Args): Result {
111105
return Result("Custom processed: ${args.input}")

0 commit comments

Comments
 (0)