Skip to content

Commit 1a80d6c

Browse files
committed
[a2a] Add stress tests, fix race conditions, and refine testing approach
1 parent 80abf56 commit 1a80d6c

File tree

19 files changed

+440
-246
lines changed

19 files changed

+440
-246
lines changed

a2a/a2a-client/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ kotlin {
3737
implementation(kotlin("test-junit5"))
3838
implementation(project(":a2a:a2a-test"))
3939
implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http"))
40+
implementation(project(":test-utils"))
4041

4142
implementation(libs.ktor.client.cio)
4243
implementation(libs.ktor.client.logging)

a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@ import ai.koog.a2a.transport.ClientTransport
1515
import ai.koog.a2a.transport.Request
1616
import ai.koog.a2a.transport.Response
1717
import kotlinx.coroutines.flow.Flow
18-
import kotlin.concurrent.Volatile
18+
import kotlin.concurrent.atomics.AtomicReference
19+
import kotlin.concurrent.atomics.ExperimentalAtomicApi
1920

2021
/**
2122
* A2A client responsible for sending requests to A2A server.
2223
*/
24+
@OptIn(ExperimentalAtomicApi::class)
2325
public open class A2AClient(
2426
private val transport: ClientTransport,
2527
private val agentCardResolver: AgentCardResolver,
2628
) {
27-
@Volatile
28-
protected var agentCard: AgentCard? = null
29+
protected var agentCard: AtomicReference<AgentCard?> = AtomicReference(null)
2930

3031
/**
3132
* Performs initialization logic.
@@ -41,7 +42,7 @@ public open class A2AClient(
4142
*/
4243
public open suspend fun getAgentCard(): AgentCard {
4344
return agentCardResolver.resolve().also {
44-
agentCard = it
45+
agentCard.exchange(it)
4546
}
4647
}
4748

@@ -51,7 +52,7 @@ public open class A2AClient(
5152
* @throws [IllegalStateException] if it's not initialized
5253
*/
5354
public open fun cachedAgentCard(): AgentCard {
54-
return checkNotNull(agentCard) { "Agent card is not initialized." }
55+
return checkNotNull(agentCard.load()) { "Agent card is not initialized." }
5556
}
5657

5758
/**
@@ -64,12 +65,12 @@ public open class A2AClient(
6465
request: Request<Nothing?>,
6566
ctx: ClientCallContext = ClientCallContext.Default
6667
): Response<AgentCard> {
67-
check(getAgentCard().supportsAuthenticatedExtendedCard == true) {
68+
check(cachedAgentCard().supportsAuthenticatedExtendedCard == true) {
6869
"Agent card reports that authenticated extended agent card is not supported."
6970
}
7071

7172
return transport.getAuthenticatedExtendedAgentCard(request, ctx).also {
72-
agentCard = it.data
73+
agentCard.exchange(it.data)
7374
}
7475
}
7576

a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,23 @@ package ai.koog.a2a.client
22

33
import ai.koog.a2a.test.BaseA2AProtocolTest
44
import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport
5+
import ai.koog.test.utils.DockerAvailableCondition
56
import io.ktor.client.HttpClient
67
import io.ktor.client.plugins.logging.LogLevel
78
import io.ktor.client.plugins.logging.Logging
89
import kotlinx.coroutines.test.runTest
910
import org.junit.jupiter.api.AfterAll
1011
import org.junit.jupiter.api.BeforeAll
1112
import org.junit.jupiter.api.TestInstance
12-
import org.junit.jupiter.api.condition.EnabledOnOs
13-
import org.junit.jupiter.api.condition.OS
13+
import org.junit.jupiter.api.extension.ExtendWith
1414
import org.junit.jupiter.api.parallel.Execution
1515
import org.junit.jupiter.api.parallel.ExecutionMode
1616
import org.testcontainers.containers.GenericContainer
1717
import org.testcontainers.containers.wait.strategy.Wait
1818
import org.testcontainers.junit.jupiter.Container
1919
import org.testcontainers.junit.jupiter.Testcontainers
20-
import kotlin.time.Duration.Companion.minutes
20+
import kotlin.test.Test
21+
import kotlin.time.Duration.Companion.seconds
2122

2223
/**
2324
* Integration test class for testing the JSON-RPC HTTP communication in the A2A client context.
@@ -26,7 +27,7 @@ import kotlin.time.Duration.Companion.minutes
2627
*/
2728
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
2829
@Testcontainers
29-
@EnabledOnOs(OS.LINUX)
30+
@ExtendWith(DockerAvailableCondition::class)
3031
@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.")
3132
class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() {
3233
companion object {
@@ -37,7 +38,7 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() {
3738
.waitingFor(Wait.forListeningPort())
3839
}
3940

40-
override val testTimeout = 1.minutes
41+
override val testTimeout = 10.seconds
4142

4243
private val httpClient = HttpClient {
4344
install(Logging) {
@@ -74,4 +75,36 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() {
7475
fun tearDown() = runTest {
7576
transport.close()
7677
}
78+
79+
@Test
80+
override fun `test get agent card`() =
81+
super.`test get agent card`()
82+
83+
@Test
84+
override fun `test get authenticated extended agent card`() =
85+
super.`test get authenticated extended agent card`()
86+
87+
@Test
88+
override fun `test send message`() =
89+
super.`test send message`()
90+
91+
@Test
92+
override fun `test send message streaming`() =
93+
super.`test send message streaming`()
94+
95+
@Test
96+
override fun `test get task`() =
97+
super.`test get task`()
98+
99+
@Test
100+
override fun `test cancel task`() =
101+
super.`test cancel task`()
102+
103+
@Test
104+
override fun `test resubscribe task`() =
105+
super.`test resubscribe task`()
106+
107+
@Test
108+
override fun `test push notification configs`() =
109+
super.`test push notification configs`()
77110
}

a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ public data class AgentCard(
110110
@JvmInline
111111
@Serializable
112112
public value class TransportProtocol(public val value: String) {
113-
@Suppress("MissingKDocForPublicAPI")
113+
/**
114+
* List of known transport protocols.
115+
*/
114116
public companion object {
115117
/**
116118
* JSON-RPC protocol.

a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ import ai.koog.a2a.transport.Response
3939
import ai.koog.a2a.transport.ServerCallContext
4040
import ai.koog.a2a.utils.KeyedMutex
4141
import ai.koog.a2a.utils.withLock
42+
import io.github.oshai.kotlinlogging.KotlinLogging
4243
import kotlinx.coroutines.CancellationException
44+
import kotlinx.coroutines.CompletableJob
4345
import kotlinx.coroutines.CoroutineScope
46+
import kotlinx.coroutines.Job
4447
import kotlinx.coroutines.SupervisorJob
4548
import kotlinx.coroutines.cancel
4649
import kotlinx.coroutines.flow.Flow
4750
import kotlinx.coroutines.flow.channelFlow
48-
import kotlinx.coroutines.flow.first
51+
import kotlinx.coroutines.flow.firstOrNull
4952
import kotlinx.coroutines.flow.flow
50-
import kotlinx.coroutines.flow.last
53+
import kotlinx.coroutines.flow.lastOrNull
5154
import kotlinx.coroutines.flow.map
55+
import kotlinx.coroutines.flow.onStart
5256
import kotlinx.coroutines.launch
5357

5458
/**
@@ -321,6 +325,10 @@ public open class A2AServer(
321325
protected val idGenerator: IdGenerator = UuidIdGenerator,
322326
protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()),
323327
) : RequestHandler {
328+
private companion object {
329+
private val logger = KotlinLogging.logger {}
330+
}
331+
324332
/**
325333
* Mutex for locking specific tasks by their IDs.
326334
*/
@@ -374,7 +382,7 @@ public open class A2AServer(
374382

375383
val taskId = message.taskId ?: idGenerator.generateTaskId(message)
376384

377-
val session = tasksMutex.withLock(taskId) {
385+
val (session, monitoringStarted) = tasksMutex.withLock(taskId) {
378386
// If there's a currently running session for the same task, wait for it to finish.
379387
sessionManager.getSession(taskId)?.join()
380388

@@ -412,22 +420,36 @@ public open class A2AServer(
412420
eventProcessor = eventProcessor,
413421
) {
414422
agentExecutor.execute(requestContext, eventProcessor)
415-
}.also {
416-
sessionManager.addSession(it)
423+
}.let {
424+
it to sessionManager.addSession(it)
417425
}
418426
}
419427

428+
// Signal that event collection is setup
429+
val collectionStarted: CompletableJob = Job()
430+
420431
// Subscribe to events stream and start emitting them.
421432
launch {
422433
session.events
434+
.onStart {
435+
collectionStarted.complete()
436+
}
423437
.collect { event ->
424438
send(Response(data = event, id = request.id))
425439
}
426440
}
427441

428-
// Start the session to execute the agent and wait for it to finish.
429-
// Using await here to propagate any exceptions thrown by the agent execution.
442+
// Ensure event collection is setup to stream events in response.
443+
collectionStarted.join()
444+
// Ensure monitoring is ready to monitor the session.
445+
monitoringStarted.join()
446+
447+
/*
448+
Start the session to execute the agent and wait for it to finish.
449+
Using await here to propagate any exceptions thrown by the agent execution.
450+
*/
430451
session.agentJob.await()
452+
session.join()
431453
}
432454

433455
override suspend fun onSendMessage(
@@ -440,10 +462,10 @@ public open class A2AServer(
440462

441463
val event = if (messageConfiguration?.blocking == true) {
442464
// If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished.
443-
eventStream.last()
465+
eventStream.lastOrNull()
444466
} else {
445-
eventStream.first()
446-
}
467+
eventStream.firstOrNull()
468+
} ?: throw IllegalStateException("Can't get response from the agent: event stream is empty")
447469

448470
return when (val eventData = event.data) {
449471
is Message -> Response(data = eventData, id = event.id)

a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,25 @@ import kotlinx.coroutines.flow.collect
1414
*
1515
* @property eventProcessor The session event processor
1616
* @property agentJob The execution process associated with this session's execution
17-
* @property events A stream of events generated during this session
1817
*/
1918
public class Session(
2019
public val eventProcessor: SessionEventProcessor,
2120
public val agentJob: Deferred<Unit>
2221
) {
23-
public val contextId: String get() = eventProcessor.contextId
24-
public val taskId: String get() = eventProcessor.taskId
25-
public val events: Flow<Event> get() = eventProcessor.events
22+
/**
23+
* Context ID associated with this session.
24+
*/
25+
public val contextId: String = eventProcessor.contextId
26+
27+
/**
28+
* Task ID associated with this session.
29+
*/
30+
public val taskId: String = eventProcessor.taskId
31+
32+
/**
33+
* A stream of events associated with this session.
34+
*/
35+
public val events: Flow<Event> = eventProcessor.events
2636

2737
/**
2838
* Starts the [agentJob], if it hasn't already been started.
@@ -31,7 +41,7 @@ public class Session(
3141
agentJob.start()
3242
}
3343

34-
/*
44+
/**
3545
* Suspends until the session, i.e., event stream and agent job, complete.
3646
* Waits for the event stream to finish first, to avoid triggering the agent job prematurely.
3747
* Assumes that by the time event stream is finished, agent job will already be completed or canceled.

a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ import ai.koog.a2a.server.exceptions.SessionNotActiveException
1010
import ai.koog.a2a.server.tasks.TaskStorage
1111
import kotlinx.coroutines.flow.Flow
1212
import kotlinx.coroutines.flow.MutableSharedFlow
13-
import kotlinx.coroutines.flow.emptyFlow
1413
import kotlinx.coroutines.flow.filterIsInstance
15-
import kotlinx.coroutines.flow.flow
1614
import kotlinx.coroutines.flow.map
15+
import kotlinx.coroutines.flow.onSubscription
1716
import kotlinx.coroutines.flow.takeWhile
1817
import kotlinx.coroutines.sync.Mutex
1918
import kotlinx.coroutines.sync.withLock
@@ -42,8 +41,6 @@ import kotlin.jvm.JvmInline
4241
* from the incoming request or a newly generated ID that must be used if creating a new task.
4342
* Note: This taskId might not correspond to an actually existing task initially - it serves as the
4443
* identifier that will be validated against all [TaskEvent] in this session.
45-
* @property isOpen Whether the session is open.
46-
* @property events A hot flow of events in this session that can be subscribed to.
4744
*/
4845
@OptIn(ExperimentalAtomicApi::class)
4946
public class SessionEventProcessor(
@@ -63,6 +60,10 @@ public class SessionEventProcessor(
6360
}
6461

6562
private val _isOpen: AtomicBoolean = AtomicBoolean(true)
63+
64+
/**
65+
* Whether the session is open.
66+
*/
6667
public val isOpen: Boolean get() = _isOpen.load()
6768

6869
/**
@@ -82,17 +83,15 @@ public class SessionEventProcessor(
8283
}
8384

8485
private val _events = MutableSharedFlow<FlowEvent>()
85-
public val events: Flow<Event>
86-
get() = flow {
87-
if (isOpen) {
88-
_events
89-
.takeWhile { it !is FlowEvent.Close }
90-
.filterIsInstance<FlowEvent.Data>()
91-
.map { it.data }
92-
} else {
93-
emptyFlow()
94-
}.collect(this)
95-
}
86+
87+
/**
88+
* A hot flow of events in this session that can be subscribed to.
89+
*/
90+
public val events: Flow<Event> = _events
91+
.onSubscription { if (!_isOpen.load()) emit(FlowEvent.Close) }
92+
.takeWhile { it !is FlowEvent.Close }
93+
.filterIsInstance<FlowEvent.Data>()
94+
.map { it.data }
9695

9796
/**
9897
* Sends a [Message] to the session event processor. Validates the message against the session context and updates

0 commit comments

Comments
 (0)