1- package ai.koog.a2a.server
2-
3- import ai.koog.a2a.client.A2AClient
4- import ai.koog.a2a.client.UrlAgentCardResolver
5- import ai.koog.a2a.consts.A2AConsts
6- import ai.koog.a2a.model.AgentCapabilities
7- import ai.koog.a2a.model.AgentCard
8- import ai.koog.a2a.model.AgentSkill
1+ package ai.koog.a2a.server.jsonrpc
2+
93import ai.koog.a2a.model.Message
104import ai.koog.a2a.model.MessageSendConfiguration
115import ai.koog.a2a.model.MessageSendParams
@@ -15,31 +9,19 @@ import ai.koog.a2a.model.TaskIdParams
159import ai.koog.a2a.model.TaskState
1610import ai.koog.a2a.model.TaskStatusUpdateEvent
1711import ai.koog.a2a.model.TextPart
18- import ai.koog.a2a.model.TransportProtocol
19- import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage
20- import ai.koog.a2a.test.BaseA2AProtocolTest
2112import ai.koog.a2a.transport.Request
22- import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport
23- import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport
2413import io.kotest.inspectors.shouldForAll
2514import io.kotest.inspectors.shouldForAtLeastOne
2615import io.kotest.matchers.nulls.shouldNotBeNull
2716import io.kotest.matchers.should
2817import io.kotest.matchers.shouldBe
2918import io.kotest.matchers.string.shouldStartWith
3019import io.kotest.matchers.types.shouldBeInstanceOf
31- import io.ktor.client.HttpClient
32- import io.ktor.client.engine.cio.CIO
33- import io.ktor.client.plugins.HttpTimeout
34- import io.ktor.client.plugins.logging.LogLevel
35- import io.ktor.client.plugins.logging.Logging
36- import io.ktor.server.netty.Netty
3720import kotlinx.coroutines.Dispatchers
3821import kotlinx.coroutines.delay
3922import kotlinx.coroutines.flow.toList
4023import kotlinx.coroutines.joinAll
4124import kotlinx.coroutines.launch
42- import kotlinx.coroutines.runBlocking
4325import kotlinx.coroutines.test.runTest
4426import kotlinx.coroutines.withContext
4527import org.junit.jupiter.api.AfterAll
@@ -48,10 +30,9 @@ import org.junit.jupiter.api.RepeatedTest
4830import org.junit.jupiter.api.TestInstance
4931import org.junit.jupiter.api.parallel.Execution
5032import org.junit.jupiter.api.parallel.ExecutionMode
51- import java.net.ServerSocket
5233import kotlin.test.BeforeTest
5334import kotlin.test.Test
54- import kotlin.time.Duration.Companion.minutes
35+ import kotlin.time.Duration.Companion.seconds
5536import kotlin.uuid.ExperimentalUuidApi
5637import kotlin.uuid.Uuid
5738
@@ -63,170 +44,55 @@ import kotlin.uuid.Uuid
6344@OptIn(ExperimentalUuidApi ::class )
6445@TestInstance(TestInstance .Lifecycle .PER_CLASS )
6546@Execution(ExecutionMode .SAME_THREAD , reason = " Working with the same instance of test server." )
66- class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest () {
67- override val testTimeout = 10 .minutes
68-
69- private var testPort: Int? = null
70- private val testPath = " /a2a"
71- private lateinit var serverUrl: String
47+ class A2AServerJsonRpcIntegrationTest : BaseA2AServerJsonRpcTest () {
48+ override val testTimeout = 10 .seconds
7249
73- private lateinit var serverTransport: HttpJSONRPCServerTransport
74- private lateinit var clientTransport: HttpJSONRPCClientTransport
75- private lateinit var httpClient: HttpClient
50+ @BeforeAll
51+ override fun setup () {
52+ super .setup()
53+ }
7654
77- override lateinit var client: A2AClient
55+ @BeforeTest
56+ override fun initClient () {
57+ super .initClient()
58+ }
7859
79- @BeforeAll
80- fun setup (): Unit = runBlocking {
81- // Discover and take any free port
82- testPort = ServerSocket (0 ).use { it.localPort }
83- serverUrl = " http://localhost:$testPort$testPath "
84-
85- // Create agent cards
86- val agentCard = createAgentCard()
87- val agentCardExtended = createExtendedAgentCard()
88-
89- // Create test agent executor
90- val testAgentExecutor = TestAgentExecutor ()
91-
92- // Create A2A server
93- val a2aServer = A2AServer (
94- agentExecutor = testAgentExecutor,
95- agentCard = agentCard,
96- agentCardExtended = agentCardExtended,
97- pushConfigStorage = InMemoryPushNotificationConfigStorage ()
98- )
60+ @AfterAll
61+ override fun tearDown () {
62+ super .tearDown()
63+ }
9964
100- // Create server transport
101- serverTransport = HttpJSONRPCServerTransport (a2aServer)
102-
103- // Start server
104- serverTransport.start(
105- engineFactory = Netty ,
106- port = testPort!! ,
107- path = testPath,
108- wait = false ,
109- agentCard = agentCard,
110- agentCardPath = A2AConsts .AGENT_CARD_WELL_KNOWN_PATH ,
111- )
65+ @Test
66+ override fun `test get agent card` () =
67+ super .`test get agent card`()
11268
113- // Create client transport
114- httpClient = HttpClient (CIO ) {
115- install(Logging ) {
116- level = LogLevel .ALL
117- }
69+ @Test
70+ override fun `test get authenticated extended agent card` () =
71+ super .`test get authenticated extended agent card`()
11872
119- install(HttpTimeout ) {
120- requestTimeoutMillis = testTimeout.inWholeMilliseconds
121- }
122- }
73+ @Test
74+ override fun `test send message` () =
75+ super .`test send message`()
12376
124- clientTransport = HttpJSONRPCClientTransport (serverUrl, httpClient)
77+ @Test
78+ override fun `test send message streaming` () =
79+ super .`test send message streaming`()
12580
126- client = A2AClient (
127- transport = clientTransport,
128- agentCardResolver = UrlAgentCardResolver (
129- baseUrl = serverUrl,
130- path = A2AConsts .AGENT_CARD_WELL_KNOWN_PATH ,
131- baseHttpClient = httpClient,
132- )
133- )
134- }
81+ @Test
82+ override fun `test get task` () =
83+ super .`test get task`()
13584
136- @BeforeTest
137- fun initClient (): Unit = runBlocking {
138- client.connect()
139- }
85+ @Test
86+ override fun `test cancel task` () =
87+ super .`test cancel task`()
14088
141- @AfterAll
142- fun tearDown (): Unit = runBlocking {
143- clientTransport.close()
144- serverTransport.stop()
145- }
89+ @Test
90+ override fun `test resubscribe task` () =
91+ super .`test resubscribe task`()
14692
147- private fun createAgentCard (): AgentCard = AgentCard (
148- protocolVersion = " 0.3.0" ,
149- name = " Hello World Agent" ,
150- description = " Just a hello world agent" ,
151- url = " http://localhost:9999/" ,
152- preferredTransport = TransportProtocol .JSONRPC ,
153- additionalInterfaces = null ,
154- iconUrl = null ,
155- provider = null ,
156- version = " 1.0.0" ,
157- documentationUrl = null ,
158- capabilities = AgentCapabilities (
159- streaming = true ,
160- pushNotifications = true ,
161- stateTransitionHistory = null ,
162- extensions = null
163- ),
164- securitySchemes = null ,
165- security = null ,
166- defaultInputModes = listOf (" text" ),
167- defaultOutputModes = listOf (" text" ),
168- skills = listOf (
169- AgentSkill (
170- id = " hello_world" ,
171- name = " Returns hello world" ,
172- description = " just returns hello world" ,
173- tags = listOf (" hello world" ),
174- examples = listOf (" hi" , " hello world" ),
175- inputModes = null ,
176- outputModes = null ,
177- security = null
178- )
179- ),
180- supportsAuthenticatedExtendedCard = true ,
181- signatures = null
182- )
183-
184- private fun createExtendedAgentCard (): AgentCard = AgentCard (
185- protocolVersion = " 0.3.0" ,
186- name = " Hello World Agent - Extended Edition" ,
187- description = " The full-featured hello world agent for authenticated users." ,
188- url = " http://localhost:9999/" ,
189- preferredTransport = TransportProtocol .JSONRPC ,
190- additionalInterfaces = null ,
191- iconUrl = null ,
192- provider = null ,
193- version = " 1.0.1" ,
194- documentationUrl = null ,
195- capabilities = AgentCapabilities (
196- streaming = true ,
197- pushNotifications = true ,
198- stateTransitionHistory = null ,
199- extensions = null
200- ),
201- securitySchemes = null ,
202- security = null ,
203- defaultInputModes = listOf (" text" ),
204- defaultOutputModes = listOf (" text" ),
205- skills = listOf (
206- AgentSkill (
207- id = " hello_world" ,
208- name = " Returns hello world" ,
209- description = " just returns hello world" ,
210- tags = listOf (" hello world" ),
211- examples = listOf (" hi" , " hello world" ),
212- inputModes = null ,
213- outputModes = null ,
214- security = null
215- ),
216- AgentSkill (
217- id = " super_hello_world" ,
218- name = " Returns a SUPER Hello World" ,
219- description = " A more enthusiastic greeting, only for authenticated users." ,
220- tags = listOf (" hello world" , " super" , " extended" ),
221- examples = listOf (" super hi" , " give me a super hello" ),
222- inputModes = null ,
223- outputModes = null ,
224- security = null
225- )
226- ),
227- supportsAuthenticatedExtendedCard = true ,
228- signatures = null
229- )
93+ @Test
94+ override fun `test push notification configs` () =
95+ super .`test push notification configs`()
23096
23197 /* *
23298 * Extended test that wouldn't work with Python A2A SDK server, because their implementation has some problems.
@@ -240,7 +106,7 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() {
240106 val createTaskRequest = Request (
241107 data = MessageSendParams (
242108 message = Message (
243- messageId = Uuid .random().toString(),
109+ messageId = Uuid .Companion . random().toString(),
244110 role = Role .User ,
245111 parts = listOf (
246112 TextPart (" do long-running task" ),
@@ -340,7 +206,7 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() {
340206 ) = Request (
341207 data = MessageSendParams (
342208 message = Message (
343- messageId = Uuid .random().toString(),
209+ messageId = Uuid .Companion . random().toString(),
344210 role = Role .User ,
345211 parts = listOf (
346212 TextPart (" do long-running task" ),
@@ -388,7 +254,7 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() {
388254 val createTaskRequest = Request (
389255 data = MessageSendParams (
390256 message = Message (
391- messageId = Uuid .random().toString(),
257+ messageId = Uuid .Companion . random().toString(),
392258 role = Role .User ,
393259 parts = listOf (
394260 TextPart (" do cancelable task" ),
@@ -426,7 +292,7 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() {
426292 val request = Request (
427293 data = MessageSendParams (
428294 message = Message (
429- messageId = Uuid .random().toString(),
295+ messageId = Uuid .Companion . random().toString(),
430296 role = Role .User ,
431297 parts = listOf (
432298 TextPart (" hello world" ),
0 commit comments