diff --git a/.github/workflows/release-binaries.yml b/.github/workflows/release-binaries.yml index 048777c..6fd62b4 100644 --- a/.github/workflows/release-binaries.yml +++ b/.github/workflows/release-binaries.yml @@ -13,6 +13,11 @@ on: permissions: contents: write +env: + # Force Node 20-pinned JavaScript actions onto Node 24 (silences the + # deprecation annotation until the actions update). + FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true + jobs: upload-assets: name: Build and upload binaries @@ -37,7 +42,7 @@ jobs: use_cross: false runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: ref: ${{ inputs.tag || github.ref }} diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml index f28feac..279db1b 100644 --- a/.github/workflows/release-plz.yml +++ b/.github/workflows/release-plz.yml @@ -6,20 +6,29 @@ permissions: on: push: - tags: - - "v*" + branches: + - master workflow_dispatch: +env: + # Force Node 20-pinned JavaScript actions onto Node 24 (silences the + # deprecation annotation until the actions update). + FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true + +# Standard release-plz flow with per-crate tags (see release-plz.toml): +# - `release-plz-pr` keeps an open "release" PR with the pending version bumps +# and changelog. Merging it lands the new versions on master. +# - `release-plz-release` then publishes each crate whose version is ahead of +# crates.io and pushes its `{package}-v{version}` tag. jobs: - release-plz: - name: Release-plz + release-plz-release: + name: Release-plz release runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 - ref: master - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable @@ -27,10 +36,36 @@ jobs: - name: Cache dependencies uses: Swatinem/rust-cache@v2 - - name: Run release-plz + - name: Run release-plz (release) uses: release-plz/action@v0.5 with: command: release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + release-plz-pr: + name: Release-plz PR + runs-on: ubuntu-latest + concurrency: + group: release-plz-${{ github.ref }} + cancel-in-progress: false + steps: + - name: Checkout repository + uses: actions/checkout@v5 + with: + fetch-depth: 0 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Run release-plz (release-pr) + uses: release-plz/action@v0.5 + with: + command: release-pr + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5f9705f..611faa1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,107 +1,110 @@ -name: Rust CI - -on: - push: - branches: [ "master" ] - pull_request: - branches: [ "master" ] - -env: - CARGO_TERM_COLOR: always - RUST_BACKTRACE: 1 - CARGO_INCREMENTAL: 0 - -jobs: - build: - name: Build and Test - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - - - name: Cache dependencies - uses: Swatinem/rust-cache@v2 - - - name: Build workspace - run: cargo build --workspace --verbose --locked - - - name: Run tests - run: cargo test --verbose --workspace --locked --all-targets - - clippy: - name: Clippy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: clippy - - - name: Cache dependencies - uses: Swatinem/rust-cache@v2 - - # Run clippy on a2a-rs with all features (the core crate supports --all-features) - - name: Run clippy (a2a-rs, all features) - run: cargo clippy -p a2a-rs --all-targets --all-features -- -D warnings - - # Run clippy on other workspace crates that support --all-features - - name: Run clippy (a2a-ap2, a2a-web-client, a2a-agents-common) - run: cargo clippy -p a2a-ap2 -p a2a-web-client -p a2a-agents-common --all-targets --all-features -- -D warnings - - # Run clippy on a2a-agents with specific features (mcp-server/mcp-client - # depend on a local path that is not available in CI) - - name: Run clippy (a2a-agents) - run: cargo clippy -p a2a-agents --all-targets --features "reimbursement-agent,sqlx,ap2,auth" -- -D warnings - - rustfmt: - name: Format - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt - - - name: Check formatting - run: cargo fmt --all -- --check - - docs: - name: Doc Check - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - - - name: Cache dependencies - uses: Swatinem/rust-cache@v2 - - # Check docs per-crate to avoid mcp-related path dep issues - - name: Check documentation (a2a-rs) - run: cargo doc --no-deps --all-features -p a2a-rs - - - name: Check documentation (other crates) - run: cargo doc --no-deps -p a2a-ap2 -p a2a-web-client -p a2a-agents-common - - - name: Check documentation (a2a-agents) - run: cargo doc --no-deps -p a2a-agents --features "reimbursement-agent,sqlx,ap2,auth" - - audit: - name: Security Audit - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - - name: Install cargo-audit - uses: taiki-e/install-action@cargo-audit - - name: Run cargo-audit - run: cargo audit --ignore RUSTSEC-2023-0071 - +name: Rust CI + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + CARGO_INCREMENTAL: 0 + # Run JavaScript actions still pinned to the deprecated Node 20 runtime on + # Node 24, silencing the per-run deprecation annotation until they update. + FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true + +jobs: + build: + name: Build and Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Build workspace + run: cargo build --workspace --verbose --locked + + - name: Run tests + run: cargo test --verbose --workspace --locked --all-targets + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + # Run clippy on a2a-rs with all features (the core crate supports --all-features) + - name: Run clippy (a2a-rs, all features) + run: cargo clippy -p a2a-rs --all-targets --all-features -- -D warnings + + # Run clippy on other workspace crates that support --all-features + - name: Run clippy (a2a-ap2, a2a-web-client, a2a-agents-common) + run: cargo clippy -p a2a-ap2 -p a2a-web-client -p a2a-agents-common --all-targets --all-features -- -D warnings + + # Run clippy on a2a-agents with specific features (mcp-server/mcp-client + # depend on a local path that is not available in CI) + - name: Run clippy (a2a-agents) + run: cargo clippy -p a2a-agents --all-targets --features "reimbursement-agent,sqlx,ap2,auth" -- -D warnings + + rustfmt: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt --all -- --check + + docs: + name: Doc Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + # Check docs per-crate to avoid mcp-related path dep issues + - name: Check documentation (a2a-rs) + run: cargo doc --no-deps --all-features -p a2a-rs + + - name: Check documentation (other crates) + run: cargo doc --no-deps -p a2a-ap2 -p a2a-web-client -p a2a-agents-common + + - name: Check documentation (a2a-agents) + run: cargo doc --no-deps -p a2a-agents --features "reimbursement-agent,sqlx,ap2,auth" + + audit: + name: Security Audit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + - name: Install cargo-audit + uses: taiki-e/install-action@cargo-audit + - name: Run cargo-audit + run: cargo audit --ignore RUSTSEC-2023-0071 + diff --git a/.gitignore b/.gitignore index 7d16fef..2d52c37 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ -/target -*/target -**/.claude/settings.local.json -CLAUDE.md -a2a-client/leptos/ - -# Environment files -.env -.env.* -*.env - -# Database files -*.db +/target +*/target +**/.claude/settings.local.json +CLAUDE.md +a2a-client/leptos/ + +# Environment files +.env +.env.* +*.env + +# Database files +*.db /target */target **/.claude/settings.local.json @@ -30,6 +30,9 @@ a2a-client/leptos/ # Embedded git repos (a2a-mcp uses a local clone of the MCP Rust SDK) a2a-mcp/rust-sdk/ +# Local reference clone of the official upstream A2A Rust SDK +a2aproject/ + .vscode/* !.vscode/mcp.json .claude/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 6974563..96b40b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file. ## [Unreleased] - 2026-05-24 ### Added +- **Client-side `Transport` port + JSON-RPC 2.0 client + card-driven negotiation (`a2a-rs`)**: The client gained a hexagonal transport abstraction mirroring the server side, plus a wire-compatible JSON-RPC 2.0 client so it can talk to any standard A2A agent. + - `port::client::Transport` (re-exported as `a2a_rs::Transport`) is the outbound client port — the renamed, relocated `AsyncA2AClient` with an added `protocol()` discriminator. `HttpClient` (ConnectRPC) reports `"CONNECTRPC"`. + - `JsonRpcClient` (new `jsonrpc-client` feature) implements `Transport` over the spec JSON-RPC 2.0 wire format (single `POST`, SSE for streaming), reusing the generated ProtoJSON request/response types. Its method names, error codes, and envelopes come from a shared `adapter::transport::jsonrpc_wire` module extracted from the server adapter, so the two directions are byte-compatible (proven by `tests/jsonrpc_client_interop_test.rs`, an in-process client↔server round-trip over a real socket: send/get/list/cancel, push-config CRUD, SSE subscribe, typed error mapping). + - `TransportFactory` + `TransportNegotiator` + `connect(base_url, &negotiator)` select a transport from an agent card's `supported_interfaces`, ranked by client preference (factory registration order). `default_registry()` prefers CONNECTRPC then JSON-RPC. Unit tests in `tests/transport_negotiation_test.rs`. + - `a2a-web-client`'s `WebA2AClient` now holds a `Box` (field `transport`, was `http`); `auto_connect` performs real card-driven negotiation, falling back to a direct ConnectRPC client. +- **Wire-compatible JSON-RPC 2.0 + HTTP+JSON transport (`a2a-rs`)**: Added `JsonRpcAdapter`, a sibling of `ConnectRpcAdapter` that speaks the spec-mandated JSON-RPC 2.0 and HTTP+JSON (REST) bindings for interop with the canonical `a2aproject` SDK (and the Go/C#/Python SDKs). Behind the new `jsonrpc-server` feature. + - Wraps the same inner `TaskService`; mounted at the composition edge via the `jsonrpc_router` / `rest_router` free functions (see `examples/jsonrpc_server.rs`). + - JSON-RPC: single `POST /` with all 11 methods (`SendMessage`, `GetTask`, `ListTasks`, `CancelTask`, push-config CRUD, `GetExtendedAgentCard`), `A2AError` → spec error codes (`-32001`…, `-32700`/`-32601`/`-32602`), and SSE for the two streaming methods. + - REST: official-SDK paths (no `/v1` prefix) — `POST /message:send`, `GET /tasks/{id}`, `GET /extendedAgentCard`, push-config routes — with HTTP status mapped from `A2AError`. Task custom-verbs use slash-form aliases (`/tasks/{id}/cancel`) since axum's matchit router rejects a path-param + `:`-suffix in one segment. + - The wire body reuses the `buffa`-generated proto request/response types directly: verified ProtoJSON-clean (camelCase, SCREAMING_SNAKE enums, RFC3339 timestamps, base64 `bytes`, bare `Struct` metadata, tag-free field-presence unions), so no hand-written wire DTOs are needed. Golden + behavioral tests in `tests/jsonrpc_wire_test.rs` and `tests/jsonrpc_dispatch_test.rs`. + - End-to-end router tests in `tests/jsonrpc_router_test.rs` drive the real `jsonrpc_router`/`rest_router` via `tower::ServiceExt::oneshot`: REST round-trip, the `/tasks/{id}/cancel` slash alias, 404/error-status mapping, list-via-query, the JSON-RPC envelope + version rejection, and both SSE framings (JSON-RPC wraps each event in a response envelope; REST emits the bare ProtoJSON `StreamResponse`). +- **Agent-card transport negotiation (`a2a-rs`)**: `SimpleAgentInfo` gained `with_preferred_transport` and `add_interface` so a card can advertise multiple `supportedInterfaces` (e.g. `JSONRPC` + `HTTP+JSON`) — the metadata an off-the-shelf A2A client reads to negotiate a transport. `examples/jsonrpc_server.rs` advertises both bindings it mounts. - **Native LLM Tool Calling**: Added `LlmProvider` primitives (`ToolDefinition`, `ToolCall`) to `a2a-agents-common` for standardizing function calling across models (OpenAI, Gemini). - **LLM Streaming Support (SSE)**: Added `chat_completion_stream` to stream content and fully formed tool calls in real time. - **AI/LLM Integration (Phase 3)**: Integrated `McpClientManager` into `AgentBuilder` via the `mcp-client` feature. @@ -28,6 +40,12 @@ All notable changes to this project will be documented in this file. - Documented the metadata tool-call envelope used by `McpToA2ABridge` in `lib.rs` crate-level rustdoc and the bridge's struct docs. - Added an architecture diagram to the `a2a-mcp` README. - Converted rustdoc examples in `a2a-mcp`'s `lib.rs` from `rust,ignore` to real compile-checked `no_run` doctests. +- **Kitchen-sink complex agent example (`a2a-agents`)**: `examples/complex_agent.rs` (+ `complex_agent.toml`, behind `--features mcp-server`) — a "Research Assistant" wiring every major building block in one binary: declarative TOML config, optional LLM tool-calling via `LlmProvider` (with a keyless rule-based fallback), MCP tool consumption through `McpToA2ABridge` against an in-process tool server, live SSE streaming of progress artifacts, and native A2A task lifecycle via the `TaskStatusBroadcast` mixin. +- **Builder-level streaming wiring (`a2a-agents` + `a2a-rs`)**: `AgentBuilder::with_streaming` / `AgentRuntime::with_streaming` attach a shared streaming backend that the runtime injects into the transport (`ConnectRpcAdapter::with_streaming_handler`), so `tasks/subscribe` SSE streams finally observe the broadcasts a handler emits. Backed by a new forwarding blanket `impl AsyncStreamingHandler for Arc` in `a2a-rs`. **Fixes** a gap where the builder path defaulted to a no-op streaming handler and silently dropped handler broadcasts before they reached SSE clients. + +### Changed +- **`a2a-rs` transport**: Extracted the ConnectRPC adapter's request-decoding helpers (`decode_send_config`, `list_request_to_params`, `map_update_event`) to `pub(super)` so the new JSON-RPC adapter reuses them — both transports now share a single decode/encode path against the generated proto types. +- **BREAKING — client port renamed and relocated (`a2a-rs`)**: The client trait `services::client::AsyncA2AClient` is now `port::client::Transport` (re-exported as `a2a_rs::Transport`), with a new required `fn protocol(&self) -> &str` method. `StreamItem` moved alongside it (`a2a_rs::StreamItem`). The `services::client` module and the `services::{AsyncA2AClient, StreamItem}` re-exports are gone. Call sites import `a2a_rs::Transport` / `a2a_rs::StreamItem`; method names are unchanged. `a2a-web-client`'s `WebA2AClient.http: HttpClient` field became `transport: Box`. ### Removed - Removed the printf-only `examples/minimal_example.rs` in `a2a-mcp`. diff --git a/Cargo.lock b/Cargo.lock index 01ad914..ce2c7b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,7 +31,7 @@ dependencies = [ "serde", "serde_json", "shellexpand", - "thiserror 1.0.69", + "thiserror 2.0.18", "tokio", "toml", "tower-http", @@ -44,7 +44,6 @@ dependencies = [ name = "a2a-agents-common" version = "0.3.0" dependencies = [ - "a2a-agents", "async-stream", "async-trait", "chrono", @@ -55,7 +54,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror 1.0.69", + "thiserror 2.0.18", "tokio", "tracing", "uuid", @@ -70,7 +69,7 @@ dependencies = [ "buffa-types", "serde", "serde_json", - "thiserror 1.0.69", + "thiserror 2.0.18", ] [[package]] @@ -132,8 +131,9 @@ dependencies = [ "serde", "serde_json", "sqlx", - "thiserror 1.0.69", + "thiserror 2.0.18", "tokio", + "tower", "tracing", "tracing-subscriber", "url", @@ -157,7 +157,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror 1.0.69", + "thiserror 2.0.18", "tokio", "tracing", "uuid", @@ -778,6 +778,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + [[package]] name = "chrono" version = "0.4.44" @@ -988,6 +999,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc" version = "3.4.0" @@ -1126,7 +1146,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "curve25519-dalek-derive", "digest", "fiat-crypto", @@ -1732,6 +1752,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", "wasip3", ] @@ -3192,6 +3213,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -3230,6 +3262,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + [[package]] name = "rand_xorshift" version = "0.4.0" @@ -3424,21 +3462,29 @@ checksum = "0810a9f717d9828f475fe1f629f4c305c8464b7f496c3a854b58d29e65f4058e" dependencies = [ "async-trait", "base64 0.22.1", + "bytes", "chrono", "futures", + "http", + "http-body", + "http-body-util", "pastey", "pin-project-lite", "process-wrap", + "rand 0.10.1", "rmcp-macros", "schemars 1.2.1", "serde", "serde_json", + "sse-stream", "thiserror 2.0.18", "tokio", "tokio-stream", "tokio-util", + "tower-service", "tracing", "url", + "uuid", ] [[package]] @@ -3781,7 +3827,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -3792,7 +3838,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -4100,6 +4146,19 @@ dependencies = [ "uuid", ] +[[package]] +name = "sse-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3962b63f038885f15bce2c6e02c0e7925c072f1ac86bb60fd44c5c6b762fb72" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index fb0e367..5f79ab6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,3 +2,22 @@ resolver = "2" members = ["a2a-rs", "a2a-agents", "a2a-client", "a2a-ap2", "a2a-agents-common", "a2a-mcp"] exclude = ["a2a-mcp/rust-sdk"] + +# Common dependencies shared across workspace members. Members reference these +# with `dep.workspace = true`, adding `features`/`optional`/`default-features` +# locally as needed (those compose; the version is pinned here once). Keeps the +# versions from drifting (the previous `thiserror = "1.0"` vs `"2"` split). +[workspace.dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2" +anyhow = "1.0" +chrono = { version = "0.4", features = ["serde"] } +uuid = { version = "1.4", features = ["v4"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +async-trait = "0.1" +futures = "0.3" +reqwest = { version = "0.12", default-features = false } +bon = "2.3" +tokio = "1.32" diff --git a/README.md b/README.md index db04dc2..37ea94a 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ The core library uses Cargo feature flags so you only compile what you need: ```rust use a2a_rs::{HttpClient, Message}; -use a2a_rs::port::AsyncA2AClient; +use a2a_rs::Transport; #[tokio::main] async fn main() -> Result<(), Box> { @@ -153,7 +153,9 @@ Port traits define the contracts between layers. Implement `AsyncMessageHandler` ## Protocol coverage -Implements the full A2A v1.0.0 specification: +Implements the A2A v1.0.0 protocol surface — wire-compatible with the spec, with +a couple of small, documented and backward-compatible divergences (see +[`a2a-rs` → Spec compliance](a2a-rs/README.md#spec-compliance)): - `message/send` and `message/stream` (blocking and streaming message exchange) - `tasks/get`, `tasks/list`, `tasks/cancel`, `tasks/resubscribe` @@ -162,6 +164,16 @@ Implements the full A2A v1.0.0 specification: - Security schemes: HTTP bearer, API key, OAuth2, OpenID Connect, mTLS - Task states: submitted, working, input-required, completed, canceled, failed, rejected, auth-required +Notable enhancements beyond the spec (both opt-in / backward-compatible): + +- **ConnectRPC transport.** The spec names `JSONRPC`, `GRPC`, and `HTTP+JSON`; + a2a-rs adds **ConnectRPC** as the in-tree default (advertised under the + non-spec `CONNECTRPC` binding) alongside a spec-compliant JSON-RPC 2.0 + transport. Use the JSON-RPC transport for third-party interop. +- **Gap-free SSE stream resumption via `Last-Event-ID`** (W3C SSE standard, not + an A2A spec feature). Interoperable — spec clients fall back to standard + reconnect-from-current-state — but gap-free resume only applies a2a-rs ↔ a2a-rs. + ## Testing ```bash diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..35ecba4 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,94 @@ +# A2A-RS Roadmap + +Deferred themes and not-yet-scheduled work. Pre-1.0 with only in-workspace +consumers: break cleanly and fix call sites in one PR — no deprecation shims. + +## 0.5 + +### CLI (`a2acli`) + empirical cross-SDK interop + +A small `a2acli` bin crate driving the client `Transport` port: `card`, `send`, +`stream`, `get`, `cancel`. + +- Depend on **`a2a-rs` directly** (the `client` / `http-client` / + `jsonrpc-client` features) — *not* on `a2a-client`, which drags in + axum/askama for zero CLI benefit. The reusable client core (`Transport`, + `JsonRpcClient`/`HttpClient`, transport negotiation, `subscribe_resilient`) + already lives in `a2a-rs`; `a2a-client`'s `WebA2AClient` is a thin wrapper + over it. The CLI and the web client are siblings on `a2a-rs`, not a stack. +- Promote the one ergonomic bit currently trapped in the web crate — + `auto_connect` (URL-validate → `connect` → ConnectRPC fallback) — down into + `a2a-rs` behind the `client` feature so both consumers share it. + +Doubles as the manual interop harness: point the official `a2aproject/a2acli` +at our `examples/jsonrpc_server.rs`, and/or our `JsonRpcClient` at a stock A2A +agent, to validate wire-compat against the canonical SDKs. +(`tests/jsonrpc_client_interop_test.rs` already proves our-client ↔ our-server +byte-compat; this validates against *other* SDKs.) + +### AP2 (Agent Payments Protocol) expansion + +- Expand `a2a-ap2` to fully support AP2 primitives (Payment Request, Receipt). +- Bridge AP2 with native LLM tool calling (let LLMs request and verify payments). +- Add robust tests and error handling for AP2 flows. + +### Multi-tenancy + +Thread a `tenant` through requests/storage. Today only placeholder fields exist +(`TaskPushNotificationConfig.tenant`, the proto `/{tenant}/…` routes). It +reshapes the storage/port/transport surface, so it warrants its own pass. Two +viable shapes: + +- **(a) edge tenant-routing** — a `TenantRouter` holding per-tenant storage, + resolving the tenant from the `/{tenant}/` path at the transport edge, keeping + domain/ports tenant-free (smallest blast radius, most hexagonal). +- **(b) per-request `tenant` param** threaded through every port method + + transport extraction + storage scoping, matching the official SDK exactly + (largest diff, touches every call site across all crates). + +### Durable streaming resumption + +The replay buffer is in-memory and bounded (256 events/task); beyond it, resume +falls back to the initial snapshot. A sqlx-backed event log would make +resumption survive restarts. + +### ConnectRPC SSE `Last-Event-ID` + +ConnectRPC transport has no SSE `Last-Event-ID`, so `RetryingTransport` over it +reconnects from scratch rather than resuming gap-free. + +## Release pipeline + +### aws-lc-sys + `cross` (blocked on upstream) + +`cross` is used only for `aarch64-unknown-linux-gnu` today (native cargo +elsewhere), and that works. Any *new* cross target (e.g. +`aarch64-unknown-linux-musl`) hits the `aws-lc-sys 0.41.0` "compiler bug +detected" panic. Root cause: `rustls 0.23` (pulled in by `connectrpc`, +`hyper-rustls`, `reqwest` defaults) re-enables the `aws_lc_rs` provider even +though `a2a-rs` only asks for `ring`. + +A feature-only "ring-only" fix is **blocked by `connectrpc 0.3.3`**: it exposes +no TLS feature flags and depends on `hyper-rustls`/`tokio-rustls` with their +default `aws-lc-rs` provider, so no combination of our flags removes +`aws-lc-sys`. (`sqlx` offers `tls-rustls-ring` and `reqwest` offers +`rustls-tls-*-no-provider`, but fixing only those leaves connectrpc still pulling +`aws-lc-rs`.) Cargo `[patch.crates-io]` swaps the *source*, not features, so it +can't flip connectrpc's `hyper-rustls` default either. Viable paths: + +- **(a)** upstream a `ring` feature into `connectrpc`, then set ring on + `connectrpc` + `reqwest` `rustls-tls-no-provider` + `sqlx` `tls-rustls-ring`; +- **(b)** fork/vendor `connectrpc` with + `hyper-rustls = { default-features = false, features = ["ring", …] }`; +- **(c)** leave `aws-lc-rs` in and make it cross-build — a `Cross.toml` whose + image has clang+cmake (and `AWS_LC_SYS_PREBUILT_NASM=1` on x86) — sidestepping + the provider question. Needs a reproducible `cross` env to validate. + +## Optional / nice-to-have + +- **Single bidirectional showcase** — fold `AgentToMcpBridge` (re-expose the + agent *as* MCP tools) into `complex_agent`. Already covered standalone by + `a2a-mcp/examples/bidirectional_demo.rs`; only worth it for one end-to-end demo. +- **MCP-native progress** — wire `McpToA2ABridge::with_streaming` + + `ProgressClientHandler` so downstream tool progress streams (the tool server + would need to emit `notify_progress`). Progress is handler-driven today. diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 031bd2a..0000000 --- a/TODO.md +++ /dev/null @@ -1,23 +0,0 @@ -# A2A-RS Follow-Ups and Future Work - -## Agent Payments Protocol (AP2) Integration -- Expand `a2a-ap2` crate to fully support AP2 primitives (Payment Request, Payment Receipt). -- Bridge AP2 features with native LLM tool calling (allow LLMs to request and verify payments). -- Add robust tests and error handling for AP2 flows. - -## Complex Agent Example -- Create a comprehensive "kitchen-sink" example showcasing all components: - - LLM Provider integration (OpenAI/Gemini). - - MCP tool bridging (`AgentToMcpBridge` & `McpToA2ABridge`). - - Streaming interactions to a Web Client (`a2a-client`). - - Declarative TOML configuration. - - A2A native tasks and progress tracking. - -## Streaming Improvements -- Add support for partial/incremental tool call streaming (instead of waiting for the full JSON string to parse) to allow UIs to show function call progress in real time. -- Implement robust retry mechanisms and exponential backoff for SSE stream interruptions. -- Expand streaming integrations natively into the `a2a-client` framework. - -## General -- Refine existing Rustdoc examples and ensure they are all compile-checked. -- Resolve any remaining compilation warnings across the workspace. diff --git a/a2a-agents-common/Cargo.toml b/a2a-agents-common/Cargo.toml index 8804054..ac93273 100644 --- a/a2a-agents-common/Cargo.toml +++ b/a2a-agents-common/Cargo.toml @@ -14,29 +14,29 @@ categories = ["api-bindings", "text-processing"] # No circular dependency with a2a-agents allowed # Serialization -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde = { workspace = true } +serde_json = { workspace = true } # Text processing and regex regex = "1.10" # Date and time utilities -chrono = { version = "0.4", features = ["serde"] } +chrono = { workspace = true } # LLM support -reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false } -tracing = "0.1" -uuid = { version = "1.0", features = ["v4"] } -futures = "0.3" +reqwest = { workspace = true, features = ["json", "rustls-tls", "stream"] } +tracing = { workspace = true } +uuid = { workspace = true } +futures = { workspace = true } eventsource-stream = "0.2" async-stream = "0.3" # Async support -tokio = { version = "1.32", features = ["sync", "time"], optional = true } -async-trait = "0.1" +tokio = { workspace = true, features = ["sync", "time"], optional = true } +async-trait = { workspace = true } # Error handling -thiserror = "1.0" +thiserror = { workspace = true } # Caching (optional) moka = { version = "0.12", features = ["future"], optional = true } @@ -46,8 +46,7 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies] -tokio = { version = "1.32", features = ["rt", "macros", "test-util"] } -a2a-agents = { path = "../a2a-agents" } +tokio = { workspace = true, features = ["rt", "macros", "test-util"] } [[test]] name = "integration_test" diff --git a/a2a-agents-common/src/llm/mod.rs b/a2a-agents-common/src/llm/mod.rs index ef7b1e4..ce340b3 100644 --- a/a2a-agents-common/src/llm/mod.rs +++ b/a2a-agents-common/src/llm/mod.rs @@ -5,6 +5,9 @@ use serde_json::Value; pub mod gemini; pub mod openai; +pub mod tool_call; + +pub use tool_call::{PartialToolCall, ToolCallAccumulator}; /// Represents an error returned by an LLM provider. #[derive(Debug, thiserror::Error)] diff --git a/a2a-agents-common/src/llm/tool_call.rs b/a2a-agents-common/src/llm/tool_call.rs new file mode 100644 index 0000000..e5e99ce --- /dev/null +++ b/a2a-agents-common/src/llm/tool_call.rs @@ -0,0 +1,152 @@ +//! Incremental tool-call assembly for streaming LLM responses. +//! +//! Providers stream a tool call as a sequence of +//! [`ToolCallChunk`](super::LlmStreamEvent::ToolCallChunk)s (an id, an optional +//! name, and a fragment of the JSON arguments) followed by a finalized +//! [`ToolCall`]. [`ToolCallAccumulator`] folds those chunks — keyed by call id, +//! so interleaved calls stay separate — into running [`PartialToolCall`]s that a +//! UI can render live, and reconciles the authoritative final call. + +use super::ToolCall; + +/// A tool call assembled so far from streamed chunks. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct PartialToolCall { + /// Provider-assigned call id (stable across this call's chunks). + pub id: String, + /// Function name, once a chunk has carried it. + pub name: Option, + /// JSON arguments accumulated so far (may be partial/unparseable mid-stream). + pub arguments: String, + /// Set once a finalized [`ToolCall`] has reconciled this entry. + pub complete: bool, +} + +impl PartialToolCall { + fn to_tool_call(&self) -> ToolCall { + ToolCall { + id: self.id.clone(), + name: self.name.clone().unwrap_or_default(), + arguments: self.arguments.clone(), + } + } +} + +/// Folds streamed [`ToolCallChunk`](super::LlmStreamEvent::ToolCallChunk)s into +/// complete [`ToolCall`]s, preserving first-seen order. +#[derive(Debug, Default)] +pub struct ToolCallAccumulator { + calls: Vec, +} + +impl ToolCallAccumulator { + /// Create an empty accumulator. + pub fn new() -> Self { + Self::default() + } + + fn index_of(&mut self, id: &str) -> usize { + if let Some(i) = self.calls.iter().position(|c| c.id == id) { + return i; + } + self.calls.push(PartialToolCall { + id: id.to_string(), + ..Default::default() + }); + self.calls.len() - 1 + } + + /// Apply one streamed chunk, returning the running partial for this id. A + /// non-empty `name` overrides; `args_delta` is appended. + pub fn push(&mut self, id: &str, name: Option<&str>, args_delta: &str) -> &PartialToolCall { + let idx = self.index_of(id); + let call = &mut self.calls[idx]; + if let Some(n) = name { + if !n.is_empty() { + call.name = Some(n.to_string()); + } + } + call.arguments.push_str(args_delta); + &self.calls[idx] + } + + /// Reconcile a finalized [`ToolCall`]: its name and arguments are + /// authoritative and replace whatever was accumulated, marking the entry + /// complete. + pub fn finalize(&mut self, call: ToolCall) { + let idx = self.index_of(&call.id); + let entry = &mut self.calls[idx]; + entry.name = Some(call.name); + entry.arguments = call.arguments; + entry.complete = true; + } + + /// The running partial for `id`, if any. + pub fn partial(&self, id: &str) -> Option<&PartialToolCall> { + self.calls.iter().find(|c| c.id == id) + } + + /// Calls reconciled by [`finalize`](Self::finalize), as concrete + /// [`ToolCall`]s, without clearing state. + pub fn completed(&self) -> Vec { + self.calls + .iter() + .filter(|c| c.complete) + .map(PartialToolCall::to_tool_call) + .collect() + } + + /// Drain every accumulated call as a [`ToolCall`], clearing the accumulator. + pub fn drain_completed(&mut self) -> Vec { + let out = self + .calls + .iter() + .map(PartialToolCall::to_tool_call) + .collect(); + self.calls.clear(); + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn folds_interleaved_calls_by_id() { + let mut acc = ToolCallAccumulator::new(); + acc.push("a", Some("add"), "{\"x\":"); + acc.push("b", Some("mul"), "{\"y\":"); + acc.push("a", None, "1}"); + acc.push("b", None, "2}"); + + assert_eq!(acc.partial("a").unwrap().name.as_deref(), Some("add")); + assert_eq!(acc.partial("a").unwrap().arguments, "{\"x\":1}"); + assert_eq!(acc.partial("b").unwrap().arguments, "{\"y\":2}"); + } + + #[test] + fn finalize_is_authoritative_and_marks_complete() { + let mut acc = ToolCallAccumulator::new(); + acc.push("a", Some("add"), "{\"x\":1"); // truncated mid-stream + assert!(acc.completed().is_empty()); + + acc.finalize(ToolCall { + id: "a".to_string(), + name: "add".to_string(), + arguments: "{\"x\":1,\"y\":2}".to_string(), + }); + + let done = acc.completed(); + assert_eq!(done.len(), 1); + assert_eq!(done[0].arguments, "{\"x\":1,\"y\":2}"); + } + + #[test] + fn drain_empties() { + let mut acc = ToolCallAccumulator::new(); + acc.push("a", Some("add"), "{}"); + assert_eq!(acc.drain_completed().len(), 1); + assert!(acc.drain_completed().is_empty()); + } +} diff --git a/a2a-agents/BUILDER_API.md b/a2a-agents/BUILDER_API.md index e0e6394..29ff419 100644 --- a/a2a-agents/BUILDER_API.md +++ b/a2a-agents/BUILDER_API.md @@ -99,7 +99,6 @@ url = "https://example.com" [server] host = "127.0.0.1" http_port = 8080 # Set to 0 to disable HTTP -ws_port = 8081 # Set to 0 to disable WebSocket [server.storage] type = "sqlx" @@ -170,7 +169,6 @@ Server configuration: - `host` (default: `127.0.0.1`): Host to bind to - `http_port` (default: `8080`): HTTP server port (0 to disable) -- `ws_port` (default: `8081`): WebSocket server port (0 to disable) #### `[server.storage]` - Optional (defaults to in-memory) diff --git a/a2a-agents/CHANGELOG.md b/a2a-agents/CHANGELOG.md index 6d4d78e..7ab994d 100644 --- a/a2a-agents/CHANGELOG.md +++ b/a2a-agents/CHANGELOG.md @@ -7,6 +7,42 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **MCP server over Streamable HTTP** — `run_mcp_server` can now serve a + TOML-configured agent over MCP's Streamable HTTP transport (rmcp's + `StreamableHttpService` on an `axum` router) in addition to stdio. Configure it + via a new `[features.mcp_server.http]` section (`McpHttpConfig`: `enabled`, + `host`, `port`, `path`); when `http.enabled` it takes precedence over stdio. + DNS-rebinding protection defaults to loopback-only and is tunable via + `allowed_hosts` / `allowed_origins` (empty `allowed_hosts` disables `Host` + validation for proxy-fronted public binds). Enables the + `transport-streamable-http-server` rmcp feature. New `mcp_http_agent` example + (`examples/mcp_http_agent.{rs,toml}`) plus an end-to-end `initialize`-handshake + and `Host`-allow-list integration test (`tests/mcp_http_test.rs`). +- **`AgentBuilder::with_streaming` / `AgentRuntime::with_streaming`** — attach a + shared streaming backend so `tasks/subscribe` SSE streams observe the + broadcasts a handler emits (e.g. via the `TaskStatusBroadcast` mixin). Pass the + *same* `InMemoryStreamingHandler` your handler broadcasts to (clones share + their subscriber registry); the runtime injects it into the transport via + `ConnectRpcAdapter::with_streaming_handler` and logs "📡 Streaming backend + wired into transport" when active. +- **`complex_agent` example** (`examples/complex_agent.rs` + + `examples/complex_agent.toml`, behind `--features mcp-server`) — a kitchen-sink + "Research Assistant" that wires declarative TOML config, optional LLM + tool-calling (with a keyless, deterministic rule-based fallback), MCP tool + consumption via `McpToA2ABridge` (against an in-process tool server over + `tokio::io::duplex`), live SSE streaming of progress artifacts, and native A2A + task lifecycle through the broadcast mixin. + +### Fixed + +- **Streaming through the builder reached a no-op.** `AgentRuntime::start_http` + built its transport with `ConnectRpcAdapter::new(...)`, which defaults to a + `NoopStreamingHandler` — so broadcasts from a builder-constructed handler never + reached `tasks/subscribe` SSE clients. They now do when the streaming backend + is supplied via `with_streaming` (see Added). + ## [0.3.0](https://github.com/EmilLindfors/a2a-rs/compare/a2a-agents-v0.2.0...a2a-agents-v0.3.0) - 2026-05-27 ### Fixed diff --git a/a2a-agents/Cargo.toml b/a2a-agents/Cargo.toml index 288dbef..f93d367 100644 --- a/a2a-agents/Cargo.toml +++ b/a2a-agents/Cargo.toml @@ -16,37 +16,37 @@ a2a-agents-common = { path = "../a2a-agents-common", version = "0.3" } a2a-ap2 = { path = "../a2a-ap2", version = "0.3", optional = true } a2a-client = { package = "a2a-web-client", path = "../a2a-client", version = "0.3" } a2a-mcp = { path = "../a2a-mcp", version = "0.3", optional = true } -rmcp = { version = "1.7", features = ["server", "client", "transport-io", "transport-child-process"], optional = true } +rmcp = { version = "1.7", features = ["server", "client", "transport-io", "transport-child-process", "transport-streamable-http-server"], optional = true } # Core dependencies -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde = { workspace = true } +serde_json = { workspace = true } buffa = { version = "0.3.0", features = ["json"] } buffa-types = { version = "0.3.0", features = ["json"] } toml = "0.8" -chrono = { version = "0.4", features = ["serde"] } -thiserror = "1.0" -uuid = { version = "1.4", features = ["v4", "serde"] } -bon = "2.3" +chrono = { workspace = true } +thiserror = { workspace = true } +uuid = { workspace = true, features = ["v4", "serde"] } +bon = { workspace = true } dotenvy = "0.15.7" shellexpand = "3.1" # Async foundation -tokio = { version = "1.32", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync", "time"] } -async-trait = "0.1" +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync", "time"] } +async-trait = { workspace = true } # Command line interface clap = { version = "4.4", features = ["derive"] } # Logging -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } # Required dependencies regex = "1.10" # Used for text parsing and config env var expansion # AI integration (OpenAI-compatible API) -reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false } +reqwest = { workspace = true, features = ["json", "rustls-tls"] } # OAuth2 support (optional, for auth feature) oauth2 = { version = "5.0", optional = true } @@ -57,8 +57,8 @@ askama = "0.12" askama_axum = "0.4" tower-http = { version = "0.6", features = ["fs", "cors"] } base64 = "0.22" -futures = "0.3" -anyhow = "1.0" +futures = { workspace = true } +anyhow = { workspace = true } [package.metadata.docs.rs] features = ["reimbursement-agent", "sqlx", "ap2", "auth"] @@ -73,10 +73,11 @@ auth = ["a2a-rs/auth", "dep:oauth2"] # MCP server: expose any TOML-configured agent as an MCP stdio server via # `a2a-mcp::AgentToMcpBridge`. Pulls in `a2a-mcp` and the matching `rmcp` build. mcp-server = ["dep:a2a-mcp", "dep:rmcp"] -# MCP client: lets agents call out to external MCP servers (e.g. spawned as -# child processes). The framework-level integration is still a work in progress -# — see a2a-mcp/TODO.md — but the low-level `McpClientManager` compiles and is -# usable directly. +# MCP client: lets an agent call out to external MCP servers (spawned as child +# processes), discover their tools, and invoke them while serving A2A requests. +# Build an `McpClientManager` from `[features.mcp_client]` config with +# `McpClientManager::connect`, hand it to your handler, and reach tools through +# the `McpToolsExt` trait. See `examples/mcp_client_agent.rs`. mcp-client = ["dep:rmcp"] [[bin]] @@ -89,6 +90,13 @@ name = "reimbursement_demo" path = "bin/reimbursement_demo.rs" required-features = ["reimbursement-agent"] +# Minimal MCP stdio server (echo + add) — fixture for the mcp-client example and +# integration test, so they can spawn a real MCP server with no external deps. +[[bin]] +name = "mcp_echo_server" +path = "bin/mcp_echo_server.rs" +required-features = ["mcp-client"] + [[example]] name = "test_handler" path = "examples/test_handler.rs" @@ -122,5 +130,24 @@ name = "mcp_server_agent" path = "examples/mcp_server_agent.rs" required-features = ["mcp-server"] +[[example]] +name = "mcp_http_agent" +path = "examples/mcp_http_agent.rs" +required-features = ["mcp-server"] + +# Agent that acts as an MCP *client*: connects to an MCP server (the bundled +# mcp_echo_server) from TOML config and calls its tools while serving A2A. +[[example]] +name = "mcp_client_agent" +path = "examples/mcp_client_agent.rs" +required-features = ["mcp-client"] + +# Kitchen-sink example: TOML config + optional LLM tool-calling + MCP tool +# consumption (McpToA2ABridge) + live SSE streaming + native task lifecycle. +[[example]] +name = "complex_agent" +path = "examples/complex_agent.rs" +required-features = ["mcp-server"] + diff --git a/a2a-agents/README.md b/a2a-agents/README.md index d3a2718..272040b 100644 --- a/a2a-agents/README.md +++ b/a2a-agents/README.md @@ -85,7 +85,7 @@ The original hexagonal architecture approach with manual wiring: ## 🔌 Model Context Protocol (MCP) Integration -You can expose any declarative A2A Agent as a Model Context Protocol (MCP) server over `stdio` transport. This allows MCP-compatible clients (like Claude Desktop) to invoke the agent's skills as local tools. +You can expose any declarative A2A Agent as a Model Context Protocol (MCP) server over `stdio` (for local clients like Claude Desktop) or **Streamable HTTP** (for networked clients) transport. Either way, MCP-compatible clients can invoke the agent's skills as tools. The bridge dispatches tool calls to the agent's message handler **in-process**, which means: - No backing HTTP server is required (you can set `http_port = 0` for a pure-stdio server). @@ -143,6 +143,105 @@ To connect Claude Desktop to your agent, add the following to your Claude Deskto } ``` +### 4. Streamable HTTP transport + +For networked MCP clients, serve the agent over MCP's Streamable HTTP transport +instead of stdio. Add a `[features.mcp_server.http]` section — when `enabled`, +it takes precedence over stdio: + +```toml +[features.mcp_server] +enabled = true +stdio = false + +[features.mcp_server.http] +enabled = true +host = "127.0.0.1" # default +port = 8000 # default +path = "/mcp" # default mount path +``` + +```bash +cargo run -p a2a-agents --features mcp-server --example mcp_http_agent +``` + +The server then accepts MCP requests at `http://127.0.0.1:8000/mcp`. + +**DNS-rebinding protection.** By default the transport only accepts inbound +`Host` headers for loopback (`localhost`, `127.0.0.1`, `::1`). For a public +bind, list the hostnames you serve under — and optionally restrict browser +origins: + +```toml +[features.mcp_server.http] +enabled = true +host = "0.0.0.0" +port = 8000 +allowed_hosts = ["mcp.example.com", "mcp.example.com:8000"] +allowed_origins = ["https://app.example.com"] # omit to disable Origin checks +``` + +Setting `allowed_hosts = []` disables `Host` validation entirely (accepts any +host) — only do this behind a trusted reverse proxy. + +### 5. MCP client (consume external MCP tools) + +The other direction: let your agent **call out** to MCP servers and use their +tools while it serves A2A requests. Enable the `mcp-client` Cargo feature and +declare the servers to connect to under `[features.mcp_client]`. Each server is +spawned as a child process: + +```toml +[features.mcp_client] +enabled = true + +[[features.mcp_client.servers]] +name = "echo" +command = "cargo" +args = ["run", "-q", "-p", "a2a-agents", "--features", "mcp-client", "--bin", "mcp_echo_server"] +# `env = { KEY = "value" }` and `cwd = "…"` are also supported. +``` + +In code, connect the config-declared servers into an `McpClientManager` and +hand it to the handler that will use the tools. The handler owns the manager and +reaches tools through the `McpToolsExt` trait: + +```rust +use a2a_agents::core::{AgentBuilder, AgentConfig, McpClientManager}; +use a2a_agents::traits::{McpToolsExt, extract_tool_result_text}; + +#[derive(Clone)] +struct MyHandler { mcp: McpClientManager } + +impl McpToolsExt for MyHandler { + fn mcp_client(&self) -> &McpClientManager { &self.mcp } +} + +// inside process_message: +// let result = self.call_mcp_tool("echo", "echo", Some(json!({ "text": text }))).await?; +// let reply = extract_tool_result_text(&result); + +let config = AgentConfig::from_file("agent.toml")?; +let mcp = McpClientManager::connect(&config.features.mcp_client).await?; // connects + discovers tools +AgentBuilder::new(config) + .with_handler(MyHandler { mcp }) + .with_storage(a2a_rs::InMemoryTaskStorage::new()) + .build()? + .run() + .await?; +``` + +Connection is lenient — a server that fails to start is logged and skipped, and +`connect` only errors if servers were configured but none could be reached. + +```bash +cargo run -p a2a-agents --features mcp-client --example mcp_client_agent +``` + +The example connects to the bundled `mcp_echo_server`, so it runs with no +external setup; point `command`/`args` at any MCP stdio server to talk to +something real. + ## Architecture ### ReimbursementMessageHandler @@ -284,13 +383,8 @@ This example implementation demonstrates the framework architecture but has simp ## Future Enhancements -See [TODO.md](./TODO.md) for the comprehensive modernization roadmap including: - -1. **Phase 2**: Production features (SQLx storage, authentication) -2. **Phase 3**: AI/LLM integration for natural language processing -3. **Phase 4**: Additional agent examples (document analysis, research assistant) -4. **Phase 5**: Comprehensive testing and documentation -5. **Phase 6**: Docker support and production deployment +See the workspace [ROADMAP.md](../ROADMAP.md) for deferred themes and planned +work. ## Framework Features Demonstrated diff --git a/a2a-agents/bin/a2a.rs b/a2a-agents/bin/a2a.rs index b1b304f..ff7204e 100644 --- a/a2a-agents/bin/a2a.rs +++ b/a2a-agents/bin/a2a.rs @@ -12,6 +12,7 @@ use a2a_agents_common::llm::gemini::{GeminiConfig, GeminiProvider}; use a2a_agents_common::llm::openai::{OpenAiConfig, OpenAiProvider}; use a2a_rs::{ + InMemoryStreamingHandler, domain::{A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, port::AsyncMessageHandler, }; @@ -172,7 +173,14 @@ async fn main() -> anyhow::Result<()> { } } - let handler = ReimbursementHandler::with_llm(storage.clone(), llm_provider); + let streaming = InMemoryStreamingHandler::new(); + let push = storage.push_notifier(); + let handler = ReimbursementHandler::with_llm( + storage.clone(), + streaming, + push, + llm_provider, + ); // We use build() instead of build_with_auto_storage() since we created storage manually // Note: If mcp-client is enabled, we'd need to manually initialize it here, diff --git a/a2a-agents/bin/mcp_echo_server.rs b/a2a-agents/bin/mcp_echo_server.rs new file mode 100644 index 0000000..d3a00db --- /dev/null +++ b/a2a-agents/bin/mcp_echo_server.rs @@ -0,0 +1,148 @@ +//! Minimal MCP **server** over stdio — a fixture for the `mcp-client` story. +//! +//! Exposes two tools, `echo` and `add`, over MCP's stdio transport. It exists +//! so the [`mcp_client_agent`](../examples/mcp_client_agent.rs) example and the +//! `mcp_client_test` integration test have a real MCP server to spawn as a +//! child process — no external dependencies (npx, Node, …) required. +//! +//! Run it directly to poke at it with an MCP inspector: +//! +//! ```bash +//! cargo run -p a2a-agents --features mcp-client --bin mcp_echo_server +//! ``` +//! +//! Everything it logs goes to **stderr** — stdout is reserved for the MCP wire +//! protocol, so writing anything else there would corrupt the stream. + +use rmcp::{ + ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, + model::{ + CallToolRequestParams, CallToolResult, Content, Implementation, JsonObject, + ListToolsResult, PaginatedRequestParams, ProtocolVersion, ServerCapabilities, ServerInfo, + Tool, + }, + service::RequestContext, + transport::stdio, +}; +use serde_json::json; +use std::sync::Arc; + +/// An MCP server exposing `echo` and `add`. +#[derive(Clone)] +struct EchoServer { + tools: Arc>, +} + +impl EchoServer { + fn new() -> Self { + let echo_arg: Arc = Arc::new( + serde_json::from_value(json!({ + "type": "object", + "properties": { "text": { "type": "string" } }, + "required": ["text"] + })) + .expect("valid JSON schema"), + ); + let number_pair: Arc = Arc::new( + serde_json::from_value(json!({ + "type": "object", + "properties": { + "a": { "type": "number" }, + "b": { "type": "number" } + }, + "required": ["a", "b"] + })) + .expect("valid JSON schema"), + ); + let tools = vec![ + Tool::new("echo", "Echo back the provided text", echo_arg), + Tool::new("add", "Add two numbers a + b", number_pair), + ]; + Self { + tools: Arc::new(tools), + } + } +} + +// rmcp's `ServerHandler` methods are RPITIT (`impl Future`), so they're written +// in that form here rather than with `async fn`. +#[allow(clippy::manual_async_fn)] +impl ServerHandler for EchoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_protocol_version(ProtocolVersion::V_2024_11_05) + .with_server_info(Implementation::new("mcp-echo-server", "0.1.0")) + .with_instructions("Echo and arithmetic tools for the a2a-agents mcp-client example") + } + + fn list_tools( + &self, + _request: Option, + _ctx: RequestContext, + ) -> impl std::future::Future> + Send + '_ { + async move { + Ok(ListToolsResult { + tools: (*self.tools).clone(), + next_cursor: None, + meta: None, + }) + } + } + + fn call_tool( + &self, + CallToolRequestParams { + name, arguments, .. + }: CallToolRequestParams, + _ctx: RequestContext, + ) -> impl std::future::Future> + Send + '_ { + async move { + let args = arguments.unwrap_or_default(); + let text = match name.as_ref() { + "echo" => args + .get("text") + .and_then(|v| v.as_str()) + .ok_or_else(|| McpError::invalid_params("missing 'text'", None))? + .to_string(), + "add" => { + let a = number_arg(&args, "a")?; + let b = number_arg(&args, "b")?; + (a + b).to_string() + } + other => { + return Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )); + } + }; + Ok(CallToolResult::success(vec![Content::text(text)])) + } + } +} + +fn number_arg( + args: &serde_json::Map, + key: &str, +) -> Result { + args.get(key) + .and_then(|v| v.as_f64()) + .ok_or_else(|| McpError::invalid_params(format!("missing or non-numeric '{key}'"), None)) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Logs to stderr only — stdout carries the MCP protocol. + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + tracing::info!("mcp_echo_server starting on stdio"); + let running = EchoServer::new().serve(stdio()).await?; + running.waiting().await?; + Ok(()) +} diff --git a/a2a-agents/bin/reimbursement_demo.rs b/a2a-agents/bin/reimbursement_demo.rs index 29ea636..d208480 100644 --- a/a2a-agents/bin/reimbursement_demo.rs +++ b/a2a-agents/bin/reimbursement_demo.rs @@ -5,10 +5,7 @@ use a2a_client::{ WebA2AClient, components::{MessageView, TaskView, create_sse_stream}, }; -use a2a_rs::{ - domain::{ListTasksParams, TaskState, TaskStatusUpdateEvent}, - services::AsyncA2AClient, -}; +use a2a_rs::domain::{ListTasksParams, TaskState, TaskStatusUpdateEvent}; use askama::Template; use askama_axum::IntoResponse; use axum::{ @@ -342,7 +339,6 @@ fn load_agent_config(args: &Args) -> anyhow::Result { // Override config with command-line arguments config.host = args.host.clone(); config.http_port = args.agent_http_port; - config.ws_port = args.agent_ws_port; Ok(config) } @@ -350,7 +346,6 @@ fn load_agent_config(args: &Args) -> anyhow::Result { fn print_agent_info(config: &ServerConfig, args: &Args) { println!(" 📍 Host: {}", config.host); println!(" 🔌 HTTP Port: {}", config.http_port); - println!(" 📡 WebSocket Port: {}", config.ws_port); println!(" ⚙️ Transport: {}", args.transport); match &config.storage { @@ -495,7 +490,7 @@ async fn submit_expense( let response = state .client - .http + .transport .send_task_message(&task_id, &message, None, Some(50)) .await .map_err(|e| AppError(anyhow::anyhow!("Failed to submit expense: {}", e)))?; @@ -519,7 +514,7 @@ async fn submit_expense( match state .client - .http + .transport .set_task_push_notification(&push_config) .await { @@ -559,7 +554,7 @@ async fn tasks_page( let result = state .client - .http + .transport .list_tasks(¶ms) .await .map_err(|e| AppError(anyhow::anyhow!("Failed to list tasks: {}", e)))?; @@ -583,7 +578,7 @@ async fn chat_page( let max_retries = 3; let (messages, task_state) = loop { - match state.client.http.get_task(&task_id, Some(50)).await { + match state.client.transport.get_task(&task_id, Some(50)).await { Ok(task) => { info!( "Retrieved task {} with {} history items", @@ -702,7 +697,7 @@ async fn send_message( let response = state .client - .http + .transport .send_task_message(&task_id, &message, None, Some(50)) .await .map_err(|e| AppError(anyhow::anyhow!("Failed to send message: {}", e)))?; @@ -727,7 +722,7 @@ async fn send_message( match state .client - .http + .transport .set_task_push_notification(&push_config) .await { @@ -752,7 +747,7 @@ async fn cancel_task( ) -> Result { state .client - .http + .transport .cancel_task(&task_id) .await .map_err(|e| AppError(anyhow::anyhow!("Failed to cancel task: {}", e)))?; diff --git a/a2a-agents/config.apikey.example.json b/a2a-agents/config.apikey.example.json index a9ee674..1b1acaf 100644 --- a/a2a-agents/config.apikey.example.json +++ b/a2a-agents/config.apikey.example.json @@ -1,7 +1,6 @@ { "host": "127.0.0.1", "http_port": 8080, - "ws_port": 8081, "storage": { "type": "Sqlx", "url": "sqlite:authenticated_tasks.db", diff --git a/a2a-agents/config.auth.example.json b/a2a-agents/config.auth.example.json index 8dad6d2..43916f6 100644 --- a/a2a-agents/config.auth.example.json +++ b/a2a-agents/config.auth.example.json @@ -1,7 +1,6 @@ { "host": "127.0.0.1", "http_port": 8080, - "ws_port": 8081, "storage": { "type": "InMemory" }, diff --git a/a2a-agents/config.example.json b/a2a-agents/config.example.json index 7e06ed4..adde461 100644 --- a/a2a-agents/config.example.json +++ b/a2a-agents/config.example.json @@ -1,7 +1,6 @@ { "host": "127.0.0.1", "http_port": 8080, - "ws_port": 8081, "storage": { "type": "InMemory" } diff --git a/a2a-agents/config.sqlx.example.json b/a2a-agents/config.sqlx.example.json index e31906b..e3bad69 100644 --- a/a2a-agents/config.sqlx.example.json +++ b/a2a-agents/config.sqlx.example.json @@ -1,7 +1,6 @@ { "host": "127.0.0.1", "http_port": 8080, - "ws_port": 8081, "storage": { "type": "Sqlx", "url": "sqlite:reimbursement_tasks.db", diff --git a/a2a-agents/examples/complex_agent.rs b/a2a-agents/examples/complex_agent.rs new file mode 100644 index 0000000..c005df0 --- /dev/null +++ b/a2a-agents/examples/complex_agent.rs @@ -0,0 +1,594 @@ +//! Kitchen-sink complex agent — every major a2a-rs building block in one binary. +//! +//! Run with (the `mcp-server` feature pulls in `a2a-mcp` + `rmcp`): +//! +//! ```bash +//! cargo run -p a2a-agents --example complex_agent --features mcp-server +//! ``` +//! +//! Optionally export `OPENAI_API_KEY` or `GEMINI_API_KEY` first to let an LLM +//! drive tool selection and answer in natural language. With no key set the +//! agent still works — it falls back to a deterministic rule-based router so the +//! example runs end-to-end in CI without secrets. +//! +//! What it wires together: +//! +//! * **Declarative TOML config** (`complex_agent.toml`) — identity, skills, +//! transport port, storage, and the `streaming` feature flag. +//! * **An in-process MCP tool server** exposing `add`, `multiply`, and +//! `word_count`, reached over an in-memory `tokio::io::duplex` pipe (no +//! external process). +//! * **`McpToA2ABridge`** — the agent discovers those MCP tools +//! (`get_llm_tools`) and executes them (`execute_llm_tool_call`). +//! * **Optional LLM tool-calling** via `a2a-agents-common`'s `LlmProvider`. +//! * **Live streaming to web clients** — the handler broadcasts progress +//! artifacts through the `TaskStatusBroadcast` mixin and a shared +//! `InMemoryStreamingHandler`, which the runtime now injects into the +//! transport so `tasks/subscribe` SSE streams actually observe them. +//! * **A2A native tasks & progress** — every request creates/advances a task +//! through `Working` → `Completed`/`Failed`. +//! +//! Talk to it once running (separate shell): +//! +//! ```bash +//! # Agent card +//! curl -s http://127.0.0.1:8080/.well-known/agent-card.json | jq . +//! ``` +//! +//! …or point any A2A client (e.g. the `a2a-web-client`) at the same URL and +//! subscribe to a task to watch the progress artifacts stream in. + +use std::sync::Arc; + +use a2a_agents::core::AgentBuilder; +use a2a_agents_common::llm::{ + ChatMessage, LlmProvider, LlmRequest, MessageRole, ToolCall, ToolDefinition, +}; +use a2a_mcp::McpToA2ABridge; +use a2a_rs::Artifact; +use a2a_rs::application::{HasPushNotifier, HasStreaming, HasTaskLifecycle, TaskStatusBroadcast}; +use a2a_rs::domain::{ + A2AError, ContextId, Message, Part, Role, Task, TaskArtifactUpdateEvent, TaskId, TaskState, + part, +}; +use a2a_rs::port::{ + AsyncMessageHandler, AsyncPushNotifier, AsyncStreamingHandler, AsyncTaskLifecycle, +}; +use a2a_rs::{InMemoryStreamingHandler, InMemoryTaskStorage}; +use async_trait::async_trait; +use rmcp::{ + ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, model::*, service::RequestContext, +}; +use serde_json::json; +use tracing_subscriber::EnvFilter; + +/// How many LLM ↔ tool round-trips to allow before giving up. +const MAX_TOOL_ROUNDS: usize = 4; + +const SYSTEM_PROMPT: &str = "You are a concise research assistant. You have tools \ +for arithmetic (add, multiply) and text analysis (word_count). Use a tool whenever \ +it gives an exact answer instead of guessing, then reply in one short sentence."; + +// --------------------------------------------------------------------------- +// 1. Downstream MCP tool server (in-process). +// --------------------------------------------------------------------------- + +/// A tiny MCP server exposing three tools. Mirrors the shape an external MCP +/// server (spawned as a child process) would have — here it just runs over an +/// in-memory duplex pipe so the example needs no external setup. +#[derive(Clone)] +struct ToolServer { + tools: Arc>, +} + +impl ToolServer { + fn new() -> Self { + let number_pair: Arc = Arc::new( + serde_json::from_value(json!({ + "type": "object", + "properties": { + "a": { "type": "number" }, + "b": { "type": "number" } + }, + "required": ["a", "b"] + })) + .expect("valid JSON schema"), + ); + let text_arg: Arc = Arc::new( + serde_json::from_value(json!({ + "type": "object", + "properties": { "text": { "type": "string" } }, + "required": ["text"] + })) + .expect("valid JSON schema"), + ); + + let tools = vec![ + Tool::new("add", "Add two numbers a + b", number_pair.clone()), + Tool::new("multiply", "Multiply two numbers a * b", number_pair), + Tool::new("word_count", "Count the words in a piece of text", text_arg), + ]; + Self { + tools: Arc::new(tools), + } + } +} + +// The rmcp `ServerHandler` methods are declared with explicit `impl Future` +// return types (RPITIT), so they're written in the same manual form here. +#[allow(clippy::manual_async_fn)] +impl ServerHandler for ToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_protocol_version(ProtocolVersion::V_2024_11_05) + .with_server_info(Implementation::new("kitchen-sink-tools", "0.1.0")) + .with_instructions("Arithmetic and text-analysis tools for the complex agent example") + } + + fn list_tools( + &self, + _request: Option, + _ctx: RequestContext, + ) -> impl std::future::Future> + Send + '_ { + async move { + Ok(ListToolsResult { + tools: (*self.tools).clone(), + next_cursor: None, + meta: None, + }) + } + } + + fn call_tool( + &self, + CallToolRequestParams { + name, arguments, .. + }: CallToolRequestParams, + _ctx: RequestContext, + ) -> impl std::future::Future> + Send + '_ { + async move { + let args = arguments.unwrap_or_default(); + let result = match name.as_ref() { + "add" | "multiply" => { + let a = number_arg(&args, "a")?; + let b = number_arg(&args, "b")?; + if name == "add" { a + b } else { a * b }.to_string() + } + "word_count" => { + let text = args + .get("text") + .and_then(|v| v.as_str()) + .ok_or_else(|| McpError::invalid_params("missing 'text'", None))?; + text.split_whitespace().count().to_string() + } + other => { + return Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )); + } + }; + Ok(CallToolResult::success(vec![Content::text(result)])) + } + } +} + +fn number_arg( + args: &serde_json::Map, + key: &str, +) -> Result { + args.get(key) + .and_then(|v| v.as_f64()) + .ok_or_else(|| McpError::invalid_params(format!("missing or non-numeric '{key}'"), None)) +} + +// --------------------------------------------------------------------------- +// 2. The agent handler. +// --------------------------------------------------------------------------- + +/// Inner handler the bridge requires but this example never routes through — +/// we call `execute_llm_tool_call`/`get_llm_tools` on the bridge directly and +/// keep task-lifecycle ownership in [`ResearchAssistantHandler`]. It only +/// fires if a raw `a2a_rs_tool_call` envelope arrives, which we never send. +#[derive(Clone)] +struct UnusedInner; + +#[async_trait] +impl AsyncMessageHandler for UnusedInner { + async fn process_message( + &self, + _task_id: &str, + _message: &Message, + _session_id: Option<&str>, + ) -> Result { + Err(A2AError::UnsupportedOperation( + "inner handler is not used in the complex_agent example".to_string(), + )) + } +} + +/// The agent. Owns the task-lifecycle, streaming, and push ports (so it hosts +/// the `TaskStatusBroadcast` mixin), the MCP bridge, and an optional LLM. +#[derive(Clone)] +struct ResearchAssistantHandler { + lifecycle: Arc, + streaming: Arc, + push: Arc, + bridge: Arc>, + llm: Option>, +} + +// Accessors that surface the `TaskStatusBroadcast` mixin on this handler. Every +// status transition routed through `update_and_broadcast` reaches streaming +// subscribers and push targets — see `.claude/rules/hexagonal_architecture.md` §9. +impl HasTaskLifecycle for ResearchAssistantHandler { + fn lifecycle(&self) -> &dyn AsyncTaskLifecycle { + self.lifecycle.as_ref() + } +} +impl HasStreaming for ResearchAssistantHandler { + fn streaming(&self) -> &dyn AsyncStreamingHandler { + self.streaming.as_ref() + } +} +impl HasPushNotifier for ResearchAssistantHandler { + fn push_notifier(&self) -> &dyn AsyncPushNotifier { + self.push.as_ref() + } +} + +impl ResearchAssistantHandler { + fn new( + lifecycle: impl AsyncTaskLifecycle + 'static, + streaming: impl AsyncStreamingHandler + 'static, + push: Arc, + bridge: Arc>, + llm: Option>, + ) -> Self { + Self { + lifecycle: Arc::new(lifecycle), + streaming: Arc::new(streaming), + push, + bridge, + llm, + } + } + + /// Push an incremental progress artifact to any SSE subscriber. + async fn stream_progress(&self, task_id: &str, context_id: &str, text: &str) { + let artifact = Artifact { + artifact_id: format!("progress-{task_id}"), + name: "progress".to_string(), + description: String::new(), + parts: vec![Part::text(text.to_string())], + metadata: ::buffa::MessageField::none(), + extensions: Vec::new(), + ..Default::default() + }; + let event = TaskArtifactUpdateEvent { + task_id: task_id.to_string(), + context_id: context_id.to_string(), + kind: "artifact-update".to_string(), + artifact, + append: Some(true), + last_chunk: Some(false), + metadata: None, + }; + if let Err(e) = self + .streaming + .broadcast_artifact_update(task_id, event) + .await + { + tracing::warn!("failed to broadcast progress: {e}"); + } + } + + /// LLM path: let the model pick tools, execute them via the bridge, loop + /// until it answers in prose. + async fn run_with_llm( + &self, + llm: &dyn LlmProvider, + task_id: &str, + context_id: &str, + user_text: &str, + ) -> Result { + let tools: Vec = self.bridge.get_llm_tools(); + let mut messages = vec![ + ChatMessage::system(SYSTEM_PROMPT), + ChatMessage::user(user_text), + ]; + + for _round in 0..MAX_TOOL_ROUNDS { + let mut request = LlmRequest::new(messages.clone()).temperature(0.2); + if !tools.is_empty() { + request = request.tools(tools.clone()); + } + + let response = llm + .chat_completion(request) + .await + .map_err(|e| A2AError::Internal(format!("LLM error: {e}")))?; + + match response.tool_calls { + Some(calls) if !calls.is_empty() => { + // Record the assistant turn that requested the tools… + messages.push(ChatMessage { + role: MessageRole::Assistant, + content: response.content.clone(), + tool_calls: Some(calls.clone()), + tool_call_id: None, + name: None, + }); + // …then execute each tool against MCP and feed results back. + for call in &calls { + self.stream_progress( + task_id, + context_id, + &format!("🛠️ calling `{}`({})", call.name, call.arguments), + ) + .await; + let result = self + .bridge + .execute_llm_tool_call(task_id, call) + .await + .map_err(|e| e.to_a2a_error())?; + self.stream_progress( + task_id, + context_id, + &format!("✅ `{}` → {result}", call.name), + ) + .await; + messages.push(ChatMessage::tool_result( + call.id.clone(), + call.name.clone(), + result, + )); + } + } + _ => return Ok(response.content.unwrap_or_default()), + } + } + Ok("I couldn't converge on an answer within the tool-call budget.".to_string()) + } + + /// No-LLM fallback: a deterministic router so the example runs without keys. + async fn run_rule_based( + &self, + task_id: &str, + context_id: &str, + user_text: &str, + ) -> Result { + let lower = user_text.to_lowercase(); + let make = |name: &str, args: serde_json::Value| ToolCall { + id: format!("rule-{name}"), + name: name.to_string(), + arguments: args.to_string(), + }; + + let tool_call = + if lower.contains("multipl") || lower.contains("times") || lower.contains('*') { + parse_two_numbers(user_text) + .map(|(a, b)| make("multiply", json!({ "a": a, "b": b }))) + } else if lower.contains("add") || lower.contains("plus") || lower.contains("sum") { + parse_two_numbers(user_text).map(|(a, b)| make("add", json!({ "a": a, "b": b }))) + } else if lower.contains("word") || lower.contains("count") { + let text = user_text + .split_once(':') + .map(|(_, t)| t.trim().to_string()) + .unwrap_or_else(|| user_text.to_string()); + Some(make("word_count", json!({ "text": text }))) + } else { + None + }; + + match tool_call { + Some(call) => { + self.stream_progress( + task_id, + context_id, + &format!("🛠️ (rule-based) calling `{}`", call.name), + ) + .await; + let result = self + .bridge + .execute_llm_tool_call(task_id, &call) + .await + .map_err(|e| e.to_a2a_error())?; + Ok(format!( + "The `{}` tool returned **{result}**.\n\n(Set OPENAI_API_KEY or \ + GEMINI_API_KEY to let an LLM choose tools and answer in natural language.)", + call.name + )) + } + None => { + let names: Vec = self + .bridge + .tools() + .iter() + .map(|t| t.name.to_string()) + .collect(); + Ok(format!( + "I can do simple math and text stats via MCP tools ({}).\n\ + Try: `add 21 21`, `multiply 6 7`, or `count words: the quick brown fox`.", + names.join(", ") + )) + } + } + } +} + +#[async_trait] +impl AsyncMessageHandler for ResearchAssistantHandler { + async fn process_message( + &self, + task_id: &str, + message: &Message, + _session_id: Option<&str>, + ) -> Result { + let id: TaskId = task_id.parse()?; + + // Create the task on first contact. + if !self.lifecycle.exists(&id).await? { + let raw_ctx = if message.context_id.is_empty() { + uuid::Uuid::new_v4().to_string() + } else { + message.context_id.clone() + }; + let ctx: ContextId = raw_ctx.parse()?; + self.lifecycle.create(&id, &ctx).await?; + } + let context_id = self.lifecycle.get(&id, Some(1)).await?.context_id.clone(); + + // Record the user's message and move to Working — broadcast both. + self.update_and_broadcast(&id, TaskState::Working, Some(message.clone())) + .await?; + self.stream_progress(task_id, &context_id, "🔎 Analyzing your request…") + .await; + + let user_text = extract_text(message); + let outcome = match &self.llm { + Some(llm) => { + self.run_with_llm(llm.as_ref(), task_id, &context_id, &user_text) + .await + } + None => self.run_rule_based(task_id, &context_id, &user_text).await, + }; + + let (state, reply) = match outcome { + Ok(text) => (TaskState::Completed, text), + Err(e) => (TaskState::Failed, format!("Sorry — I hit an error: {e}")), + }; + + let response = Message::builder() + .role(Role::Agent) + .parts(vec![Part::text(reply)]) + .message_id(uuid::Uuid::new_v4().to_string()) + .context_id(context_id) + .build(); + + let final_task = self + .update_and_broadcast(&id, state, Some(response)) + .await?; + Ok(final_task) + } + + async fn validate_message(&self, message: &Message) -> Result<(), A2AError> { + if message.parts.is_empty() { + return Err(A2AError::ValidationError { + field: "message.parts".to_string(), + message: "Message must contain at least one part".to_string(), + }); + } + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// 3. Helpers. +// --------------------------------------------------------------------------- + +fn extract_text(message: &Message) -> String { + message + .parts + .iter() + .filter_map(|p| match &p.content { + Some(part::Content::Text(t)) => Some(t.clone()), + _ => None, + }) + .collect::>() + .join(" ") +} + +/// Pull the first two numbers out of free text (e.g. "what is 6 times 7" → 6,7). +fn parse_two_numbers(text: &str) -> Option<(f64, f64)> { + let nums: Vec = text + .split(|c: char| !(c.is_ascii_digit() || c == '.')) + .filter(|s| !s.is_empty()) + .filter_map(|s| s.parse::().ok()) + .collect(); + match nums.as_slice() { + [a, b, ..] => Some((*a, *b)), + _ => None, + } +} + +fn load_llm() -> Option> { + use a2a_agents_common::llm::{gemini::GeminiProvider, openai::OpenAiProvider}; + if let Ok(gemini) = GeminiProvider::from_env() { + tracing::info!("🤖 LLM: Gemini (tool-calling enabled)"); + return Some(Arc::new(gemini)); + } + if let Ok(openai) = OpenAiProvider::from_env() { + tracing::info!("🤖 LLM: OpenAI (tool-calling enabled)"); + return Some(Arc::new(openai)); + } + tracing::info!("🤖 LLM: none configured — using rule-based fallback"); + None +} + +// --------------------------------------------------------------------------- +// 4. Composition root. +// --------------------------------------------------------------------------- + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + dotenvy::dotenv().ok(); + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); + + // (a) Start the in-process MCP tool server over a duplex pipe. + let (server_io, client_io) = tokio::io::duplex(8192); + let _tool_server = tokio::spawn(async move { + let running = ToolServer::new().serve(server_io).await?; + running.waiting().await?; + anyhow::Ok(()) + }); + + // (b) Connect an MCP client and hand its peer to the bridge. + // `mcp_client` is kept alive for the whole process (the run loop below + // never returns) so the peer stays connected. + let mcp_client = ().serve(client_io).await?; + let bridge = Arc::new(McpToA2ABridge::new(mcp_client.peer().clone(), UnusedInner).await?); + tracing::info!( + "🔧 MCP tools available: {}", + bridge + .tools() + .iter() + .map(|t| t.name.to_string()) + .collect::>() + .join(", ") + ); + + // (c) Shared storage + streaming. The SAME streaming instance goes to the + // handler (it broadcasts) and to the builder (the transport subscribes), + // so SSE clients see the handler's progress. Clones share the registry. + let storage = InMemoryTaskStorage::new(); + let streaming = InMemoryStreamingHandler::new(); + + // (d) Build the handler with optional LLM. + let handler = ResearchAssistantHandler::new( + storage.clone(), + streaming.clone(), + storage.push_notifier(), + bridge, + load_llm(), + ); + + // (e) Assemble from TOML and run. `with_streaming` is the new builder hook + // that bridges the handler's broadcasts to the transport's SSE streams. + println!("🚀 Complex agent listening on http://127.0.0.1:8080"); + println!(" Agent card: http://127.0.0.1:8080/.well-known/agent-card.json"); + AgentBuilder::from_file("examples/complex_agent.toml")? + .with_handler(handler) + .with_storage(storage) + .with_streaming(streaming) + .build()? + .run() + .await?; + + drop(mcp_client); + Ok(()) +} diff --git a/a2a-agents/examples/complex_agent.toml b/a2a-agents/examples/complex_agent.toml new file mode 100644 index 0000000..31f1bda --- /dev/null +++ b/a2a-agents/examples/complex_agent.toml @@ -0,0 +1,45 @@ +# Kitchen-sink "Research Assistant" agent. +# +# Drives examples/complex_agent.rs. Everything the agent advertises — identity, +# skills, transport, storage, streaming — is declared here; the Rust side only +# wires the handler, the MCP tool server, and (optionally) an LLM. + +[agent] +name = "Research Assistant (Kitchen Sink)" +description = "Answers math and text questions using MCP tools, with optional LLM-driven tool selection and live streaming progress." +version = "0.1.0" + +[agent.provider] +name = "a2a-rs examples" +url = "https://github.com/emillindfors/a2a-rs" + +[server] +host = "127.0.0.1" +http_port = 8080 + +# In-memory storage keeps the example self-contained (no DB needed). +# Swap for `type = "sqlx"` + a `url` to persist tasks across restarts. +[server.storage] +type = "inmemory" + +[[skills]] +id = "compute" +name = "Compute & analyze" +description = "Add or multiply numbers and count words via MCP tools. With an LLM key set, it answers in natural language and picks tools itself." +keywords = ["math", "add", "multiply", "word count", "tools", "mcp"] +examples = [ + "add 21 21", + "multiply 6 7", + "count words: the quick brown fox jumps over the lazy dog", + "What is 19 times 3?", +] +input_formats = ["text"] +output_formats = ["text", "data"] + +[features] +# Streaming is the headline of this example: the handler broadcasts progress +# artifacts that reach `tasks/subscribe` SSE clients because the runtime now +# shares the handler's streaming backend with the transport. +streaming = true +push_notifications = false +state_history = true diff --git a/a2a-agents/examples/jwt_auth.toml b/a2a-agents/examples/jwt_auth.toml index 68e270c..6156423 100644 --- a/a2a-agents/examples/jwt_auth.toml +++ b/a2a-agents/examples/jwt_auth.toml @@ -9,7 +9,6 @@ version = "1.0.0" [server] host = "127.0.0.1" http_port = 8080 -ws_port = 8081 # In-memory storage for this example [server.storage] diff --git a/a2a-agents/examples/keycloak_jwt.toml b/a2a-agents/examples/keycloak_jwt.toml index 272e650..d3f27b8 100644 --- a/a2a-agents/examples/keycloak_jwt.toml +++ b/a2a-agents/examples/keycloak_jwt.toml @@ -9,7 +9,6 @@ version = "1.0.0" [server] host = "127.0.0.1" http_port = 8080 -ws_port = 8081 # In-memory storage for this example [server.storage] diff --git a/a2a-agents/examples/keycloak_oauth2.toml b/a2a-agents/examples/keycloak_oauth2.toml index 24a77c9..d94f99e 100644 --- a/a2a-agents/examples/keycloak_oauth2.toml +++ b/a2a-agents/examples/keycloak_oauth2.toml @@ -9,7 +9,6 @@ version = "1.0.0" [server] host = "127.0.0.1" http_port = 8080 -ws_port = 8081 # In-memory storage for this example [server.storage] diff --git a/a2a-agents/examples/mcp_client_agent.rs b/a2a-agents/examples/mcp_client_agent.rs new file mode 100644 index 0000000..6eff1ef --- /dev/null +++ b/a2a-agents/examples/mcp_client_agent.rs @@ -0,0 +1,134 @@ +//! Agent as an MCP **client** — connect to an MCP server from TOML config and +//! call its tools while serving A2A requests. +//! +//! Run it (the `mcp-client` feature pulls in `rmcp`): +//! +//! ```bash +//! cargo run -p a2a-agents --example mcp_client_agent --features mcp-client +//! ``` +//! +//! The TOML (`mcp_client_agent.toml`) declares one downstream MCP server under +//! `[features.mcp_client]`; the framework spawns it as a child process. This +//! example points at the bundled [`mcp_echo_server`](../bin/mcp_echo_server.rs) +//! so it runs with no external setup. +//! +//! The flow that "finishes" the mcp-client integration: +//! +//! 1. Load config, then [`McpClientManager::connect`] it — this connects to +//! every configured server and discovers their tools. +//! 2. Hand the *connected* manager to the handler, which owns it and implements +//! [`McpToolsExt`] by returning a reference to it. +//! 3. The handler calls tools through `McpToolsExt` while processing messages. +//! +//! Talk to it once running (separate shell): +//! +//! ```bash +//! curl -s http://127.0.0.1:8080/.well-known/agent-card.json | jq . +//! ``` + +use a2a_agents::core::{AgentBuilder, AgentConfig, McpClientManager}; +use a2a_agents::traits::{McpToolsExt, extract_tool_result_text}; +use a2a_rs::{ + InMemoryTaskStorage, + domain::{A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, + port::AsyncMessageHandler, +}; +use async_trait::async_trait; +use serde_json::json; +use uuid::Uuid; + +/// A handler that forwards each message to the downstream MCP `echo` tool. +/// +/// It owns the [`McpClientManager`] and surfaces [`McpToolsExt`] by handing out +/// a reference to it — that's all the wiring a handler needs to use MCP tools. +#[derive(Clone)] +struct McpEchoHandler { + mcp: McpClientManager, +} + +impl McpToolsExt for McpEchoHandler { + fn mcp_client(&self) -> &McpClientManager { + &self.mcp + } +} + +#[async_trait] +impl AsyncMessageHandler for McpEchoHandler { + async fn process_message( + &self, + task_id: &str, + message: &Message, + _session_id: Option<&str>, + ) -> Result { + let text = message + .parts + .iter() + .find_map(|p| p.get_text().map(str::to_string)) + .unwrap_or_else(|| "No text provided".to_string()); + + // Call the downstream MCP `echo` tool and surface its result. + let reply = match self + .call_mcp_tool("echo", "echo", Some(json!({ "text": text }))) + .await + { + Ok(result) => format!("MCP echo says: {}", extract_tool_result_text(&result)), + Err(e) => format!("MCP tool call failed: {e}"), + }; + + let response = Message::builder() + .role(Role::Agent) + .parts(vec![Part::text(reply)]) + .message_id(Uuid::new_v4().to_string()) + .context_id(message.context_id.clone()) + .build(); + + Ok(Task::builder() + .id(task_id.to_string()) + .context_id(message.context_id.clone()) + .status(TaskStatus::new( + TaskState::Completed, + Some(response.clone()), + )) + .history(vec![message.clone(), response]) + .build()) + } + + async fn validate_message(&self, message: &Message) -> Result<(), A2AError> { + if message.parts.is_empty() { + return Err(A2AError::ValidationError { + field: "parts".to_string(), + message: "Message must contain at least one part".to_string(), + }); + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + // 1. Load config and connect to the MCP servers it declares. + let config = AgentConfig::from_file("examples/mcp_client_agent.toml")?; + let mcp = McpClientManager::connect(&config.features.mcp_client).await?; + tracing::info!("connected MCP servers: {:?}", mcp.connected_servers().await); + + // 2. Hand the connected manager to the handler. + let handler = McpEchoHandler { mcp }; + + // 3. Assemble and run. + println!("🚀 MCP echo client agent on http://127.0.0.1:8080"); + AgentBuilder::new(config) + .with_handler(handler) + .with_storage(InMemoryTaskStorage::new()) + .build()? + .run() + .await?; + + Ok(()) +} diff --git a/a2a-agents/examples/mcp_client_agent.toml b/a2a-agents/examples/mcp_client_agent.toml new file mode 100644 index 0000000..d5676f6 --- /dev/null +++ b/a2a-agents/examples/mcp_client_agent.toml @@ -0,0 +1,29 @@ +# Agent that acts as an MCP *client*: it connects to the bundled mcp_echo_server +# and exposes its `echo` tool as an A2A skill. +# +# The [[features.mcp_client.servers]] entry below spawns the server as a child +# process. It uses `cargo run` so the example works straight from the repo with +# no prior build step; point `command`/`args` at any MCP stdio server (an npx +# server, a compiled binary, …) to talk to something real. + +[agent] +name = "MCP Echo Client Agent" +description = "An A2A agent that forwards messages to an MCP echo tool" +version = "0.1.0" + +[server] +http_port = 8080 + +[[skills]] +id = "echo" +name = "Echo" +description = "Echoes your text back via a downstream MCP tool" +examples = ["Say hello"] + +[features.mcp_client] +enabled = true + +[[features.mcp_client.servers]] +name = "echo" +command = "cargo" +args = ["run", "-q", "-p", "a2a-agents", "--features", "mcp-client", "--bin", "mcp_echo_server"] diff --git a/a2a-agents/examples/mcp_http_agent.rs b/a2a-agents/examples/mcp_http_agent.rs new file mode 100644 index 0000000..82d2e90 --- /dev/null +++ b/a2a-agents/examples/mcp_http_agent.rs @@ -0,0 +1,82 @@ +//! Expose a declarative A2A agent as an MCP server over Streamable HTTP. +//! +//! This example flips both `features.mcp_server.enabled` and +//! `features.mcp_server.http.enabled` in TOML and lets `AgentBuilder` / +//! `AgentRuntime` do the rest: it serves an MCP Streamable HTTP endpoint that +//! dispatches calls to the agent handler in-process. The agent's skills are +//! callable as MCP tools by any networked MCP client. +//! +//! Requires the `mcp-server` feature: +//! +//! ```text +//! cargo run --example mcp_http_agent -p a2a-agents --features mcp-server +//! ``` +//! +//! The server then listens on the `host:port` / `path` from the TOML +//! (`http://127.0.0.1:8000/mcp` by default). Point an MCP Streamable HTTP +//! client at that URL. + +use a2a_agents::core::{AgentBuilder, BuildError}; +use a2a_rs::{ + InMemoryTaskStorage, + domain::{A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, + port::AsyncMessageHandler, +}; +use async_trait::async_trait; +use uuid::Uuid; + +#[derive(Clone)] +struct EchoHandler; + +#[async_trait] +impl AsyncMessageHandler for EchoHandler { + async fn process_message( + &self, + task_id: &str, + message: &Message, + _session_id: Option<&str>, + ) -> Result { + let text = message + .parts + .iter() + .find_map(|p| p.get_text()) + .unwrap_or("") + .to_string(); + + let response = Message::builder() + .role(Role::Agent) + .parts(vec![Part::text(format!("echo: {text}"))]) + .message_id(Uuid::new_v4().to_string()) + .build(); + + Ok(Task::builder() + .id(task_id.to_string()) + .context_id(message.context_id.clone()) + .status(TaskStatus::new( + TaskState::Completed, + Some(response.clone()), + )) + .history(vec![message.clone(), response]) + .build()) + } +} + +#[tokio::main] +async fn main() -> Result<(), BuildError> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::INFO.into()), + ) + .init(); + + AgentBuilder::from_file("examples/mcp_http_agent.toml")? + .with_handler(EchoHandler) + .with_storage(InMemoryTaskStorage::new()) + .build()? + .run() + .await + .map_err(|e| BuildError::RuntimeError(e.to_string()))?; + + Ok(()) +} diff --git a/a2a-agents/examples/mcp_http_agent.toml b/a2a-agents/examples/mcp_http_agent.toml new file mode 100644 index 0000000..710fc55 --- /dev/null +++ b/a2a-agents/examples/mcp_http_agent.toml @@ -0,0 +1,36 @@ +# MCP server agent example — Streamable HTTP transport +# Flip features.mcp_server.enabled = true and features.mcp_server.http.enabled = +# true and the runtime serves this agent over MCP's Streamable HTTP transport +# (using in-process handler dispatch, no backing A2A HTTP server required). + +[agent] +name = "Echo Agent (MCP/HTTP)" +description = "An echo agent exposed as MCP tools over Streamable HTTP" +version = "0.1.0" + +[server] +host = "127.0.0.1" +http_port = 0 + +[features.mcp_server] +enabled = true +# When http.enabled is true the HTTP transport takes precedence over stdio. +stdio = false + +[features.mcp_server.http] +enabled = true +host = "127.0.0.1" +port = 8000 +path = "/mcp" +# DNS-rebinding protection defaults to loopback-only. For a public bind, list +# the hostnames you serve under (and optionally restrict browser origins): +# allowed_hosts = ["mcp.example.com"] +# allowed_origins = ["https://app.example.com"] +# allowed_hosts = [] disables Host validation entirely (proxy-only). + +[[skills]] +id = "echo" +name = "Echo" +description = "Echoes back whatever you send" +keywords = ["echo", "repeat", "test"] +examples = ["Echo this message"] diff --git a/a2a-agents/examples/mcp_server_agent.toml b/a2a-agents/examples/mcp_server_agent.toml index 40980c3..218478d 100644 --- a/a2a-agents/examples/mcp_server_agent.toml +++ b/a2a-agents/examples/mcp_server_agent.toml @@ -10,7 +10,6 @@ version = "0.1.0" [server] host = "127.0.0.1" http_port = 0 -ws_port = 0 [features.mcp_server] enabled = true diff --git a/a2a-agents/examples/oauth2_auth.toml b/a2a-agents/examples/oauth2_auth.toml index c0a31a3..8b2e40f 100644 --- a/a2a-agents/examples/oauth2_auth.toml +++ b/a2a-agents/examples/oauth2_auth.toml @@ -9,7 +9,6 @@ version = "1.0.0" [server] host = "127.0.0.1" http_port = 8080 -ws_port = 8081 # In-memory storage for this example [server.storage] diff --git a/a2a-agents/examples/oauth2_client_credentials.toml b/a2a-agents/examples/oauth2_client_credentials.toml index d76515b..0a50d19 100644 --- a/a2a-agents/examples/oauth2_client_credentials.toml +++ b/a2a-agents/examples/oauth2_client_credentials.toml @@ -9,7 +9,6 @@ version = "1.0.0" [server] host = "127.0.0.1" http_port = 8080 -ws_port = 8081 # SQLx storage for production use [server.storage] diff --git a/a2a-agents/examples/reimbursement_builder.rs b/a2a-agents/examples/reimbursement_builder.rs index d056022..ad2e4a4 100644 --- a/a2a-agents/examples/reimbursement_builder.rs +++ b/a2a-agents/examples/reimbursement_builder.rs @@ -36,8 +36,11 @@ async fn main() -> Result<(), Box> { .await .map_err(|e| format!("Failed to create storage: {}", e))?; - // Create the handler - let handler = ReimbursementHandler::new(storage.clone()); + // Create the handler with a dedicated streaming handler and the store's + // push notifier (streaming and push are separate ports from storage). + let streaming = a2a_rs::InMemoryStreamingHandler::new(); + let push = storage.push_notifier(); + let handler = ReimbursementHandler::new(storage.clone(), streaming, push); // Build and run the agent - this is where the magic happens! // The configuration file defines all the metadata, skills, and features @@ -50,11 +53,6 @@ async fn main() -> Result<(), Box> { config.server.http_port = port_num; } } - if let Ok(port) = env::var("WS_PORT") { - if let Ok(port_num) = port.parse() { - config.server.ws_port = port_num; - } - } }) .with_handler(handler) .with_storage(storage) diff --git a/a2a-agents/examples/test_config_demo.rs b/a2a-agents/examples/test_config_demo.rs index 739d57d..6d17e7f 100644 --- a/a2a-agents/examples/test_config_demo.rs +++ b/a2a-agents/examples/test_config_demo.rs @@ -12,7 +12,6 @@ async fn main() -> Result<(), Box> { let config1 = ServerConfig { host: "127.0.0.1".to_string(), http_port: 8080, - ws_port: 8081, storage: StorageConfig::InMemory, auth: AuthConfig::None, }; @@ -25,7 +24,6 @@ async fn main() -> Result<(), Box> { let config2 = ServerConfig { host: "127.0.0.1".to_string(), http_port: 8080, - ws_port: 8081, storage: StorageConfig::Sqlx { url: "sqlite://reimbursement.db".to_string(), max_connections: 10, @@ -41,7 +39,6 @@ async fn main() -> Result<(), Box> { let config3 = ServerConfig { host: "127.0.0.1".to_string(), http_port: 8080, - ws_port: 8081, storage: StorageConfig::InMemory, auth: AuthConfig::BearerToken { tokens: vec![ @@ -59,7 +56,6 @@ async fn main() -> Result<(), Box> { let config4 = ServerConfig { host: "0.0.0.0".to_string(), http_port: 8080, - ws_port: 8081, storage: StorageConfig::Sqlx { url: "postgres://user:password@localhost/reimbursement_prod".to_string(), max_connections: 50, diff --git a/a2a-agents/examples/test_handler.rs b/a2a-agents/examples/test_handler.rs index 60eb70a..d8efe52 100644 --- a/a2a-agents/examples/test_handler.rs +++ b/a2a-agents/examples/test_handler.rs @@ -1,4 +1,5 @@ use a2a_agents::agents::reimbursement::handler::ReimbursementHandler; +use a2a_rs::InMemoryStreamingHandler; use a2a_rs::adapter::storage::InMemoryTaskStorage; use a2a_rs::domain::{Message, Part, Role}; use a2a_rs::port::message_handler::AsyncMessageHandler; @@ -10,9 +11,11 @@ async fn main() -> Result<(), Box> { // Initialize logging tracing_subscriber::fmt().with_env_filter("debug").init(); - // Create handler with in-memory task storage + // Create handler with in-memory task storage and a dedicated streaming + // handler (streaming and push are separate ports). let task_storage = InMemoryTaskStorage::new(); - let handler = ReimbursementHandler::new(task_storage); + let push = task_storage.push_notifier(); + let handler = ReimbursementHandler::new(task_storage, InMemoryStreamingHandler::new(), push); println!("=== Testing Reimbursement Handler ===\n"); diff --git a/a2a-agents/examples/test_metadata.rs b/a2a-agents/examples/test_metadata.rs index 125116a..7435be3 100644 --- a/a2a-agents/examples/test_metadata.rs +++ b/a2a-agents/examples/test_metadata.rs @@ -2,15 +2,18 @@ use serde_json::{Map, Value, json}; use uuid::Uuid; use a2a_agents::agents::reimbursement::handler::ReimbursementHandler; +use a2a_rs::InMemoryStreamingHandler; use a2a_rs::adapter::storage::InMemoryTaskStorage; use a2a_rs::domain::{Message, Part, Role}; use a2a_rs::port::message_handler::AsyncMessageHandler; #[tokio::main] async fn main() -> Result<(), Box> { - // Initialize the handler with in-memory task storage + // Initialize the handler with in-memory task storage and a dedicated + // streaming handler (streaming and push are separate ports). let task_storage = InMemoryTaskStorage::new(); - let handler = ReimbursementHandler::new(task_storage); + let push = task_storage.push_notifier(); + let handler = ReimbursementHandler::new(task_storage, InMemoryStreamingHandler::new(), push); // Example 1: Text part with metadata hints println!("=== Example 1: Text with metadata ==="); diff --git a/a2a-agents/examples/test_metrics.rs b/a2a-agents/examples/test_metrics.rs index ca36d55..6dfcf85 100644 --- a/a2a-agents/examples/test_metrics.rs +++ b/a2a-agents/examples/test_metrics.rs @@ -2,6 +2,7 @@ use serde_json::{Map, Value}; use uuid::Uuid; use a2a_agents::agents::reimbursement::handler::ReimbursementHandler; +use a2a_rs::InMemoryStreamingHandler; use a2a_rs::adapter::storage::InMemoryTaskStorage; use a2a_rs::domain::{Message, Part, Role}; use a2a_rs::port::message_handler::AsyncMessageHandler; @@ -16,9 +17,11 @@ async fn main() -> Result<(), Box> { .with_line_number(true) .init(); - // Initialize the handler with in-memory task storage + // Initialize the handler with in-memory task storage and a dedicated + // streaming handler (streaming and push are separate ports). let task_storage = InMemoryTaskStorage::new(); - let handler = ReimbursementHandler::new(task_storage); + let push = task_storage.push_notifier(); + let handler = ReimbursementHandler::new(task_storage, InMemoryStreamingHandler::new(), push); println!("=== Testing Metrics and Logging ===\n"); diff --git a/a2a-agents/examples/test_sqlx_storage.rs b/a2a-agents/examples/test_sqlx_storage.rs index 762c1e4..c707627 100644 --- a/a2a-agents/examples/test_sqlx_storage.rs +++ b/a2a-agents/examples/test_sqlx_storage.rs @@ -14,7 +14,6 @@ async fn main() -> Result<(), Box> { let config = ServerConfig { host: "127.0.0.1".to_string(), http_port: 8080, - ws_port: 8081, storage: StorageConfig::Sqlx { url: "sqlite://reimbursement_test.db".to_string(), max_connections: 5, diff --git a/a2a-agents/reimbursement.toml b/a2a-agents/reimbursement.toml index c0c8b79..ded5ab9 100644 --- a/a2a-agents/reimbursement.toml +++ b/a2a-agents/reimbursement.toml @@ -14,7 +14,6 @@ url = "https://example.org" [server] host = "127.0.0.1" http_port = 8080 -ws_port = 8081 # Storage backend [server.storage] diff --git a/a2a-agents/src/agents/reimbursement/config.rs b/a2a-agents/src/agents/reimbursement/config.rs index 98e4d78..d6dd4e6 100644 --- a/a2a-agents/src/agents/reimbursement/config.rs +++ b/a2a-agents/src/agents/reimbursement/config.rs @@ -55,9 +55,6 @@ pub struct ServerConfig { /// Port for HTTP server #[serde(default = "default_http_port")] pub http_port: u16, - /// Port for WebSocket server - #[serde(default = "default_ws_port")] - pub ws_port: u16, /// Storage backend configuration #[serde(default)] pub storage: StorageConfig, @@ -71,7 +68,6 @@ impl Default for ServerConfig { Self { host: default_host(), http_port: default_http_port(), - ws_port: default_ws_port(), storage: StorageConfig::default(), auth: AuthConfig::default(), } @@ -87,10 +83,6 @@ impl ServerConfig { .ok() .and_then(|s| s.parse().ok()) .unwrap_or_else(default_http_port), - ws_port: env::var("WS_PORT") - .ok() - .and_then(|s| s.parse().ok()) - .unwrap_or_else(default_ws_port), storage: StorageConfig::from_env(), auth: AuthConfig::from_env(), } @@ -118,10 +110,6 @@ fn default_http_port() -> u16 { 8080 } -fn default_ws_port() -> u16 { - 8081 -} - /// Authentication configuration #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(tag = "type")] diff --git a/a2a-agents/src/agents/reimbursement/handler.rs b/a2a-agents/src/agents/reimbursement/handler.rs index c583157..1f22175 100644 --- a/a2a-agents/src/agents/reimbursement/handler.rs +++ b/a2a-agents/src/agents/reimbursement/handler.rs @@ -8,15 +8,28 @@ use std::sync::{Arc, Mutex}; use tracing::{debug, error, info, instrument, warn}; use uuid::Uuid; -use a2a_rs::domain::{A2AError, Message, Part, Role, Task, TaskState, part}; +use a2a_rs::application::{HasPushNotifier, HasStreaming, HasTaskLifecycle, TaskStatusBroadcast}; +use a2a_rs::domain::{A2AError, ContextId, Message, Part, Role, Task, TaskId, TaskState, part}; use a2a_rs::port::message_handler::AsyncMessageHandler; use super::types::*; -use a2a_agents_common::llm::{ChatMessage, LlmProvider, LlmRequest}; +use a2a_agents_common::llm::{ChatMessage, LlmProvider, LlmRequest, ToolCallAccumulator}; -// NOTE: Task storage is handled by DefaultRequestProcessor + SQLx/InMemory storage +// NOTE: Task storage is handled by ConnectRpcAdapter + SQLx/InMemory storage // This handler is stateless and only processes messages +/// Build the typed metadata that marks an artifact as a (possibly partial) +/// tool-call delta, so a UI can distinguish it from a text chunk and reassemble +/// the call's arguments by `call_id`. +fn tool_call_metadata(call_id: &str, tool_name: &str, partial: bool) -> Map { + let mut m = Map::new(); + m.insert("type".to_string(), json!("tool-call")); + m.insert("call_id".to_string(), json!(call_id)); + m.insert("tool_name".to_string(), json!(tool_name)); + m.insert("partial".to_string(), json!(partial)); + m +} + /// Metrics for tracking handler performance #[derive(Debug, Default, Clone)] pub struct HandlerMetrics { @@ -80,16 +93,13 @@ impl HandlerMetrics { /// Reimbursement message handler with proper JSON parsing and validation /// Reimbursement handler that manages task history through a task manager #[derive(Clone)] -pub struct ReimbursementHandler -where - T: a2a_rs::port::AsyncTaskManager - + a2a_rs::port::AsyncStreamingHandler - + Clone - + Send - + Sync - + 'static, -{ - task_manager: T, +pub struct ReimbursementHandler { + /// Per-task lifecycle port (create/get/update_status) + task_lifecycle: Arc, + /// Streaming port for broadcasting artifact updates + streaming: Arc, + /// Push-notifier port for out-of-band webhook delivery on each transition + push_notifier: Arc, validation_rules: ValidationRules, #[allow(dead_code)] file_metadata_store: Arc>>>, @@ -97,17 +107,38 @@ where llm_provider: Option>, } +// The handler holds both the lifecycle and streaming ports, so it hosts the +// `TaskStatusBroadcast` mixin (see `.claude/rules/hexagonal_architecture.md` +// §9): every status transition it drives goes through `update_and_broadcast`, +// announcing the change to streaming subscribers. Storage is persistence-only +// and does not self-broadcast on mutation, so without these accessors the +// agent's updates — including the background worker's final state — would +// silently stop reaching subscribers and push notifications. +impl HasTaskLifecycle for ReimbursementHandler { + fn lifecycle(&self) -> &dyn a2a_rs::port::AsyncTaskLifecycle { + self.task_lifecycle.as_ref() + } +} + +impl HasStreaming for ReimbursementHandler { + fn streaming(&self) -> &dyn a2a_rs::port::AsyncStreamingHandler { + self.streaming.as_ref() + } +} + +impl HasPushNotifier for ReimbursementHandler { + fn push_notifier(&self) -> &dyn a2a_rs::port::AsyncPushNotifier { + self.push_notifier.as_ref() + } +} + #[allow(dead_code)] -impl ReimbursementHandler -where - T: a2a_rs::port::AsyncTaskManager - + a2a_rs::port::AsyncStreamingHandler - + Clone - + Send - + Sync - + 'static, -{ - pub fn new(task_manager: T) -> Self { +impl ReimbursementHandler { + pub fn new( + task_lifecycle: impl a2a_rs::port::AsyncTaskLifecycle + 'static, + streaming: impl a2a_rs::port::AsyncStreamingHandler + 'static, + push_notifier: impl a2a_rs::port::AsyncPushNotifier + 'static, + ) -> Self { // Try to initialize AI client from environment let llm_provider: Option> = if let Ok(gemini) = a2a_agents_common::llm::gemini::GeminiProvider::from_env() @@ -122,12 +153,19 @@ where None }; - Self::with_llm(task_manager, llm_provider) + Self::with_llm(task_lifecycle, streaming, push_notifier, llm_provider) } - pub fn with_llm(task_manager: T, llm_provider: Option>) -> Self { + pub fn with_llm( + task_lifecycle: impl a2a_rs::port::AsyncTaskLifecycle + 'static, + streaming: impl a2a_rs::port::AsyncStreamingHandler + 'static, + push_notifier: impl a2a_rs::port::AsyncPushNotifier + 'static, + llm_provider: Option>, + ) -> Self { Self { - task_manager, + task_lifecycle: Arc::new(task_lifecycle), + streaming: Arc::new(streaming), + push_notifier: Arc::new(push_notifier), validation_rules: ValidationRules::default(), file_metadata_store: Arc::new(Mutex::new(HashMap::new())), metrics: Arc::new(Mutex::new(HandlerMetrics::default())), @@ -398,6 +436,7 @@ Example response when asking for info: let mut ai_response = String::new(); let artifact_id = uuid::Uuid::new_v4().to_string(); + let mut tool_calls = ToolCallAccumulator::new(); use futures::StreamExt; @@ -428,24 +467,28 @@ Example response when asking for info: }; let _ = self - .task_manager + .streaming .broadcast_artifact_update(task_id, update_event) .await; } Ok(a2a_agents_common::llm::LlmStreamEvent::ToolCallChunk { - arguments, + id, name, - .. + arguments, }) => { - // Similar to above, but for tool calls + // Stream the tool-call delta as a typed, structured artifact so + // a UI can distinguish it from text and reassemble arguments by + // call id (see the `tool-call` metadata marker). ai_response.push_str(&arguments); + let partial = tool_calls.push(&id, name.as_deref(), &arguments); + let tool_name = partial + .name + .clone() + .unwrap_or_else(|| "unknown".to_string()); let artifact = Artifact { artifact_id: artifact_id.clone(), - name: format!( - "Tool Call: {}", - name.unwrap_or_else(|| "Unknown".to_string()) - ), + name: format!("Tool Call: {tool_name}"), description: String::new(), parts: vec![Part::text(arguments)], metadata: buffa::MessageField::none(), @@ -460,16 +503,44 @@ Example response when asking for info: artifact, append: Some(true), last_chunk: Some(false), - metadata: None, + metadata: Some(tool_call_metadata(&id, &tool_name, true)), }; let _ = self - .task_manager + .streaming .broadcast_artifact_update(task_id, update_event) .await; } - Ok(a2a_agents_common::llm::LlmStreamEvent::ToolCall(_)) => { - // Ignore final tool call structure for now + Ok(a2a_agents_common::llm::LlmStreamEvent::ToolCall(call)) => { + // Reconcile the authoritative final call and emit a terminal, + // non-partial artifact carrying the complete arguments. + let metadata = tool_call_metadata(&call.id, &call.name, false); + let artifact = Artifact { + artifact_id: artifact_id.clone(), + name: format!("Tool Call: {}", call.name), + description: String::new(), + parts: vec![Part::text(call.arguments.clone())], + metadata: buffa::MessageField::none(), + extensions: Vec::new(), + ..Default::default() + }; + + let update_event = a2a_rs::domain::TaskArtifactUpdateEvent { + task_id: task_id.to_string(), + context_id: current_message.context_id.clone(), + kind: "artifact-update".to_string(), + artifact, + append: Some(false), + last_chunk: Some(true), + metadata: Some(metadata), + }; + + tool_calls.finalize(call); + + let _ = self + .streaming + .broadcast_artifact_update(task_id, update_event) + .await; } Err(e) => { tracing::error!("LLM Stream error: {}", e); @@ -1151,7 +1222,7 @@ Example response when asking for info: ProcessingStatus::Pending }; - // NOTE: Task storage is handled by DefaultRequestProcessor + // NOTE: Task storage is handled by ConnectRpcAdapter // We just process the message and return a response // (receipts metadata is still tracked in file_metadata_store if needed) @@ -1177,7 +1248,7 @@ Example response when asking for info: }) } ReimbursementRequest::StatusQuery { request_id } => { - // NOTE: Actual task status is managed by DefaultRequestProcessor/SQLx + // NOTE: Actual task status is managed by ConnectRpcAdapter/SQLx // This handler doesn't have access to task storage directly // The client should use tasks/get instead for real status { @@ -1366,15 +1437,7 @@ Example response when asking for info: } #[async_trait] -impl AsyncMessageHandler for ReimbursementHandler -where - T: a2a_rs::port::AsyncTaskManager - + a2a_rs::port::AsyncStreamingHandler - + Clone - + Send - + Sync - + 'static, -{ +impl AsyncMessageHandler for ReimbursementHandler { #[instrument(skip(self, message), fields( task_id = %task_id, message_id = %message.message_id, @@ -1393,9 +1456,11 @@ where ); info!("Processing reimbursement request"); + let id: TaskId = task_id.parse()?; + // Check if task exists and get its current state - let existing_task = if self.task_manager.task_exists(task_id).await? { - Some(self.task_manager.get_task(task_id, Some(50)).await?) + let existing_task = if self.task_lifecycle.exists(&id).await? { + Some(self.task_lifecycle.get(&id, Some(50)).await?) } else { None }; @@ -1407,7 +1472,8 @@ where } else { message.context_id.clone() }; - self.task_manager.create_task(task_id, &context_id).await?; + let context_id: ContextId = context_id.parse()?; + self.task_lifecycle.create(&id, &context_id).await?; } // Check if this task already has a completed/approved expense @@ -1416,8 +1482,7 @@ where if task.status.state == TaskState::Completed { // This is a follow-up to a completed task // Add user message to history - self.task_manager - .update_task_status(task_id, TaskState::Working, Some(message.clone())) + self.update_and_broadcast(&id, TaskState::Working, Some(message.clone())) .await?; // Send a simple acknowledgment @@ -1430,8 +1495,7 @@ where // Add agent response and return let final_task = self - .task_manager - .update_task_status(task_id, TaskState::Completed, Some(response_message)) + .update_and_broadcast(&id, TaskState::Completed, Some(response_message)) .await?; return Ok(final_task); @@ -1439,8 +1503,7 @@ where } // Add the user's message to history first - self.task_manager - .update_task_status(task_id, TaskState::Working, Some(message.clone())) + self.update_and_broadcast(&id, TaskState::Working, Some(message.clone())) .await?; // Send immediate acknowledgment @@ -1451,13 +1514,13 @@ where .context_id(message.context_id.clone()) .build(); - self.task_manager - .update_task_status(task_id, TaskState::Working, Some(ack_message)) + self.update_and_broadcast(&id, TaskState::Working, Some(ack_message)) .await?; // Clone what we need for the background task let handler = self.clone(); let task_id_owned = task_id.to_string(); + let id_owned = id.clone(); let message_owned = message.clone(); let context_id = message.context_id.clone(); @@ -1472,11 +1535,7 @@ where let text_content = handler.extract_text_from_message(&message_owned); // Get current task for context - let current_task = match handler - .task_manager - .get_task(&task_id_owned, Some(50)) - .await - { + let current_task = match handler.task_lifecycle.get(&id_owned, Some(50)).await { Ok(task) => { info!(task_id = %task_id_owned, history_count = task.history.len(), "Retrieved task for AI processing"); Some(task) @@ -1551,8 +1610,7 @@ where // Update task with AI response info!(task_id = %task_id_owned, new_state = ?task_state, "Updating task with AI response"); match handler - .task_manager - .update_task_status(&task_id_owned, task_state, Some(response_message)) + .update_and_broadcast(&id_owned, task_state, Some(response_message)) .await { Ok(updated_task) => { @@ -1578,7 +1636,7 @@ where info!(task_id = %task_id, "Returning immediate acknowledgment, AI processing in background"); // Get the updated task with the acknowledgment message - let final_task = self.task_manager.get_task(task_id, Some(50)).await?; + let final_task = self.task_lifecycle.get(&id, Some(50)).await?; Ok(final_task) } @@ -1677,3 +1735,128 @@ where Ok(()) } } + +#[cfg(test)] +mod tool_call_streaming_tests { + use super::*; + use a2a_agents_common::llm::{ + LlmError, LlmProvider, LlmResponse, LlmStreamEvent, ToolCall as LlmToolCall, + }; + use a2a_rs::adapter::{InMemoryStreamingHandler, InMemoryTaskStorage}; + use a2a_rs::port::{AsyncStreamingHandler, NoopPushNotifier, UpdateEvent}; + use futures::StreamExt; + use futures::stream::{self, BoxStream}; + + /// A fake provider that streams a tool call as a partial chunk followed by the + /// finalized call — the exact shape the handler turns into `tool-call` + /// metadata on broadcast artifact updates. + struct ToolCallProvider; + + #[async_trait] + impl LlmProvider for ToolCallProvider { + async fn chat_completion(&self, _request: LlmRequest) -> Result { + Err(LlmError::ProviderError("unused in this test".to_string())) + } + + async fn chat_completion_stream( + &self, + _request: LlmRequest, + ) -> Result>, LlmError> { + let events = vec![ + Ok(LlmStreamEvent::ToolCallChunk { + id: "call_1".to_string(), + name: Some("create_reimbursement".to_string()), + arguments: "{\"amount\":".to_string(), + }), + Ok(LlmStreamEvent::ToolCall(LlmToolCall { + id: "call_1".to_string(), + name: "create_reimbursement".to_string(), + arguments: "{\"amount\":50}".to_string(), + })), + ]; + Ok(stream::iter(events).boxed()) + } + } + + /// End-to-end: a tool call streamed by the provider becomes broadcast artifact + /// updates whose `tool-call` metadata reaches a `combined_update_stream` + /// subscriber — the same stream the SSE transport serializes to clients. The + /// accumulator and `tool_call_metadata` builder are unit-tested in isolation; + /// this covers the broadcast wiring between them, previously only + /// compile-checked. + #[tokio::test] + async fn tool_call_metadata_reaches_stream_subscribers() { + let streaming = InMemoryStreamingHandler::new(); + let handler = ReimbursementHandler::with_llm( + InMemoryTaskStorage::new(), + streaming.clone(), + NoopPushNotifier, + Some(Arc::new(ToolCallProvider) as Arc), + ); + + let task_id = "task-tc"; + let message = Message::user_text("reimburse me".to_string(), "m1".to_string()); + + // Drive the AI streaming path directly. The final JSON parse may fail (tool + // arguments aren't a ReimbursementResponse), but the artifact broadcasts — + // the wiring under test — happen during the stream, before that parse. + let _ = handler + .process_with_ai(task_id, "reimburse me", None, &message) + .await; + + // Replay everything buffered for the task (id > 0): the SSE source. + let mut updates = streaming + .combined_update_stream(task_id, Some(0)) + .await + .unwrap(); + + let mut tool_call_updates = Vec::new(); + while let Ok(Some(Ok(seq))) = + tokio::time::timeout(std::time::Duration::from_secs(2), updates.next()).await + { + if let UpdateEvent::ArtifactUpdate(ev) = seq.event { + let is_tool_call = ev + .metadata + .as_ref() + .and_then(|m| m.get("type")) + .and_then(Value::as_str) + == Some("tool-call"); + if is_tool_call { + tool_call_updates.push(ev); + } + } + } + + // Both the partial chunk and the finalized call surfaced as tool-call + // metadata, carrying the call id and tool name end-to-end. + assert_eq!( + tool_call_updates.len(), + 2, + "expected partial + final tool-call artifact updates" + ); + for ev in &tool_call_updates { + let meta = ev.metadata.as_ref().unwrap(); + assert_eq!(meta.get("call_id").and_then(Value::as_str), Some("call_1")); + assert_eq!( + meta.get("tool_name").and_then(Value::as_str), + Some("create_reimbursement") + ); + } + let partials: Vec = tool_call_updates + .iter() + .map(|ev| { + ev.metadata + .as_ref() + .unwrap() + .get("partial") + .and_then(Value::as_bool) + .unwrap() + }) + .collect(); + assert_eq!( + partials, + vec![true, false], + "partial streamed chunk, then the finalized non-partial call" + ); + } +} diff --git a/a2a-agents/src/agents/reimbursement/plugin.rs b/a2a-agents/src/agents/reimbursement/plugin.rs index b15f370..479c453 100644 --- a/a2a-agents/src/agents/reimbursement/plugin.rs +++ b/a2a-agents/src/agents/reimbursement/plugin.rs @@ -10,15 +10,7 @@ use super::handler::ReimbursementHandler; /// Implement AgentPlugin for ReimbursementHandler with InMemoryTaskStorage #[async_trait] -impl AgentPlugin for ReimbursementHandler -where - T: a2a_rs::port::AsyncTaskManager - + a2a_rs::port::AsyncStreamingHandler - + Clone - + Send - + Sync - + 'static, -{ +impl AgentPlugin for ReimbursementHandler { fn name(&self) -> &str { "Reimbursement Agent" } diff --git a/a2a-agents/src/agents/reimbursement/server.rs b/a2a-agents/src/agents/reimbursement/server.rs index 73e6c03..1f027a7 100644 --- a/a2a-agents/src/agents/reimbursement/server.rs +++ b/a2a-agents/src/agents/reimbursement/server.rs @@ -1,8 +1,12 @@ +use std::sync::Arc; + use a2a_rs::adapter::{ - BearerTokenAuthenticator, DefaultRequestProcessor, HttpPushNotificationSender, HttpServer, - InMemoryTaskStorage, SimpleAgentInfo, + BearerTokenAuthenticator, ConnectRpcAdapter, HttpPushNotificationSender, HttpServer, + InMemoryStreamingHandler, InMemoryTaskStorage, SimpleAgentInfo, +}; +use a2a_rs::port::{ + AsyncNotificationManager, AsyncPushNotifier, AsyncTaskLifecycle, AsyncTaskQuery, }; -use a2a_rs::port::{AsyncNotificationManager, AsyncTaskManager}; // SQLx storage support (feature-gated) #[cfg(feature = "sqlx")] @@ -22,7 +26,6 @@ impl ReimbursementServer { let config = ServerConfig { host, http_port: port, - ws_port: port + 1, storage: StorageConfig::default(), auth: AuthConfig::default(), }; @@ -76,7 +79,8 @@ impl ReimbursementServer { match &self.config.storage { StorageConfig::InMemory => { let storage = self.create_in_memory_storage(); - self.start(storage).await + let push = storage.push_notifier(); + self.start(storage, push).await } #[cfg(feature = "sqlx")] StorageConfig::Sqlx { @@ -87,7 +91,8 @@ impl ReimbursementServer { let storage = self .create_sqlx_storage(url, *max_connections, *enable_logging) .await?; - self.start(storage).await + let push = storage.push_notifier(); + self.start(storage, push).await } #[cfg(not(feature = "sqlx"))] StorageConfig::Sqlx { .. } => { @@ -96,20 +101,33 @@ impl ReimbursementServer { } } - /// Start HTTP server - pub async fn start(&self, storage: S) -> Result<(), Box> + /// Start HTTP server. + /// + /// Streaming fan-out lives in a dedicated [`InMemoryStreamingHandler`] shared + /// between the message handler (which broadcasts) and the transport processor + /// (which registers subscribers), so a streaming client sees the handler's + /// transitions. Push delivery uses the store's own notifier so configs set + /// via the notification API are honored. + pub async fn start( + &self, + storage: S, + push: Arc, + ) -> Result<(), Box> where - S: AsyncTaskManager + S: AsyncTaskLifecycle + + AsyncTaskQuery + AsyncNotificationManager - + a2a_rs::port::AsyncStreamingHandler + Clone + Send + Sync + 'static, { - // Create message handler with storage for history management - let message_handler = ReimbursementHandler::new(storage.clone()); - self.start_with_handler(message_handler, storage).await + let streaming = InMemoryStreamingHandler::new(); + // Create message handler sharing the streaming + push ports. + let message_handler = + ReimbursementHandler::new(storage.clone(), streaming.clone(), push.clone()); + self.start_with_handler(message_handler, storage, streaming, push) + .await } /// Start HTTP server with specific handler @@ -117,11 +135,13 @@ impl ReimbursementServer { &self, message_handler: H, storage: S, + streaming: InMemoryStreamingHandler, + push: Arc, ) -> Result<(), Box> where - S: AsyncTaskManager + S: AsyncTaskLifecycle + + AsyncTaskQuery + AsyncNotificationManager - + a2a_rs::port::AsyncStreamingHandler + Clone + Send + Sync @@ -162,13 +182,16 @@ impl ReimbursementServer { Some(vec!["text".to_string(), "data".to_string()]), ); - // Create processor with separate handlers and agent info - let processor = DefaultRequestProcessor::new( + // Create processor with separate handlers and agent info, sharing the + // streaming handler with the message handler and the store's push notifier. + let processor = ConnectRpcAdapter::new( message_handler, - storage.clone(), // storage implements AsyncTaskManager + storage.clone(), // storage implements AsyncTaskLifecycle + AsyncTaskQuery storage, // storage also implements AsyncNotificationManager agent_info.clone(), - ); + ) + .with_streaming_handler(streaming) + .with_push_notifier(push); // Create HTTP server let bind_address = format!("{}:{}", self.config.host, self.config.http_port); diff --git a/a2a-agents/src/core/builder.rs b/a2a-agents/src/core/builder.rs index 09f0565..837e7f0 100644 --- a/a2a-agents/src/core/builder.rs +++ b/a2a-agents/src/core/builder.rs @@ -3,26 +3,17 @@ //! Provides a fluent API for building agents from configuration files //! or programmatically with minimal boilerplate. -#[cfg(feature = "mcp-client")] -use crate::core::McpClientManager; use crate::core::config::{AgentConfig, ConfigError, StorageConfig}; use crate::core::runtime::AgentRuntime; -use a2a_rs::domain::{ - A2AError, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, TaskState, - TaskStatusUpdateEvent, -}; +use a2a_rs::domain::{A2AError, ContextId, Task, TaskId, TaskPushNotificationConfig, TaskState}; use a2a_rs::port::{ - AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - StreamingSubscriber, UpdateEvent, + AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, + AsyncTaskQuery, }; use a2a_rs::{HttpPushNotificationSender, InMemoryTaskStorage}; use async_trait::async_trait; -use futures::Stream; use std::path::Path; -use std::pin::Pin; use std::sync::Arc; -#[cfg(feature = "mcp-client")] -use tracing::info; #[cfg(feature = "sqlx")] use a2a_rs::adapter::storage::SqlxTaskStorage; @@ -37,82 +28,110 @@ pub enum AutoStorage { } #[async_trait] -impl AsyncTaskManager for AutoStorage { - async fn create_task(&self, task_id: &str, context_id: &str) -> Result { +impl AsyncTaskLifecycle for AutoStorage { + async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result { match self { - AutoStorage::InMemory(s) => s.create_task(task_id, context_id).await, + AutoStorage::InMemory(s) => s.create(id, context_id).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.create_task(task_id, context_id).await, + AutoStorage::Sqlx(s) => s.create(id, context_id).await, } } - async fn get_task(&self, task_id: &str, history_length: Option) -> Result { + async fn get(&self, id: &TaskId, history_length: Option) -> Result { match self { - AutoStorage::InMemory(s) => s.get_task(task_id, history_length).await, + AutoStorage::InMemory(s) => s.get(id, history_length).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.get_task(task_id, history_length).await, + AutoStorage::Sqlx(s) => s.get(id, history_length).await, } } - async fn update_task_status( + async fn update_status( &self, - task_id: &str, + id: &TaskId, state: TaskState, message: Option, ) -> Result { match self { - AutoStorage::InMemory(s) => s.update_task_status(task_id, state, message).await, + AutoStorage::InMemory(s) => s.update_status(id, state, message).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.update_task_status(task_id, state, message).await, + AutoStorage::Sqlx(s) => s.update_status(id, state, message).await, } } - async fn cancel_task(&self, task_id: &str) -> Result { + async fn cancel(&self, id: &TaskId) -> Result { match self { - AutoStorage::InMemory(s) => s.cancel_task(task_id).await, + AutoStorage::InMemory(s) => s.cancel(id).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.cancel_task(task_id).await, + AutoStorage::Sqlx(s) => s.cancel(id).await, } } - async fn task_exists(&self, task_id: &str) -> Result { + async fn exists(&self, id: &TaskId) -> Result { match self { - AutoStorage::InMemory(s) => s.task_exists(task_id).await, + AutoStorage::InMemory(s) => s.exists(id).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.task_exists(task_id).await, + AutoStorage::Sqlx(s) => s.exists(id).await, + } + } +} + +#[async_trait] +impl AsyncTaskQuery for AutoStorage { + async fn list( + &self, + params: &a2a_rs::domain::ListTasksParams, + ) -> Result { + match self { + AutoStorage::InMemory(s) => s.list(params).await, + #[cfg(feature = "sqlx")] + AutoStorage::Sqlx(s) => s.list(params).await, } } } #[async_trait] impl AsyncNotificationManager for AutoStorage { - async fn set_task_notification( + async fn set_config( &self, config: &TaskPushNotificationConfig, ) -> Result { match self { - AutoStorage::InMemory(s) => s.set_task_notification(config).await, + AutoStorage::InMemory(s) => s.set_config(config).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.set_task_notification(config).await, + AutoStorage::Sqlx(s) => s.set_config(config).await, } } - async fn get_task_notification( + async fn get_config( &self, - task_id: &str, + params: &a2a_rs::domain::GetTaskPushNotificationConfigParams, ) -> Result { match self { - AutoStorage::InMemory(s) => s.get_task_notification(task_id).await, + AutoStorage::InMemory(s) => s.get_config(params).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.get_task_notification(task_id).await, + AutoStorage::Sqlx(s) => s.get_config(params).await, } } - async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> { + async fn list_configs( + &self, + params: &a2a_rs::domain::ListTaskPushNotificationConfigsParams, + ) -> Result, A2AError> { match self { - AutoStorage::InMemory(s) => s.remove_task_notification(task_id).await, + AutoStorage::InMemory(s) => s.list_configs(params).await, #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.remove_task_notification(task_id).await, + AutoStorage::Sqlx(s) => s.list_configs(params).await, + } + } + + async fn delete_config( + &self, + params: &a2a_rs::domain::DeleteTaskPushNotificationConfigParams, + ) -> Result<(), A2AError> { + match self { + AutoStorage::InMemory(s) => s.delete_config(params).await, + #[cfg(feature = "sqlx")] + AutoStorage::Sqlx(s) => s.delete_config(params).await, } } } @@ -152,116 +171,13 @@ impl AutoStorage { )), } } -} - -#[async_trait] -impl AsyncStreamingHandler for AutoStorage { - async fn add_status_subscriber( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result { - match self { - AutoStorage::InMemory(s) => s.add_status_subscriber(task_id, subscriber).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.add_status_subscriber(task_id, subscriber).await, - } - } - - async fn add_artifact_subscriber( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result { - match self { - AutoStorage::InMemory(s) => s.add_artifact_subscriber(task_id, subscriber).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.add_artifact_subscriber(task_id, subscriber).await, - } - } - - async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> { - match self { - AutoStorage::InMemory(s) => s.remove_subscription(subscription_id).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.remove_subscription(subscription_id).await, - } - } - - async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { - match self { - AutoStorage::InMemory(s) => s.remove_task_subscribers(task_id).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.remove_task_subscribers(task_id).await, - } - } - - async fn get_subscriber_count(&self, task_id: &str) -> Result { - match self { - AutoStorage::InMemory(s) => s.get_subscriber_count(task_id).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.get_subscriber_count(task_id).await, - } - } - - async fn broadcast_status_update( - &self, - task_id: &str, - update: TaskStatusUpdateEvent, - ) -> Result<(), A2AError> { - match self { - AutoStorage::InMemory(s) => s.broadcast_status_update(task_id, update).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.broadcast_status_update(task_id, update).await, - } - } - - async fn broadcast_artifact_update( - &self, - task_id: &str, - update: TaskArtifactUpdateEvent, - ) -> Result<(), A2AError> { - match self { - AutoStorage::InMemory(s) => s.broadcast_artifact_update(task_id, update).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.broadcast_artifact_update(task_id, update).await, - } - } - - async fn status_update_stream( - &self, - task_id: &str, - ) -> Result> + Send>>, A2AError> - { - match self { - AutoStorage::InMemory(s) => s.status_update_stream(task_id).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.status_update_stream(task_id).await, - } - } - async fn artifact_update_stream( - &self, - task_id: &str, - ) -> Result< - Pin> + Send>>, - A2AError, - > { - match self { - AutoStorage::InMemory(s) => s.artifact_update_stream(task_id).await, - #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.artifact_update_stream(task_id).await, - } - } - - async fn combined_update_stream( - &self, - task_id: &str, - ) -> Result> + Send>>, A2AError> { + /// Hand out the inner store's push notifier (shares its config registry). + pub fn push_notifier(&self) -> Arc { match self { - AutoStorage::InMemory(s) => s.combined_update_stream(task_id).await, + AutoStorage::InMemory(s) => s.push_notifier(), #[cfg(feature = "sqlx")] - AutoStorage::Sqlx(s) => s.combined_update_stream(task_id).await, + AutoStorage::Sqlx(s) => s.push_notifier(), } } } @@ -271,6 +187,7 @@ pub struct AgentBuilder { config: AgentConfig, handler: Option, storage: Option, + streaming: Option>, } impl AgentBuilder<(), ()> { @@ -281,6 +198,7 @@ impl AgentBuilder<(), ()> { config, handler: None, storage: None, + streaming: None, }) } @@ -291,6 +209,7 @@ impl AgentBuilder<(), ()> { config, handler: None, storage: None, + streaming: None, }) } @@ -300,6 +219,7 @@ impl AgentBuilder<(), ()> { config, handler: None, storage: None, + streaming: None, } } } @@ -314,21 +234,42 @@ impl AgentBuilder { config: self.config, handler: Some(handler), storage: self.storage, + streaming: self.streaming, } } /// Set custom storage for this agent pub fn with_storage(self, storage: NewS) -> AgentBuilder where - NewS: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static, + NewS: AsyncTaskLifecycle + + AsyncTaskQuery + + AsyncNotificationManager + + Clone + + Send + + Sync + + 'static, { AgentBuilder { config: self.config, handler: self.handler, storage: Some(storage), + streaming: self.streaming, } } + /// Attach a shared streaming backend for real-time updates. + /// + /// Pass the *same* [`AsyncStreamingHandler`] instance your handler + /// broadcasts to (clones of an `InMemoryStreamingHandler` share their + /// subscriber registry). The built [`AgentRuntime`] injects it into the + /// transport so `tasks/subscribe` SSE streams observe those broadcasts — + /// without it, the transport defaults to a no-op and updates never reach + /// clients. + pub fn with_streaming(mut self, streaming: impl AsyncStreamingHandler + 'static) -> Self { + self.streaming = Some(Arc::new(streaming)); + self + } + /// Access the configuration pub fn config(&self) -> &AgentConfig { &self.config @@ -347,18 +288,24 @@ impl AgentBuilder { impl AgentBuilder where H: AsyncMessageHandler + Clone + Send + Sync + 'static, - S: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static, + S: AsyncTaskLifecycle + + AsyncTaskQuery + + AsyncNotificationManager + + Clone + + Send + + Sync + + 'static, { /// Build the agent runtime pub fn build(self) -> Result, BuildError> { let handler = self.handler.ok_or(BuildError::MissingHandler)?; let storage = self.storage.ok_or(BuildError::MissingStorage)?; - Ok(AgentRuntime::new( - self.config, - Arc::new(handler), - Arc::new(storage), - )) + let mut runtime = AgentRuntime::new(self.config, Arc::new(handler), Arc::new(storage)); + if let Some(streaming) = self.streaming { + runtime = runtime.with_streaming(streaming); + } + Ok(runtime) } } @@ -371,6 +318,7 @@ where /// based on what's configured in the TOML file pub async fn build_with_auto_storage(self) -> Result, BuildError> { let handler = self.handler.ok_or(BuildError::MissingHandler)?; + let streaming = self.streaming; let storage = match &self.config.server.storage { StorageConfig::InMemory => { @@ -404,36 +352,11 @@ where } }; - // Initialize MCP client if configured - #[cfg(feature = "mcp-client")] - if self.config.features.mcp_client.enabled { - info!("Initializing MCP client..."); - let mcp_client = McpClientManager::new(); - - // Initialize connections to configured servers - if let Err(e) = mcp_client - .initialize(&self.config.features.mcp_client) - .await - { - return Err(BuildError::RuntimeError(format!( - "Failed to initialize MCP client: {}", - e - ))); - } - - return Ok(AgentRuntime::with_mcp_client( - self.config, - Arc::new(handler), - Arc::new(storage), - mcp_client, - )); + let mut runtime = AgentRuntime::new(self.config, Arc::new(handler), Arc::new(storage)); + if let Some(streaming) = streaming { + runtime = runtime.with_streaming(streaming); } - - Ok(AgentRuntime::new( - self.config, - Arc::new(handler), - Arc::new(storage), - )) + Ok(runtime) } /// Create storage from configuration with custom migrations @@ -444,6 +367,7 @@ where migrations: &'static [&'static str], ) -> Result, BuildError> { let handler = self.handler.ok_or(BuildError::MissingHandler)?; + let streaming = self.streaming; let storage = match &self.config.server.storage { StorageConfig::InMemory => { @@ -475,36 +399,11 @@ where } }; - // Initialize MCP client if configured - #[cfg(feature = "mcp-client")] - if self.config.features.mcp_client.enabled { - info!("Initializing MCP client..."); - let mcp_client = McpClientManager::new(); - - // Initialize connections to configured servers - if let Err(e) = mcp_client - .initialize(&self.config.features.mcp_client) - .await - { - return Err(BuildError::RuntimeError(format!( - "Failed to initialize MCP client: {}", - e - ))); - } - - return Ok(AgentRuntime::with_mcp_client( - self.config, - Arc::new(handler), - Arc::new(storage), - mcp_client, - )); + let mut runtime = AgentRuntime::new(self.config, Arc::new(handler), Arc::new(storage)); + if let Some(streaming) = streaming { + runtime = runtime.with_streaming(streaming); } - - Ok(AgentRuntime::new( - self.config, - Arc::new(handler), - Arc::new(storage), - )) + Ok(runtime) } } diff --git a/a2a-agents/src/core/config.rs b/a2a-agents/src/core/config.rs index 1f16add..2db8564 100644 --- a/a2a-agents/src/core/config.rs +++ b/a2a-agents/src/core/config.rs @@ -79,13 +79,9 @@ impl AgentConfig { )); } - if !self.features.mcp_server.enabled - && self.server.http_port == 0 - && self.server.ws_port == 0 - { + if !self.features.mcp_server.enabled && self.server.http_port == 0 { return Err(ConfigError::ValidationError( - "At least one server port must be configured when MCP server is disabled" - .to_string(), + "The HTTP server port must be configured when MCP server is disabled".to_string(), )); } @@ -153,10 +149,6 @@ pub struct ServerConfig { #[serde(default = "default_http_port")] pub http_port: u16, - /// WebSocket server port (0 to disable) - #[serde(default = "default_ws_port")] - pub ws_port: u16, - /// Storage configuration #[serde(default)] pub storage: StorageConfig, @@ -171,7 +163,6 @@ impl Default for ServerConfig { Self { host: default_host(), http_port: default_http_port(), - ws_port: default_ws_port(), storage: StorageConfig::default(), auth: AuthConfig::default(), } @@ -387,10 +378,18 @@ pub struct McpServerConfig { #[serde(default)] pub enabled: bool, - /// Use stdio transport (for Claude Desktop integration) + /// Use stdio transport (for Claude Desktop integration). + /// + /// Ignored when [`http.enabled`](McpHttpConfig::enabled) is set — the HTTP + /// (Streamable HTTP) transport takes precedence, since a single process + /// cannot own stdin/stdout for stdio and bind a socket at the same time. #[serde(default = "default_true")] pub stdio: bool, + /// Streamable HTTP transport (for networked MCP clients). + #[serde(default)] + pub http: McpHttpConfig, + /// Server name (defaults to agent name) #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, @@ -405,16 +404,86 @@ impl Default for McpServerConfig { Self { enabled: false, stdio: true, + http: McpHttpConfig::default(), name: None, version: None, } } } +/// Streamable HTTP transport configuration for the MCP server. +/// +/// When [`enabled`](Self::enabled), the agent is served over MCP's Streamable +/// HTTP transport (`rmcp`'s `StreamableHttpService`) instead of stdio, mounted +/// at [`path`](Self::path) on `host:port`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpHttpConfig { + /// Serve the MCP server over Streamable HTTP rather than stdio. + #[serde(default)] + pub enabled: bool, + + /// Host/interface to bind to. + #[serde(default = "default_mcp_http_host")] + pub host: String, + + /// TCP port to bind to. + #[serde(default = "default_mcp_http_port")] + pub port: u16, + + /// URL path the Streamable HTTP endpoint is mounted at. + #[serde(default = "default_mcp_http_path")] + pub path: String, + + /// Hostnames / `host:port` authorities accepted in the inbound `Host` + /// header (DNS-rebinding protection). + /// + /// * Omitted → the secure default: loopback only (`localhost`, `127.0.0.1`, + /// `::1`). + /// * `[]` → disable `Host` validation entirely (allow any host — required + /// for public binds, but **not recommended** without an upstream proxy). + /// * Non-empty → only the listed authorities are accepted. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub allowed_hosts: Option>, + + /// Browser `Origin` values accepted on inbound requests. + /// + /// * Omitted (or `[]`) → `Origin` validation disabled (the rmcp default). + /// * Non-empty → requests carrying an `Origin` must match one of these per + /// RFC 6454 `(scheme, host, port)`; entries must include a scheme (e.g. + /// `https://app.example.com`). Requests without an `Origin` still pass. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub allowed_origins: Option>, +} + +impl Default for McpHttpConfig { + fn default() -> Self { + Self { + enabled: false, + host: default_mcp_http_host(), + port: default_mcp_http_port(), + path: default_mcp_http_path(), + allowed_hosts: None, + allowed_origins: None, + } + } +} + fn default_true() -> bool { true } +fn default_mcp_http_host() -> String { + "127.0.0.1".to_string() +} + +fn default_mcp_http_port() -> u16 { + 8000 +} + +fn default_mcp_http_path() -> String { + "/mcp".to_string() +} + /// MCP client configuration #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct McpClientConfig { @@ -462,13 +531,6 @@ fn default_http_port() -> u16 { .unwrap_or(8080) } -fn default_ws_port() -> u16 { - std::env::var("WS_PORT") - .ok() - .and_then(|s| s.parse().ok()) - .unwrap_or(8081) -} - fn default_max_connections() -> u32 { 10 } @@ -547,7 +609,6 @@ mod tests { [server] host = "0.0.0.0" http_port = 3000 - ws_port = 3001 [server.storage] type = "sqlx" @@ -699,6 +760,122 @@ mod tests { assert!(ap2.required); } + #[test] + fn test_mcp_http_config() { + let toml = r#" + [agent] + name = "HTTP MCP Agent" + + [server] + http_port = 0 + + [features.mcp_server] + enabled = true + stdio = false + + [features.mcp_server.http] + enabled = true + host = "0.0.0.0" + port = 9000 + path = "/rpc" + "#; + + let config = AgentConfig::from_toml(toml).unwrap(); + let http = &config.features.mcp_server.http; + assert!(http.enabled); + assert_eq!(http.host, "0.0.0.0"); + assert_eq!(http.port, 9000); + assert_eq!(http.path, "/rpc"); + // Security knobs omitted → None (keep rmcp's loopback-only default). + assert!(http.allowed_hosts.is_none()); + assert!(http.allowed_origins.is_none()); + } + + #[test] + fn test_mcp_http_security_knobs() { + let toml = r#" + [agent] + name = "Public MCP Agent" + + [server] + http_port = 0 + + [features.mcp_server] + enabled = true + + [features.mcp_server.http] + enabled = true + allowed_hosts = ["mcp.example.com", "mcp.example.com:8000"] + allowed_origins = ["https://app.example.com"] + "#; + + let config = AgentConfig::from_toml(toml).unwrap(); + let http = &config.features.mcp_server.http; + assert_eq!( + http.allowed_hosts.as_deref(), + Some( + [ + "mcp.example.com".to_string(), + "mcp.example.com:8000".to_string() + ] + .as_slice() + ) + ); + assert_eq!( + http.allowed_origins.as_deref(), + Some(["https://app.example.com".to_string()].as_slice()) + ); + } + + #[test] + fn test_mcp_http_disable_host_validation() { + // An explicit empty list parses as Some([]) — distinct from omission — + // and disables Host validation at the transport layer. + let toml = r#" + [agent] + name = "Open MCP Agent" + + [server] + http_port = 0 + + [features.mcp_server] + enabled = true + + [features.mcp_server.http] + enabled = true + allowed_hosts = [] + "#; + + let config = AgentConfig::from_toml(toml).unwrap(); + assert_eq!( + config.features.mcp_server.http.allowed_hosts.as_deref(), + Some([].as_slice()) + ); + } + + #[test] + fn test_mcp_http_config_defaults() { + // Omitting [features.mcp_server.http] leaves HTTP disabled with sane defaults. + let toml = r#" + [agent] + name = "Stdio MCP Agent" + + [server] + http_port = 0 + + [features.mcp_server] + enabled = true + "#; + + let config = AgentConfig::from_toml(toml).unwrap(); + let mcp = &config.features.mcp_server; + assert!(mcp.stdio); + assert!(!mcp.http.enabled); + assert_eq!(mcp.http.host, "127.0.0.1"); + assert_eq!(mcp.http.port, 8000); + assert_eq!(mcp.http.path, "/mcp"); + } + #[test] fn test_ap2_extension_config_optional() { let toml = r#" diff --git a/a2a-agents/src/core/mcp.rs b/a2a-agents/src/core/mcp.rs index 4f86d14..53fc0b8 100644 --- a/a2a-agents/src/core/mcp.rs +++ b/a2a-agents/src/core/mcp.rs @@ -15,13 +15,19 @@ use tracing::info; #[cfg(feature = "mcp-server")] use crate::core::config::McpServerConfig; -/// Run agent as MCP server via stdio transport. +/// Run agent as an MCP server over the configured transport. /// -/// Bridges the in-process [`AsyncMessageHandler`] into an MCP server using -/// stdin/stdout — no loopback HTTP server is involved, so there is no -/// auth-config-ignored caveat and tool calls don't pay the round-trip cost. -/// This is the standard way to integrate with Claude Desktop and other MCP -/// stdio clients. +/// Bridges the in-process [`AsyncMessageHandler`] into an MCP server. Two +/// transports are supported, selected by [`McpServerConfig`]: +/// +/// * **stdio** (default) — stdin/stdout, the standard way to integrate with +/// Claude Desktop and other local MCP clients. No socket is bound, so tool +/// calls don't pay a round-trip cost and there is no auth-config caveat. +/// * **Streamable HTTP** ([`McpHttpConfig::enabled`]) — serves the bridge over +/// `rmcp`'s `StreamableHttpService` on `host:port`, for networked clients. +/// Takes precedence over stdio when enabled. +/// +/// [`McpHttpConfig::enabled`]: crate::core::config::McpHttpConfig::enabled #[cfg(feature = "mcp-server")] pub async fn run_mcp_server( config: &McpServerConfig, @@ -29,7 +35,7 @@ pub async fn run_mcp_server( handler: H, ) -> Result<(), Box> where - H: AsyncMessageHandler + Send + Sync + 'static, + H: AsyncMessageHandler + Clone + Send + Sync + 'static, { if !config.enabled { return Ok(()); @@ -37,40 +43,107 @@ where info!("Starting MCP server for agent: {}", agent_card.name); + if config.http.enabled { + return run_streamable_http(config, agent_card, handler).await; + } + + if !config.stdio { + info!( + "No MCP transport enabled (set features.mcp_server.stdio or features.mcp_server.http.enabled)" + ); + return Ok(()); + } + + info!("Starting stdio transport for MCP server"); + // Bridge the A2A agent into an MCP server handler. The bridge calls the // handler in-process; tool-name namespace is derived from agent_card.url. - let bridge = AgentToMcpBridge::with_handler(handler, agent_card.clone()) + let bridge = AgentToMcpBridge::with_handler(handler, agent_card) .with_mcp_metadata(config.name.clone(), config.version.clone()); - if config.stdio { - info!("Starting stdio transport for MCP server"); - - // Get stdio transport - let (read, write) = stdio(); - - // Create and run the MCP service - // serve_directly runs the service in a background task - let _running = - rmcp::service::serve_directly::(bridge, (read, write), None); - - // Keep the service running - wait for Ctrl+C or stdio to close - // The stdio transport will handle shutdown when the connection closes - tokio::select! { - _ = tokio::signal::ctrl_c() => { - info!("Received Ctrl+C, shutting down"); - } - _ = tokio::time::sleep(tokio::time::Duration::MAX) => { - // Never completes normally - only via Ctrl+C or process termination - } + // Get stdio transport + let (read, write) = stdio(); + + // Create and run the MCP service + // serve_directly runs the service in a background task + let _running = + rmcp::service::serve_directly::(bridge, (read, write), None); + + // Keep the service running - wait for Ctrl+C or stdio to close + // The stdio transport will handle shutdown when the connection closes + tokio::select! { + _ = tokio::signal::ctrl_c() => { + info!("Received Ctrl+C, shutting down"); } + _ = tokio::time::sleep(tokio::time::Duration::MAX) => { + // Never completes normally - only via Ctrl+C or process termination + } + } + + info!("MCP server shutdown gracefully"); + Ok(()) +} + +/// Serve the agent bridge over MCP's Streamable HTTP transport. +/// +/// Mounts a fresh [`AgentToMcpBridge`] per session (via the service factory) on +/// an `axum` router at [`McpHttpConfig::path`], backed by an in-memory +/// [`LocalSessionManager`]. Runs until the process is terminated. +/// +/// [`McpHttpConfig::path`]: crate::core::config::McpHttpConfig::path +#[cfg(feature = "mcp-server")] +async fn run_streamable_http( + config: &McpServerConfig, + agent_card: AgentCard, + handler: H, +) -> Result<(), Box> +where + H: AsyncMessageHandler + Clone + Send + Sync + 'static, +{ + use std::sync::Arc; + + use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }; - info!("MCP server shutdown gracefully"); - Ok(()) - } else { - // Future: could support other transports (HTTP SSE, WebSocket) - info!("Only stdio transport is currently supported for MCP server"); - Ok(()) + let http = &config.http; + let addr = format!("{}:{}", http.host, http.port); + + // Start from the secure defaults (loopback-only Host, no Origin check) and + // override only what the TOML specifies. An empty `allowed_hosts` disables + // Host validation entirely (allow any host). + let mut server_config = StreamableHttpServerConfig::default(); + if let Some(hosts) = &http.allowed_hosts { + server_config = server_config.with_allowed_hosts(hosts.clone()); + } + if let Some(origins) = &http.allowed_origins { + server_config = server_config.with_allowed_origins(origins.clone()); } + + // The factory is invoked once per MCP session; each gets its own bridge + // wrapping clones of the shared handler and agent card. + let name = config.name.clone(); + let version = config.version.clone(); + let service = StreamableHttpService::new( + move || { + Ok( + AgentToMcpBridge::with_handler(handler.clone(), agent_card.clone()) + .with_mcp_metadata(name.clone(), version.clone()), + ) + }, + Arc::new(LocalSessionManager::default()), + server_config, + ); + + let router = axum::Router::new().nest_service(&http.path, service); + let listener = tokio::net::TcpListener::bind(&addr).await?; + info!( + "MCP Streamable HTTP server listening on http://{}{}", + addr, http.path + ); + + axum::serve(listener, router).await?; + Ok(()) } /// Check if MCP server mode is enabled @@ -86,7 +159,7 @@ pub async fn run_mcp_server( _handler: H, ) -> Result<(), Box> where - H: a2a_rs::port::AsyncMessageHandler + Send + Sync + 'static, + H: a2a_rs::port::AsyncMessageHandler + Clone + Send + Sync + 'static, { tracing::warn!("MCP server feature not enabled. Compile with --features mcp-server"); Ok(()) diff --git a/a2a-agents/src/core/mcp_client.rs b/a2a-agents/src/core/mcp_client.rs index 8c50b70..6e03178 100644 --- a/a2a-agents/src/core/mcp_client.rs +++ b/a2a-agents/src/core/mcp_client.rs @@ -1,143 +1,213 @@ -//! MCP client integration for A2A agents +//! MCP client integration for A2A agents. //! -//! This module provides functionality for A2A agents to connect to MCP servers -//! and use their tools as part of the agent's capabilities. +//! Lets an A2A agent act as an MCP *client*: connect to one or more MCP servers +//! (spawned as child processes), discover their tools, and invoke them while +//! serving A2A requests. +//! +//! [`McpClientManager`] is the adapter that owns those connections. Build one +//! from the agent's `[features.mcp_client]` config with [`McpClientManager::connect`] +//! and hand it to your [`AsyncMessageHandler`](a2a_rs::port::AsyncMessageHandler); +//! the handler then reaches its tools through the +//! [`McpToolsExt`](crate::traits::McpToolsExt) convenience trait: +//! +//! ```rust,ignore +//! let config = AgentConfig::from_file("agent.toml")?; +//! let mcp = McpClientManager::connect(&config.features.mcp_client).await?; +//! let handler = MyHandler::new(mcp); // impls McpToolsExt by returning &self.mcp +//! AgentBuilder::new(config) +//! .with_handler(handler) +//! .build_with_auto_storage() +//! .await? +//! .run() +//! .await?; +//! ``` + +#![cfg(feature = "mcp-client")] -#[cfg(feature = "mcp-client")] use crate::core::config::{McpClientConfig, McpServerConnection}; -#[cfg(feature = "mcp-client")] use rmcp::{ Peer, RoleClient, ServiceExt, model::{ - CallToolRequestParams, ClientCapabilities, ClientInfo, Implementation, ProtocolVersion, - Tool, + CallToolRequestParams, CallToolResult, ClientCapabilities, ClientInfo, Implementation, + ProtocolVersion, Tool, }, + service::RunningService, transport::TokioChildProcess, }; -#[cfg(feature = "mcp-client")] use std::collections::HashMap; -#[cfg(feature = "mcp-client")] use std::sync::Arc; -#[cfg(feature = "mcp-client")] use tokio::process::Command; -#[cfg(feature = "mcp-client")] use tracing::{debug, error, info}; -/// Manager for MCP client connections -#[cfg(feature = "mcp-client")] +/// Errors raised while connecting to or calling out to MCP servers. +#[derive(Debug, thiserror::Error)] +pub enum McpClientError { + /// The child process backing an MCP server could not be spawned. + #[error("failed to spawn MCP server '{server}': {source}")] + Spawn { + server: String, + #[source] + source: std::io::Error, + }, + + /// The MCP handshake or initial tool listing failed. + #[error("failed to connect to MCP server '{server}': {message}")] + Connect { server: String, message: String }, + + /// A tool was requested on a server that isn't connected. + #[error("MCP server '{server}' is not connected")] + NotConnected { server: String }, + + /// The remote tool invocation failed. + #[error("tool '{tool}' on MCP server '{server}' failed: {message}")] + ToolCall { + server: String, + tool: String, + message: String, + }, +} + +/// Manages connections to external MCP servers and exposes their tools. +/// +/// Cheap to clone — the connection registry lives behind an [`Arc`], so a +/// handler can hold one and the framework can share it freely. #[derive(Clone)] pub struct McpClientManager { - /// Connected MCP servers and their peers + /// Connected MCP servers and their peers. servers: Arc>>, } -#[cfg(feature = "mcp-client")] struct McpServerInfo { + /// The live service handle. Dropping it tears down the transport (and the + /// child process), so it's held for as long as the server is registered — + /// the [`Peer`] below is only usable while this is alive. + _service: RunningService, peer: Peer, tools: Vec, } -#[cfg(feature = "mcp-client")] impl Default for McpClientManager { fn default() -> Self { Self::new() } } -#[cfg(feature = "mcp-client")] impl McpClientManager { - /// Create a new MCP client manager + /// Create an empty manager with no connections. + /// + /// Prefer [`connect`](Self::connect) to build and wire up a manager from + /// configuration in one step. pub fn new() -> Self { Self { servers: Arc::new(tokio::sync::RwLock::new(HashMap::new())), } } - /// Initialize connections to MCP servers from configuration - pub async fn initialize( - &self, - config: &McpClientConfig, - ) -> Result<(), Box> { + /// Build a manager and connect to every server in `config`. + /// + /// Connection is lenient: a server that fails to start is logged and + /// skipped so one bad entry doesn't take down the agent. The call only + /// fails if servers were configured but *none* could be reached — a clear + /// startup error rather than a tool call that mysteriously fails later. + /// When [`config.enabled`](McpClientConfig::enabled) is false this returns + /// an empty manager. + pub async fn connect(config: &McpClientConfig) -> Result { + let manager = Self::new(); + manager.initialize(config).await?; + Ok(manager) + } + + /// Connect to the servers in `config`, adding them to this manager. + /// + /// See [`connect`](Self::connect) for the leniency contract. + pub async fn initialize(&self, config: &McpClientConfig) -> Result<(), McpClientError> { if !config.enabled { info!("MCP client is disabled"); return Ok(()); } info!( - "Initializing MCP client with {} servers", + "Initializing MCP client with {} server(s)", config.servers.len() ); + let mut connected = 0usize; + let mut last_err = None; for server_config in &config.servers { match self.connect_to_server(server_config).await { - Ok(_) => { - info!( - "Successfully connected to MCP server: {}", - server_config.name - ); + Ok(()) => { + connected += 1; + info!("Connected to MCP server '{}'", server_config.name); } Err(e) => { error!( - "Failed to connect to MCP server '{}': {}", - server_config.name, e + "Failed to connect to MCP server '{}': {e}", + server_config.name ); - // Continue with other servers even if one fails + last_err = Some(e); } } } + if connected == 0 && !config.servers.is_empty() { + return Err(last_err.expect("a non-empty server list reports a failure")); + } + Ok(()) } - /// Connect to a single MCP server - async fn connect_to_server( - &self, - config: &McpServerConnection, - ) -> Result<(), Box> { - debug!("Connecting to MCP server: {}", config.name); + /// Connect to a single MCP server and register its tools. + async fn connect_to_server(&self, config: &McpServerConnection) -> Result<(), McpClientError> { + debug!("Connecting to MCP server '{}'", config.name); debug!("Command: {} {:?}", config.command, config.args); - // Build the command let mut cmd = Command::new(&config.command); cmd.args(&config.args); - - // Set environment variables for (key, value) in &config.env { cmd.env(key, value); } - - // Set working directory if let Some(ref cwd) = config.cwd { cmd.current_dir(cwd); } - // Create transport from the child process - let (transport, _stderr) = TokioChildProcess::builder(cmd).spawn()?; - - // Create MCP client with custom client info. `ClientInfo` and - // `Implementation` are `#[non_exhaustive]` in rmcp 1.7 — use the - // typed builders rather than struct literals. + // Spawn the server as a child process and talk to it over its stdio. + let (transport, _stderr) = + TokioChildProcess::builder(cmd) + .spawn() + .map_err(|source| McpClientError::Spawn { + server: config.name.clone(), + source, + })?; + + // `ClientInfo` and `Implementation` are `#[non_exhaustive]` in rmcp — + // use the typed builders rather than struct literals. let implementation = Implementation::new(format!("a2a-agent-{}", config.name), "0.1.0"); let client_info = ClientInfo::new(ClientCapabilities::default(), implementation) .with_protocol_version(ProtocolVersion::V_2024_11_05); - // Start the client service - let service = client_info.serve(transport).await?; + let service = client_info + .serve(transport) + .await + .map_err(|e| McpClientError::Connect { + server: config.name.clone(), + message: e.to_string(), + })?; let peer = service.peer().clone(); - // List available tools - debug!("Listing tools from MCP server: {}", config.name); + debug!("Listing tools from MCP server '{}'", config.name); let tools_result = peer .list_tools(None) .await - .map_err(|e| format!("Failed to list tools: {}", e))?; + .map_err(|e| McpClientError::Connect { + server: config.name.clone(), + message: format!("failed to list tools: {e}"), + })?; info!( - "MCP server '{}' has {} tools", + "MCP server '{}' exposes {} tool(s)", config.name, tools_result.tools.len() ); - for tool in &tools_result.tools { let desc = tool .description @@ -147,114 +217,74 @@ impl McpClientManager { debug!(" - {} ({})", tool.name, desc); } - // Store server info let server_info = McpServerInfo { + _service: service, peer, tools: tools_result.tools, }; - - let mut servers = self.servers.write().await; - servers.insert(config.name.clone(), server_info); + self.servers + .write() + .await + .insert(config.name.clone(), server_info); Ok(()) } - /// Call an MCP tool + /// Call a tool on a connected MCP server. pub async fn call_tool( &self, server_name: &str, tool_name: &str, arguments: Option, - ) -> Result> { + ) -> Result { let servers = self.servers.read().await; - let server = servers .get(server_name) - .ok_or_else(|| format!("MCP server '{}' not found", server_name))?; + .ok_or_else(|| McpClientError::NotConnected { + server: server_name.to_string(), + })?; - debug!( - "Calling tool '{}' on MCP server '{}'", - tool_name, server_name - ); + debug!("Calling tool '{tool_name}' on MCP server '{server_name}'"); - // Convert arguments to Map if provided let args_map = arguments.and_then(|v| v.as_object().cloned()); - let mut params = CallToolRequestParams::new(tool_name.to_string()); if let Some(map) = args_map { params = params.with_arguments(map); } - let result = server + server .peer .call_tool(params) .await - .map_err(|e| format!("Tool call failed: {}", e))?; - - Ok(result) + .map_err(|e| McpClientError::ToolCall { + server: server_name.to_string(), + tool: tool_name.to_string(), + message: e.to_string(), + }) } - /// List all available tools from all connected servers + /// List every tool across all connected servers as `(server, tool)` pairs. pub async fn list_all_tools(&self) -> Vec<(String, Tool)> { let servers = self.servers.read().await; - let mut all_tools = Vec::new(); - - for (server_name, server_info) in servers.iter() { - for tool in &server_info.tools { - all_tools.push((server_name.clone(), tool.clone())); - } - } - - all_tools + servers + .iter() + .flat_map(|(name, info)| info.tools.iter().map(move |t| (name.clone(), t.clone()))) + .collect() } - /// Get tools from a specific server + /// Get the tools exposed by a specific server, if connected. pub async fn list_server_tools(&self, server_name: &str) -> Option> { let servers = self.servers.read().await; servers.get(server_name).map(|s| s.tools.clone()) } - /// Check if a server is connected + /// Whether a server with the given name is connected. pub async fn is_connected(&self, server_name: &str) -> bool { - let servers = self.servers.read().await; - servers.contains_key(server_name) + self.servers.read().await.contains_key(server_name) } - /// Get names of all connected servers + /// Names of all connected servers. pub async fn connected_servers(&self) -> Vec { - let servers = self.servers.read().await; - servers.keys().cloned().collect() - } -} - -#[cfg(not(feature = "mcp-client"))] -#[derive(Clone, Default)] -pub struct McpClientManager; - -#[cfg(not(feature = "mcp-client"))] -impl McpClientManager { - pub fn new() -> Self { - Self - } - - pub async fn initialize( - &self, - _config: &crate::core::config::McpClientConfig, - ) -> Result<(), Box> { - tracing::warn!("MCP client feature not enabled. Compile with --features mcp-client"); - Ok(()) - } - - pub async fn call_tool( - &self, - _server_name: &str, - _tool_name: &str, - _arguments: Option, - ) -> Result> { - Err("MCP client feature not enabled".into()) - } - - pub async fn list_all_tools(&self) -> Vec<(String, serde_json::Value)> { - Vec::new() + self.servers.read().await.keys().cloned().collect() } } diff --git a/a2a-agents/src/core/mod.rs b/a2a-agents/src/core/mod.rs index c1393db..4090f9d 100644 --- a/a2a-agents/src/core/mod.rs +++ b/a2a-agents/src/core/mod.rs @@ -34,5 +34,6 @@ pub use config::{ AgentConfig, Ap2ExtensionConfig, AuthConfig, ConfigError, ExtensionsConfig, McpClientConfig, McpServerConfig, McpServerConnection, ServerConfig, StorageConfig, }; -pub use mcp_client::McpClientManager; +#[cfg(feature = "mcp-client")] +pub use mcp_client::{McpClientError, McpClientManager}; pub use runtime::{AgentRuntime, RuntimeError}; diff --git a/a2a-agents/src/core/runtime.rs b/a2a-agents/src/core/runtime.rs index 03eda8a..f5aec5f 100644 --- a/a2a-agents/src/core/runtime.rs +++ b/a2a-agents/src/core/runtime.rs @@ -3,14 +3,11 @@ //! The runtime handles starting HTTP/WebSocket servers, wiring components, //! and managing the agent lifecycle based on configuration. -#[cfg(feature = "mcp-client")] -use crate::core::McpClientManager; use crate::core::config::{AgentConfig, AuthConfig, StorageConfig}; -use a2a_rs::adapter::{ - BearerTokenAuthenticator, DefaultRequestProcessor, HttpServer, SimpleAgentInfo, -}; +use a2a_rs::adapter::{BearerTokenAuthenticator, ConnectRpcAdapter, HttpServer, SimpleAgentInfo}; use a2a_rs::port::{ - AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, + AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, + AsyncTaskQuery, }; use std::sync::Arc; use tracing::{info, warn}; @@ -27,14 +24,23 @@ pub struct AgentRuntime { config: AgentConfig, handler: Arc, storage: Arc, - #[cfg(feature = "mcp-client")] - mcp_client: Option, + /// Optional shared streaming backend. When set, it is injected into the + /// transport adapter so SSE subscribers see the same broadcasts the handler + /// emits (e.g. via the `TaskStatusBroadcast` mixin). Without it, the adapter + /// defaults to a no-op streaming handler and updates never reach clients. + streaming: Option>, } impl AgentRuntime where H: AsyncMessageHandler + Clone + Send + Sync + 'static, - S: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static, + S: AsyncTaskLifecycle + + AsyncTaskQuery + + AsyncNotificationManager + + Clone + + Send + + Sync + + 'static, { /// Create a new runtime pub fn new(config: AgentConfig, handler: Arc, storage: Arc) -> Self { @@ -42,31 +48,19 @@ where config, handler, storage, - #[cfg(feature = "mcp-client")] - mcp_client: None, + streaming: None, } } - /// Create a new runtime with MCP client - #[cfg(feature = "mcp-client")] - pub fn with_mcp_client( - config: AgentConfig, - handler: Arc, - storage: Arc, - mcp_client: McpClientManager, - ) -> Self { - Self { - config, - handler, - storage, - mcp_client: Some(mcp_client), - } - } - - /// Get the MCP client manager (if enabled) - #[cfg(feature = "mcp-client")] - pub fn mcp_client(&self) -> Option<&McpClientManager> { - self.mcp_client.as_ref() + /// Attach a shared streaming backend. + /// + /// Pass the *same* [`AsyncStreamingHandler`] instance the message handler + /// broadcasts to (clones of an `InMemoryStreamingHandler` share their + /// subscriber registry). The runtime injects it into the transport adapter + /// so `tasks/subscribe` SSE streams observe those broadcasts. + pub fn with_streaming(mut self, streaming: Arc) -> Self { + self.streaming = Some(streaming); + self } /// Build agent info from configuration @@ -252,13 +246,21 @@ where ); let agent_info = self.build_agent_info(base_url); - let processor = DefaultRequestProcessor::new( + let mut processor = ConnectRpcAdapter::new( (*self.handler).clone(), (*self.storage).clone(), (*self.storage).clone(), agent_info.clone(), ); + // Share the handler's streaming backend with the transport so SSE + // subscribers observe the same broadcasts. Without this the adapter + // keeps its default no-op streaming handler and updates never surface. + if let Some(streaming) = &self.streaming { + processor = processor.with_streaming_handler(streaming.clone()); + info!("📡 Streaming backend wired into transport (SSE subscribers live)"); + } + let bind_address = format!( "{}:{}", self.config.server.host, self.config.server.http_port @@ -427,10 +429,7 @@ where } /// Start the appropriate server(s) based on configuration - pub async fn run(self) -> Result<(), RuntimeError> - where - S: AsyncStreamingHandler, - { + pub async fn run(self) -> Result<(), RuntimeError> { // Check if MCP server mode is enabled if self.config.features.mcp_server.enabled { return self.run_as_mcp_server().await; diff --git a/a2a-agents/src/lib.rs b/a2a-agents/src/lib.rs index aa176ce..c3384cf 100644 --- a/a2a-agents/src/lib.rs +++ b/a2a-agents/src/lib.rs @@ -83,14 +83,15 @@ pub mod traits; pub mod utils; // Example agent implementations -// Note: Currently public for binaries/examples, will be private in Phase 3 +// Note: public for binaries/examples; intended to become private once agents +// are extracted into their own crates. pub mod agents; // Convenience re-exports for the most commonly used types pub use core::{AgentBuilder, AgentConfig, AgentRuntime, BuildError, ConfigError, RuntimeError}; pub use traits::{AgentPlugin, SkillDefinition}; -// Re-export the reimbursement agent for backward compatibility -// (This will be removed in Phase 3 when agents are extracted) +// Re-export the reimbursement agent as a convenience +// (intended to be removed once agents are extracted into their own crates) #[cfg(feature = "reimbursement-agent")] pub use agents::reimbursement::ReimbursementHandler; diff --git a/a2a-agents/src/traits/mcp_tools.rs b/a2a-agents/src/traits/mcp_tools.rs index 54fbebd..c78a55b 100644 --- a/a2a-agents/src/traits/mcp_tools.rs +++ b/a2a-agents/src/traits/mcp_tools.rs @@ -1,37 +1,47 @@ //! Traits and helpers for using MCP tools in message handlers #[cfg(feature = "mcp-client")] -use crate::core::McpClientManager; +use crate::core::{McpClientError, McpClientManager}; #[cfg(feature = "mcp-client")] use rmcp::model::CallToolResult; #[cfg(feature = "mcp-client")] use serde_json::Value; -/// Extension trait for message handlers to easily call MCP tools +/// Extension trait giving any handler that holds an [`McpClientManager`] +/// ergonomic access to its MCP tools. +/// +/// Implement it by returning a reference to the manager your handler owns; the +/// default methods then forward to it: +/// +/// ```rust,ignore +/// impl McpToolsExt for MyHandler { +/// fn mcp_client(&self) -> &McpClientManager { &self.mcp } +/// } +/// ``` #[cfg(feature = "mcp-client")] #[allow(async_fn_in_trait)] pub trait McpToolsExt { - /// Get the MCP client manager + /// The MCP client manager this handler calls tools through. fn mcp_client(&self) -> &McpClientManager; - /// Call an MCP tool with JSON arguments + /// Call an MCP tool with JSON arguments. async fn call_mcp_tool( &self, server_name: &str, tool_name: &str, arguments: Option, - ) -> Result> { + ) -> Result { self.mcp_client() .call_tool(server_name, tool_name, arguments) .await } - /// Call an MCP tool with no arguments + /// Call an MCP tool with no arguments. async fn call_mcp_tool_simple( &self, server_name: &str, tool_name: &str, - ) -> Result> { + ) -> Result { self.call_mcp_tool(server_name, tool_name, None).await } diff --git a/a2a-agents/tests/mcp_client_test.rs b/a2a-agents/tests/mcp_client_test.rs new file mode 100644 index 0000000..47c1b07 --- /dev/null +++ b/a2a-agents/tests/mcp_client_test.rs @@ -0,0 +1,100 @@ +//! End-to-end test for the `mcp-client` framework integration. +//! +//! Spawns the bundled `mcp_echo_server` as a real child-process MCP server, +//! connects to it through the same `McpClientManager::connect` path the +//! framework uses, and exercises tool discovery + invocation. This proves the +//! loop the integration closes: config → connected manager → tool call. + +#![cfg(feature = "mcp-client")] + +use a2a_agents::core::{McpClientManager, config::McpClientConfig, config::McpServerConnection}; +use a2a_agents::traits::extract_tool_result_text; +use serde_json::json; + +/// Build a config that spawns the compiled `mcp_echo_server` binary directly +/// (no nested `cargo`, so the test is fast and deterministic). +fn echo_server_config() -> McpClientConfig { + McpClientConfig { + enabled: true, + servers: vec![McpServerConnection { + name: "echo".to_string(), + command: env!("CARGO_BIN_EXE_mcp_echo_server").to_string(), + args: Vec::new(), + env: Default::default(), + cwd: None, + }], + } +} + +#[tokio::test] +async fn connect_discovers_tools() { + let mcp = McpClientManager::connect(&echo_server_config()) + .await + .expect("connect to echo server"); + + assert!(mcp.is_connected("echo").await); + assert_eq!(mcp.connected_servers().await, vec!["echo".to_string()]); + + let tools = mcp.list_server_tools("echo").await.expect("echo tools"); + let names: Vec<&str> = tools.iter().map(|t| t.name.as_ref()).collect(); + assert!( + names.contains(&"echo"), + "expected `echo` tool, got {names:?}" + ); + assert!(names.contains(&"add"), "expected `add` tool, got {names:?}"); +} + +#[tokio::test] +async fn call_echo_tool_round_trips() { + let mcp = McpClientManager::connect(&echo_server_config()) + .await + .expect("connect to echo server"); + + let result = mcp + .call_tool("echo", "echo", Some(json!({ "text": "hello mcp" }))) + .await + .expect("echo tool call"); + + assert_eq!(extract_tool_result_text(&result), "hello mcp"); +} + +#[tokio::test] +async fn call_add_tool_computes() { + let mcp = McpClientManager::connect(&echo_server_config()) + .await + .expect("connect to echo server"); + + let result = mcp + .call_tool("echo", "add", Some(json!({ "a": 2, "b": 40 }))) + .await + .expect("add tool call"); + assert_eq!(extract_tool_result_text(&result), "42"); +} + +#[tokio::test] +async fn call_on_unknown_server_is_not_connected() { + let mcp = McpClientManager::connect(&echo_server_config()) + .await + .expect("connect to echo server"); + + let err = mcp + .call_tool("does-not-exist", "echo", None) + .await + .expect_err("calling an unconnected server must fail"); + assert!( + matches!(err, a2a_agents::core::McpClientError::NotConnected { .. }), + "expected NotConnected, got {err:?}" + ); +} + +#[tokio::test] +async fn disabled_config_yields_empty_manager() { + let cfg = McpClientConfig { + enabled: false, + servers: Vec::new(), + }; + let mcp = McpClientManager::connect(&cfg) + .await + .expect("empty manager"); + assert!(mcp.connected_servers().await.is_empty()); +} diff --git a/a2a-agents/tests/mcp_http_test.rs b/a2a-agents/tests/mcp_http_test.rs new file mode 100644 index 0000000..28ad96e --- /dev/null +++ b/a2a-agents/tests/mcp_http_test.rs @@ -0,0 +1,224 @@ +//! End-to-end tests for the MCP Streamable HTTP transport. +//! +//! Boots a TOML-configured agent in MCP-server mode with the HTTP transport +//! enabled, then exercises real behavior over the wire: +//! * a full MCP `initialize` handshake (happy path), and +//! * the `allowed_hosts` DNS-rebinding knob (reject vs. allow), driven over a +//! raw socket so the `Host` header is fully under test control. + +#![cfg(feature = "mcp-server")] + +use a2a_agents::core::AgentBuilder; +use a2a_rs::{ + InMemoryTaskStorage, + domain::{A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, + port::AsyncMessageHandler, +}; +use async_trait::async_trait; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +#[derive(Clone)] +struct EchoHandler; + +#[async_trait] +impl AsyncMessageHandler for EchoHandler { + async fn process_message( + &self, + task_id: &str, + message: &Message, + _session_id: Option<&str>, + ) -> Result { + let text = message + .parts + .iter() + .find_map(|p| p.get_text()) + .unwrap_or("") + .to_string(); + let response = Message::builder() + .role(Role::Agent) + .parts(vec![Part::text(format!("echo: {text}"))]) + .message_id(uuid::Uuid::new_v4().to_string()) + .build(); + Ok(Task::builder() + .id(task_id.to_string()) + .context_id(message.context_id.clone()) + .status(TaskStatus::new( + TaskState::Completed, + Some(response.clone()), + )) + .history(vec![message.clone(), response]) + .build()) + } +} + +/// Grab a free TCP port by binding to :0 and immediately releasing it. +fn free_port() -> u16 { + std::net::TcpListener::bind("127.0.0.1:0") + .expect("bind ephemeral port") + .local_addr() + .expect("local_addr") + .port() +} + +/// Build + spawn an MCP/HTTP agent from a TOML fragment on the given port. +fn spawn_agent(http_section: &str, port: u16) -> tokio::task::JoinHandle<()> { + let toml_content = format!( + r#" + [agent] + name = "HTTP MCP Agent" + version = "0.1.0" + + [server] + host = "127.0.0.1" + http_port = 0 + + [features.mcp_server] + enabled = true + stdio = false + + [features.mcp_server.http] + enabled = true + host = "127.0.0.1" + port = {port} + path = "/mcp" + {http_section} + + [[skills]] + id = "echo" + name = "Echo" + description = "Echoes input" + "# + ); + + let runtime = AgentBuilder::from_toml(&toml_content) + .expect("build builder") + .with_handler(EchoHandler) + .with_storage(InMemoryTaskStorage::new()) + .build() + .expect("build runtime"); + + tokio::spawn(async move { + let _ = runtime.run().await; + }) +} + +/// Poll until the server accepts TCP connections (or give up). +async fn wait_listening(port: u16) { + for _ in 0..50 { + if tokio::net::TcpStream::connect(("127.0.0.1", port)) + .await + .is_ok() + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + panic!("server on port {port} never started listening"); +} + +const INIT_BODY: &str = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"a2a-test-client","version":"0.1.0"}}}"#; + +/// Send a raw HTTP/1.1 `initialize` POST with a chosen `Host` header and return +/// the first response line (e.g. `HTTP/1.1 200 OK`). +async fn raw_initialize(port: u16, host_header: &str) -> String { + let mut stream = tokio::net::TcpStream::connect(("127.0.0.1", port)) + .await + .expect("connect"); + let request = format!( + "POST /mcp HTTP/1.1\r\n\ + Host: {host}\r\n\ + Accept: application/json, text/event-stream\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {len}\r\n\ + Connection: close\r\n\ + \r\n\ + {body}", + host = host_header, + len = INIT_BODY.len(), + body = INIT_BODY, + ); + stream + .write_all(request.as_bytes()) + .await + .expect("write request"); + + // The status line + headers arrive promptly; read one chunk with a timeout + // (a 200 reply opens an SSE stream that would otherwise keep us reading). + let mut buf = vec![0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(5), stream.read(&mut buf)) + .await + .expect("response within timeout") + .expect("read response"); + let text = String::from_utf8_lossy(&buf[..n]); + text.lines().next().unwrap_or_default().trim().to_string() +} + +#[tokio::test] +async fn streamable_http_initialize_handshake() { + let port = free_port(); + let server = spawn_agent("", port); + wait_listening(port).await; + + let url = format!("http://127.0.0.1:{port}/mcp"); + let client = reqwest::Client::new(); + let init_body: serde_json::Value = serde_json::from_str(INIT_BODY).unwrap(); + + let response = client + .post(&url) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .json(&init_body) + .send() + .await + .expect("initialize request"); + + assert!( + response.status().is_success(), + "initialize should return 2xx, got {}", + response.status() + ); + assert!( + response.headers().contains_key("mcp-session-id"), + "stateful server must return an Mcp-Session-Id header" + ); + + let body = response.text().await.expect("read body"); + assert!( + body.contains("\"result\"") && body.contains("serverInfo"), + "initialize response should carry a JSON-RPC result with serverInfo, got: {body}" + ); + + server.abort(); +} + +#[tokio::test] +async fn default_config_rejects_non_loopback_host() { + // No allowed_hosts override → secure default (loopback only). + let port = free_port(); + let server = spawn_agent("", port); + wait_listening(port).await; + + let status = raw_initialize(port, "evil.example.com").await; + assert!( + status.contains("403"), + "non-loopback Host should be rejected with 403, got: {status:?}" + ); + + server.abort(); +} + +#[tokio::test] +async fn empty_allowed_hosts_permits_any_host() { + // allowed_hosts = [] disables Host validation entirely. + let port = free_port(); + let server = spawn_agent("allowed_hosts = []", port); + wait_listening(port).await; + + let status = raw_initialize(port, "evil.example.com").await; + assert!( + status.contains("200"), + "with Host validation disabled any Host should be accepted, got: {status:?}" + ); + + server.abort(); +} diff --git a/a2a-agents/tests/mcp_smoke_test.rs b/a2a-agents/tests/mcp_smoke_test.rs index ea976f0..d4d8ed0 100644 --- a/a2a-agents/tests/mcp_smoke_test.rs +++ b/a2a-agents/tests/mcp_smoke_test.rs @@ -54,7 +54,6 @@ mod tests { [server] host = "127.0.0.1" http_port = 0 - ws_port = 0 auth = { type = "bearer", tokens = ["secret-token-123"] } [features.mcp_server] diff --git a/a2a-ap2/Cargo.toml b/a2a-ap2/Cargo.toml index 9c0ab7c..34c14b2 100644 --- a/a2a-ap2/Cargo.toml +++ b/a2a-ap2/Cargo.toml @@ -11,9 +11,9 @@ keywords = ["a2a", "ap2", "payments", "agent", "commerce"] categories = ["api-bindings", "network-programming"] [dependencies] -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -thiserror = "1.0" +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } a2a-rs = { path = "../a2a-rs", version = "0.3" } buffa = { version = "0.3.0", features = ["json"] } buffa-types = { version = "0.3.0", features = ["json"] } @@ -23,7 +23,7 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies] -serde_json = "1.0" +serde_json = { workspace = true } [[example]] name = "payment_flow" diff --git a/a2a-client/Cargo.toml b/a2a-client/Cargo.toml index 3c0dbf9..e6fa797 100644 --- a/a2a-client/Cargo.toml +++ b/a2a-client/Cargo.toml @@ -16,47 +16,47 @@ name = "a2a_client" [dependencies] # A2A integration # Note: We need "server" feature for the port traits even though this is a client library -a2a-rs = { path = "../a2a-rs", version = "0.3", features = ["http-client", "server", "tracing"], default-features = false } +a2a-rs = { path = "../a2a-rs", version = "0.3", features = ["http-client", "jsonrpc-client", "server", "tracing"], default-features = false } buffa = { version = "0.3.0", features = ["json"] } # Async runtime -tokio = { version = "1", features = ["time"] } +tokio = { workspace = true, features = ["time"] } # Web framework axum = { version = "0.7", optional = true } # Serialization -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde = { workspace = true } +serde_json = { workspace = true } # Error handling -anyhow = "1.0" -thiserror = "1.0" +anyhow = { workspace = true } +thiserror = { workspace = true } # Reqwest for URL parsing and handling -reqwest = { version = "0.12", default-features = false } +reqwest = { workspace = true } # Logging -tracing = "0.1" +tracing = { workspace = true } # Async streams -futures = "0.3" +futures = { workspace = true } async-stream = { version = "0.3", optional = true } # UUID generation -uuid = { version = "1", features = ["v4"] } +uuid = { workspace = true } [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies] -tokio = { version = "1", features = ["full"] } +tokio = { workspace = true, features = ["full"] } wiremock = "0.6" axum = "0.7" futures-util = "0.3" a2a-agents = { path = "../a2a-agents" } -async-trait = "0.1" +async-trait = { workspace = true } [features] default = ["axum-components"] diff --git a/a2a-client/README.md b/a2a-client/README.md index 59bb2c7..121211d 100644 --- a/a2a-client/README.md +++ b/a2a-client/README.md @@ -246,7 +246,7 @@ See the `examples/` directory for complete working examples: ## Roadmap -See [TODO.md](TODO.md) for planned features and improvements. +See the workspace [ROADMAP.md](../ROADMAP.md) for planned features and improvements. ## Contributing diff --git a/a2a-client/examples/basic_client.rs b/a2a-client/examples/basic_client.rs index ce5d85a..5554700 100644 --- a/a2a-client/examples/basic_client.rs +++ b/a2a-client/examples/basic_client.rs @@ -18,7 +18,6 @@ use a2a_client::WebA2AClient; use a2a_rs::domain::{Message, Part}; -use a2a_rs::services::AsyncA2AClient; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -47,7 +46,7 @@ async fn main() -> anyhow::Result<()> { let task_id = uuid::Uuid::new_v4().to_string(); match client - .http + .transport .send_task_message(&task_id, &message, None, None) .await { @@ -59,7 +58,7 @@ async fn main() -> anyhow::Result<()> { // Retrieve the task to see the agent's response println!("Retrieving task to see agent response..."); - match client.http.get_task(&task.id, None).await { + match client.transport.get_task(&task.id, None).await { Ok(updated_task) => { println!("✓ Retrieved task successfully!"); let history = &updated_task.history; diff --git a/a2a-client/examples/sse_streaming.rs b/a2a-client/examples/sse_streaming.rs index fe82ecd..519d6a7 100644 --- a/a2a-client/examples/sse_streaming.rs +++ b/a2a-client/examples/sse_streaming.rs @@ -22,7 +22,6 @@ use a2a_client::{WebA2AClient, components::create_sse_stream}; use a2a_rs::domain::{Message, Part, Role}; -use a2a_rs::services::AsyncA2AClient; use axum::{ Json, Router, extract::{Path, State}, @@ -99,7 +98,7 @@ async fn send_message_handler( let task_id = uuid::Uuid::new_v4().to_string(); match client - .http + .transport .send_task_message(&task_id, &message, None, None) .await { diff --git a/a2a-client/src/components/streaming.rs b/a2a-client/src/components/streaming.rs index 9515b34..5f00203 100644 --- a/a2a-client/src/components/streaming.rs +++ b/a2a-client/src/components/streaming.rs @@ -1,138 +1,57 @@ //! Server-Sent Events (SSE) streaming components -use a2a_rs::services::{AsyncA2AClient, StreamItem}; +use a2a_rs::{RetryPolicy, StreamItem}; use axum::response::sse::{Event, KeepAlive, Sse}; use futures::StreamExt; -use std::{convert::Infallible, sync::Arc, time::Duration}; -use tracing::{error, info, warn}; +use std::{convert::Infallible, sync::Arc}; +use tracing::{error, warn}; use crate::WebA2AClient; -/// Create an SSE stream for task updates +/// Create an SSE stream of task updates for an Axum endpoint. /// -/// This function handles: -/// - WebSocket streaming if available -/// - Fallback to HTTP polling -/// - Automatic retry logic -/// - Serialization to JSON events +/// This is a thin serialization adapter over +/// [`WebA2AClient::subscribe_resilient`]: the reusable core owns reconnect + +/// exponential backoff and `Last-Event-ID` resumption, so this function only +/// maps each [`StreamItem`] to a typed Axum [`Event`] (`task-update` / +/// `task-status` / `artifact`), tagging it with the server event id so a browser +/// `EventSource` resumes automatically. The stream ends when the task reaches a +/// terminal state (or retries are exhausted). pub fn create_sse_stream( client: Arc, task_id: String, ) -> Sse>> { - let stream = async_stream::stream! { - info!("Attempting to subscribe to task {} via HTTP stream", task_id); - - let mut retry_count: u32 = 0; - let max_retries = 15; // Covers ~2 minutes using exponential backoff - let base_delay = Duration::from_millis(500); - let max_delay = Duration::from_secs(10); - let mut is_terminal = false; - - loop { - match client.http.subscribe_to_task(&task_id, Some(50)).await { - Ok(mut event_stream) => { - info!("Successfully subscribed to task {} via HTTP stream", task_id); - retry_count = 0; // Reset retries on successful connection - - while let Some(result) = event_stream.next().await { - match result { - Ok(stream_item) => { - use a2a_rs::domain::TaskStateExt; - let (event_type, event_data) = match &stream_item { - StreamItem::Task(task) => { - if let Some(status) = task.status.as_option() { - if status.state.is_terminal() { - is_terminal = true; - } - } - match serde_json::to_string(task) { - Ok(json) => ("task-update", json), - Err(e) => { - error!("Failed to serialize task: {}", e); - continue; - } - } - } - StreamItem::StatusUpdate(status) => { - if status.status.state.is_terminal() { - is_terminal = true; - } - match serde_json::to_string(status) { - Ok(json) => ("task-status", json), - Err(e) => { - error!("Failed to serialize status: {}", e); - continue; - } - } - } - StreamItem::ArtifactUpdate(artifact) => { - match serde_json::to_string(artifact) { - Ok(json) => ("artifact", json), - Err(e) => { - error!("Failed to serialize artifact: {}", e); - continue; - } - } - } - }; - - yield Ok(Event::default() - .event(event_type) - .data(event_data)); - } - Err(e) => { - warn!("Stream error (continuing): {}", e); - continue; - } - } - } - - if is_terminal { - info!("Task {} reached terminal state. Ending stream gracefully.", task_id); - break; - } else { - warn!("Stream ended prematurely for task {}. Retrying...", task_id); - // Do not break; it will loop around and reconnect. - } - } - Err(e) => { - retry_count += 1; - - if retry_count <= max_retries { - // Exponential delay: base_delay * 2^(retry_count - 1) - let factor = 2u64.saturating_pow(retry_count.saturating_sub(1).min(6)); - let delay = base_delay.saturating_mul(factor as u32); - - // Jitter calculation (0-200ms) based on task_id bytes and system time - let jitter_ms = { - let mut state = 0u64; - for &b in task_id.as_bytes() { - state = state.wrapping_mul(6364136223846793005).wrapping_add(b as u64); - } - let time_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as u64; - state = state.wrapping_mul(6364136223846793005).wrapping_add(time_ms); - state % 200 - }; - let final_delay = delay.saturating_add(Duration::from_millis(jitter_ms)).min(max_delay); - - warn!( - "Failed to subscribe to task {} (attempt {}/{}): {}. Retrying in {:?}...", - task_id, retry_count, max_retries, e, final_delay - ); - - tokio::time::sleep(final_delay).await; - continue; - } else { - warn!("Failed to subscribe after {} retries: {}, aborting stream", max_retries, e); - break; - } + let updates = client.subscribe_resilient(&task_id, RetryPolicy::default()); + + let stream = updates.filter_map(|result| async move { + let event = match result { + Ok(event) => event, + Err(e) => { + warn!("Stream error (ending): {}", e); + return None; + } + }; + + let (event_type, data) = match &event.item { + StreamItem::Task(task) => ("task-update", serde_json::to_string(task)), + StreamItem::StatusUpdate(status) => ("task-status", serde_json::to_string(status)), + StreamItem::ArtifactUpdate(artifact) => ("artifact", serde_json::to_string(artifact)), + }; + + match data { + Ok(json) => { + let mut sse = Event::default().event(event_type).data(json); + if let Some(id) = event.event_id { + sse = sse.id(id.to_string()); } + Some(Ok(sse)) + } + Err(e) => { + error!("Failed to serialize {event_type}: {e}"); + None } } - }; + }); Sse::new(stream).keep_alive(KeepAlive::default()) } diff --git a/a2a-client/src/lib.rs b/a2a-client/src/lib.rs index 2e32654..a4103ec 100644 --- a/a2a-client/src/lib.rs +++ b/a2a-client/src/lib.rs @@ -23,7 +23,7 @@ //! ```rust,no_run //! use a2a_client::WebA2AClient; //! use a2a_rs::domain::Message; -//! use a2a_rs::services::AsyncA2AClient; +//! use a2a_rs::Transport; //! //! # #[tokio::main] //! # async fn main() -> anyhow::Result<()> { @@ -33,7 +33,7 @@ //! // Send a message //! let message = Message::user_text("Hello, agent!".to_string(), "msg-1".to_string()); //! -//! let task = client.http.send_task_message("task-1", &message, None, None).await?; +//! let task = client.transport.send_task_message("task-1", &message, None, None).await?; //! println!("Task ID: {}", task.id); //! # Ok(()) //! # } @@ -110,7 +110,12 @@ pub mod utils; // Re-export commonly used types pub use error::{ClientError, Result}; -use a2a_rs::HttpClient; +use std::pin::Pin; +use std::sync::Arc; + +use a2a_rs::domain::A2AError; +use a2a_rs::{HttpClient, RetryPolicy, StreamEvent, Transport, subscribe_resilient}; +use futures::Stream; /// Web-friendly A2A client that wraps both HTTP and WebSocket clients. /// @@ -148,8 +153,12 @@ use a2a_rs::HttpClient; /// # } /// ``` pub struct WebA2AClient { - /// HTTP client for A2A requests and streaming - pub http: HttpClient, + /// The negotiated transport for A2A requests and streaming. + /// + /// Held behind an `Arc` so the client is agnostic to the + /// underlying wire protocol (ConnectRPC, JSON-RPC 2.0, …) and can share the + /// transport with a reconnecting subscription stream. + pub transport: Arc, } impl WebA2AClient { @@ -183,15 +192,17 @@ impl WebA2AClient { /// ``` pub fn new_http(base_url: String) -> Self { Self { - http: HttpClient::new(base_url), + transport: Arc::new(HttpClient::new(base_url)), } } - /// Auto-connect to an agent, attempting to detect available transports. + /// Auto-connect to an agent by fetching its card and negotiating a transport. /// - /// Probes for WebSocket support by fetching the agent card from the server. - /// Falls back to HTTP-only if agent card fetching fails or if WebSocket - /// is not supported. + /// Fetches the agent card from the well-known endpoint and selects a transport + /// from the card's `supported_interfaces` (ConnectRPC preferred, JSON-RPC 2.0 + /// as interop fallback). If the card can't be fetched or none of its interfaces + /// match a compiled-in transport, falls back to a ConnectRPC client on + /// `base_url` so the call still works against a bare agent URL. /// /// # Arguments /// @@ -209,13 +220,64 @@ impl WebA2AClient { /// # } /// ``` pub async fn auto_connect(base_url: &str) -> Result { - // Validate URL format + // Validate URL format up front so a malformed URL is a hard error. let _ = reqwest::Url::parse(base_url).map_err(|e| ClientError::InvalidUrl { url: base_url.to_string(), reason: e.to_string(), })?; - Ok(Self::new_http(base_url.to_string())) + match a2a_rs::connect(base_url, &a2a_rs::default_registry()).await { + Ok(transport) => Ok(Self { + transport: Arc::from(transport), + }), + // Card fetch / negotiation failed — fall back to a direct ConnectRPC client. + Err(_) => Ok(Self::new_http(base_url.to_string())), + } + } + + /// Subscribe to a task's updates as a protocol-neutral stream of + /// [`StreamEvent`]s. + /// + /// This is the **spec-compliant** path: a single A2A `SubscribeToTask` round + /// trip with no reconnection and no `Last-Event-ID` — what any A2A agent + /// expects. For automatic reconnection (and gap-free resume against an + /// a2a-rs server) use + /// [`subscribe_resilient`](WebA2AClient::subscribe_resilient). + pub async fn subscribe( + &self, + task_id: &str, + ) -> Result> + Send>>> + { + self.transport + .subscribe_to_task(task_id, None, None) + .await + .map_err(Into::into) + } + + /// Subscribe to a task's updates with automatic reconnect + exponential + /// backoff. The stream ends when the task reaches a terminal state (or + /// retries are exhausted). + /// + /// Reconnection itself is spec-compliant (it re-issues `SubscribeToTask`). + /// Resuming *without gaps* via `Last-Event-ID` is an **a2a-rs enhancement** + /// beyond the A2A v1.0 spec: it works against an a2a-rs server and degrades + /// gracefully (reconnect-from-current-state) against any spec-compliant one. + /// + /// This is the reusable core that [`create_sse_stream`](components::create_sse_stream) + /// builds on; framework-agnostic, so non-Axum frontends can consume it + /// directly. + pub fn subscribe_resilient( + &self, + task_id: &str, + policy: RetryPolicy, + ) -> Pin> + Send>> { + subscribe_resilient( + self.transport.clone(), + task_id.to_string(), + None, + None, + policy, + ) } } /// Application state for Axum web applications. diff --git a/a2a-client/tests/e2e_framework_lifecycle_test.rs b/a2a-client/tests/e2e_framework_lifecycle_test.rs index 158f772..78227d0 100644 --- a/a2a-client/tests/e2e_framework_lifecycle_test.rs +++ b/a2a-client/tests/e2e_framework_lifecycle_test.rs @@ -1,14 +1,13 @@ use a2a_agents::{AgentPlugin, SkillDefinition}; use a2a_client::WebA2AClient; -use a2a_rs::adapter::{DefaultRequestProcessor, HttpServer, InMemoryTaskStorage, SimpleAgentInfo}; -use a2a_rs::domain::{A2AError, Message, Task, TaskState}; +use a2a_rs::adapter::{ConnectRpcAdapter, HttpServer, InMemoryTaskStorage, SimpleAgentInfo}; +use a2a_rs::domain::{A2AError, ContextId, Message, Task, TaskId, TaskState}; use a2a_rs::port::AsyncMessageHandler; -use a2a_rs::services::AsyncA2AClient; use async_trait::async_trait; use std::time::Duration; use tokio::sync::oneshot; -use a2a_rs::port::AsyncTaskManager; +use a2a_rs::port::AsyncTaskLifecycle; /// Simple mock agent for E2E tests #[derive(Clone)] @@ -54,16 +53,18 @@ impl AsyncMessageHandler for EchoAgent { }; // Create or get the task using storage - let task = if !self.storage.task_exists(task_id).await? { - self.storage.create_task(task_id, "context-1").await? + let id: TaskId = task_id.parse()?; + let task = if !self.storage.exists(&id).await? { + let context_id: ContextId = "context-1".parse()?; + self.storage.create(&id, &context_id).await? } else { - self.storage.get_task(task_id, None).await? + self.storage.get(&id, None).await? }; let reply_msg = Message::agent_text(reply_text, "msg-res-1".to_string()); self.storage - .update_task_status(task_id, TaskState::Completed, Some(reply_msg.clone())) + .update_status(&id, TaskState::Completed, Some(reply_msg.clone())) .await?; let mut t = task; @@ -81,7 +82,7 @@ async fn test_framework_lifecycle_e2e() { ); let storage = InMemoryTaskStorage::new(); - let processor = DefaultRequestProcessor::new( + let processor = ConnectRpcAdapter::new( EchoAgent { storage: storage.clone(), }, @@ -109,7 +110,7 @@ async fn test_framework_lifecycle_e2e() { let message = Message::user_text("Hello Framework!".to_string(), "msg-1".to_string()); let task = client - .http + .transport .send_task_message(&task_id, &message, None, None) .await .expect("Failed to send task"); @@ -122,7 +123,7 @@ async fn test_framework_lifecycle_e2e() { // 4. Client fetches final result let final_task = client - .http + .transport .get_task(&task_id, None) .await .expect("Failed to fetch task"); diff --git a/a2a-client/tests/sse_streaming_test.rs b/a2a-client/tests/sse_streaming_test.rs index 38cf2a9..670beb5 100644 --- a/a2a-client/tests/sse_streaming_test.rs +++ b/a2a-client/tests/sse_streaming_test.rs @@ -2,10 +2,9 @@ use a2a_agents::core::AgentBuilder; use a2a_client::WebA2AClient; use a2a_client::components::create_sse_stream; use a2a_rs::{ - InMemoryTaskStorage, + InMemoryStreamingHandler, InMemoryTaskStorage, domain::{A2AError, Message, Part, Role, Task, TaskState, TaskStatusUpdateEvent}, port::{AsyncMessageHandler, AsyncStreamingHandler}, - services::client::AsyncA2AClient, }; use async_trait::async_trait; use futures_util::StreamExt; @@ -15,7 +14,7 @@ use tokio::time::sleep; #[derive(Clone)] struct StreamingHandler { - storage: Arc, + streaming: InMemoryStreamingHandler, } #[async_trait] @@ -41,12 +40,12 @@ impl AsyncMessageHandler for StreamingHandler { .build(); // Spawn a background task to simulate a streaming delay - let storage = self.storage.clone(); + let streaming = self.streaming.clone(); let tid = task_id.to_string(); tokio::spawn(async move { sleep(Duration::from_millis(1000)).await; - let _ = storage + let _ = streaming .broadcast_status_update( &tid, TaskStatusUpdateEvent { @@ -88,7 +87,7 @@ async fn test_sse_stream_success() { let storage = Arc::new(InMemoryTaskStorage::new()); let handler = StreamingHandler { - storage: storage.clone(), + streaming: InMemoryStreamingHandler::new(), }; let runtime = AgentBuilder::from_toml(toml_config) @@ -116,7 +115,7 @@ async fn test_sse_stream_success() { .build(); let task: Task = client - .http + .transport .send_task_message("test-task-1", &message, None, None) .await .unwrap(); @@ -124,8 +123,8 @@ async fn test_sse_stream_success() { // Check if subscribe_to_task works natively let mut native_stream = client - .http - .subscribe_to_task("test-task-1", None) + .transport + .subscribe_to_task("test-task-1", None, None) .await .unwrap(); if let Some(item) = native_stream.next().await { @@ -137,7 +136,7 @@ async fn test_sse_stream_success() { // Now test SSE Stream // Note: Re-subscribing might only get new events, so we might need another task for SSE testing let task2 = client - .http + .transport .send_task_message("test-task-2", &message, None, None) .await .unwrap(); diff --git a/a2a-mcp/Cargo.toml b/a2a-mcp/Cargo.toml index d21eb22..bcba327 100644 --- a/a2a-mcp/Cargo.toml +++ b/a2a-mcp/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "a2a-mcp" version = "0.3.0" -edition = "2021" +edition = "2024" rust-version = "1.85" authors = ["Emil Lindfors "] description = "Bidirectional integration between A2A Protocol and Model Context Protocol (MCP)" @@ -22,35 +22,35 @@ buffa = { version = "0.3.0", features = ["json"] } buffa-types = { version = "0.3.0", features = ["json"] } # Async runtime and utilities -tokio = { version = "1", features = ["full"] } -async-trait = "0.1" -futures = "0.3" +tokio = { workspace = true, features = ["full"] } +async-trait = { workspace = true } +futures = { workspace = true } # Serialization -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde = { workspace = true } +serde_json = { workspace = true } base64 = "0.22" # Error handling -thiserror = "2" -anyhow = "1.0" +thiserror = { workspace = true } +anyhow = { workspace = true } # Logging -tracing = "0.1" +tracing = { workspace = true } # Schema generation (matching rmcp's version) schemars = { version = "1.0", features = ["chrono04"] } # Date/time -chrono = { version = "0.4", features = ["serde"] } +chrono = { workspace = true } # UUID generation -uuid = { version = "1", features = ["v4", "serde"] } +uuid = { workspace = true, features = ["v4", "serde"] } [dev-dependencies] -tokio = { version = "1", features = ["full"] } -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -anyhow = "1.0" +tokio = { workspace = true, features = ["full"] } +tracing-subscriber = { workspace = true } +anyhow = { workspace = true } [features] default = [] diff --git a/a2a-mcp/README.md b/a2a-mcp/README.md index dbe6052..cc3d72f 100644 --- a/a2a-mcp/README.md +++ b/a2a-mcp/README.md @@ -56,4 +56,4 @@ flowchart TD ## Development Status -See [TODO.md](TODO.md) for current implementation status and next steps. \ No newline at end of file +See the workspace [ROADMAP.md](../ROADMAP.md) for deferred themes and next steps. \ No newline at end of file diff --git a/a2a-mcp/examples/a2a_as_mcp_server.rs b/a2a-mcp/examples/a2a_as_mcp_server.rs index ecdafaf..c1d82b9 100644 --- a/a2a-mcp/examples/a2a_as_mcp_server.rs +++ b/a2a-mcp/examples/a2a_as_mcp_server.rs @@ -14,17 +14,18 @@ use std::time::Duration; use a2a_mcp::AgentToMcpBridge; use a2a_rs::{ adapter::{ - business::{DefaultMessageHandler, DefaultRequestProcessor}, - storage::InMemoryTaskStorage, - transport::http::HttpClient, HttpServer, SimpleAgentInfo, + business::ResponderMessageHandler, + storage::InMemoryTaskStorage, + streaming::InMemoryStreamingHandler, + transport::{connectrpc::ConnectRpcAdapter, http::HttpClient}, }, - domain::{error::A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, + domain::{Message, Part, Role, Task, TaskState, TaskStatus, error::A2AError}, port::AsyncMessageHandler, services::AgentInfoProvider, }; use async_trait::async_trait; -use rmcp::{model::CallToolRequestParams, ServiceExt}; +use rmcp::{ServiceExt, model::CallToolRequestParams}; use tracing_subscriber::EnvFilter; const AGENT_ADDR: &str = "127.0.0.1:18182"; @@ -32,12 +33,13 @@ const AGENT_URL: &str = "http://127.0.0.1:18182"; /// Minimal A2A handler that echoes incoming text. /// -/// Wraps a `DefaultMessageHandler` to satisfy the storage-touching bits +/// Wraps a `ResponderMessageHandler` to satisfy the storage-touching bits /// (task creation, history persistence) and overrides response generation /// with a simple echo. #[derive(Clone)] struct EchoHandler { storage: Arc, + streaming: InMemoryStreamingHandler, } #[async_trait] @@ -48,9 +50,13 @@ impl AsyncMessageHandler for EchoHandler { message: &Message, session_id: Option<&str>, ) -> Result { - // Delegate to DefaultMessageHandler for proper storage semantics, then + // Delegate to ResponderMessageHandler for proper storage semantics, then // synthesize an echo response on top of whatever it returned. - let inner = DefaultMessageHandler::new((*self.storage).clone()); + let inner = ResponderMessageHandler::echo( + (*self.storage).clone(), + self.streaming.clone(), + self.storage.push_notifier(), + ); let mut task = inner.process_message(task_id, message, session_id).await?; let echoed = message @@ -84,6 +90,7 @@ async fn main() -> anyhow::Result<()> { let storage = Arc::new(InMemoryTaskStorage::new()); let handler = EchoHandler { storage: storage.clone(), + streaming: InMemoryStreamingHandler::new(), }; let agent_info = SimpleAgentInfo::new("Echo Agent".to_string(), AGENT_URL.to_string()) @@ -94,7 +101,7 @@ async fn main() -> anyhow::Result<()> { Some("Repeat the input back".to_string()), ); - let processor = DefaultRequestProcessor::new( + let processor = ConnectRpcAdapter::new( handler, (*storage).clone(), (*storage).clone(), diff --git a/a2a-mcp/examples/a2a_with_mcp_tools.rs b/a2a-mcp/examples/a2a_with_mcp_tools.rs index 312268c..0627d81 100644 --- a/a2a-mcp/examples/a2a_with_mcp_tools.rs +++ b/a2a-mcp/examples/a2a_with_mcp_tools.rs @@ -10,14 +10,14 @@ use std::sync::Arc; -use a2a_mcp::{create_tool_call_message, McpToA2ABridge}; +use a2a_mcp::{McpToA2ABridge, create_tool_call_message}; use a2a_rs::{ - domain::{error::A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, + domain::{Message, Part, Role, Task, TaskState, TaskStatus, error::A2AError}, port::AsyncMessageHandler, }; use async_trait::async_trait; use rmcp::{ - model::*, service::RequestContext, ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, + ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, model::*, service::RequestContext, }; use serde_json::json; use tracing_subscriber::EnvFilter; diff --git a/a2a-mcp/examples/bidirectional_demo.rs b/a2a-mcp/examples/bidirectional_demo.rs index d56c359..d895aa1 100644 --- a/a2a-mcp/examples/bidirectional_demo.rs +++ b/a2a-mcp/examples/bidirectional_demo.rs @@ -31,21 +31,22 @@ use std::sync::Arc; use std::time::Duration; -use a2a_mcp::{create_tool_call_message, AgentToMcpBridge, McpToA2ABridge}; +use a2a_mcp::{AgentToMcpBridge, McpToA2ABridge, create_tool_call_message}; use a2a_rs::{ adapter::{ - business::{DefaultMessageHandler, DefaultRequestProcessor}, - storage::InMemoryTaskStorage, - transport::http::HttpClient, HttpServer, SimpleAgentInfo, + business::ResponderMessageHandler, + storage::InMemoryTaskStorage, + streaming::InMemoryStreamingHandler, + transport::{connectrpc::ConnectRpcAdapter, http::HttpClient}, }, - domain::{error::A2AError, Message, Part, Role, Task, TaskState, TaskStatus}, + domain::{Message, Part, Role, Task, TaskState, TaskStatus, error::A2AError}, port::AsyncMessageHandler, services::AgentInfoProvider, }; use async_trait::async_trait; use rmcp::{ - model::*, service::RequestContext, ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, + ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, model::*, service::RequestContext, }; use serde_json::json; use tracing_subscriber::EnvFilter; @@ -145,6 +146,7 @@ impl ServerHandler for CalcServer { #[derive(Clone)] struct EchoHandler { storage: Arc, + streaming: InMemoryStreamingHandler, } #[async_trait] @@ -155,7 +157,11 @@ impl AsyncMessageHandler for EchoHandler { message: &Message, session_id: Option<&str>, ) -> Result { - let inner = DefaultMessageHandler::new((*self.storage).clone()); + let inner = ResponderMessageHandler::echo( + (*self.storage).clone(), + self.streaming.clone(), + self.storage.push_notifier(), + ); let mut task = inner.process_message(task_id, message, session_id).await?; let echoed = extract_text(message); @@ -253,6 +259,7 @@ async fn main() -> anyhow::Result<()> { let storage = Arc::new(InMemoryTaskStorage::new()); let echo = EchoHandler { storage: storage.clone(), + streaming: InMemoryStreamingHandler::new(), }; let mcp_to_a2a = Arc::new(McpToA2ABridge::new(calc_peer, echo).await?); let math_handler = MathHandler { bridge: mcp_to_a2a }; @@ -269,7 +276,7 @@ async fn main() -> anyhow::Result<()> { ), ); - let processor = DefaultRequestProcessor::new( + let processor = ConnectRpcAdapter::new( math_handler, (*storage).clone(), (*storage).clone(), diff --git a/a2a-mcp/src/bridge/agent_to_mcp.rs b/a2a-mcp/src/bridge/agent_to_mcp.rs index e78cc08..353301b 100644 --- a/a2a-mcp/src/bridge/agent_to_mcp.rs +++ b/a2a-mcp/src/bridge/agent_to_mcp.rs @@ -6,13 +6,13 @@ use crate::{ }; use a2a_rs::{ adapter::transport::http::HttpClient, - domain::{error::A2AError, AgentCard, Message, Part, Role, Task}, + domain::{AgentCard, Message, Part, Role, Task, error::A2AError}, port::AsyncMessageHandler, - services::client::AsyncA2AClient, + port::client::Transport, }; use async_trait::async_trait; use futures::{Stream, StreamExt}; -use rmcp::{model::*, service::RequestContext, ErrorData as McpError, RoleServer, ServerHandler}; +use rmcp::{ErrorData as McpError, RoleServer, ServerHandler, model::*, service::RequestContext}; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -41,12 +41,7 @@ pub trait BridgeBackend: Send + Sync { _task_id: &str, ) -> std::result::Result< Option< - Pin< - Box< - dyn Stream> - + Send, - >, - >, + Pin> + Send>>, >, A2AError, > { @@ -164,24 +159,19 @@ where task_id: &str, ) -> std::result::Result< Option< - Pin< - Box< - dyn Stream> - + Send, - >, - >, + Pin> + Send>>, >, A2AError, > { if let Some(ref sh) = self.streaming_handler { - let stream = sh.combined_update_stream(task_id).await?; + let stream = sh.combined_update_stream(task_id, None).await?; let mapped = stream.map(|res| { - res.map(|event| match event { + res.map(|seq| match seq.event { a2a_rs::port::UpdateEvent::StatusUpdate(status) => { - a2a_rs::services::StreamItem::StatusUpdate(status) + a2a_rs::StreamItem::StatusUpdate(status) } a2a_rs::port::UpdateEvent::ArtifactUpdate(artifact) => { - a2a_rs::services::StreamItem::ArtifactUpdate(artifact) + a2a_rs::StreamItem::ArtifactUpdate(artifact) } }) }); @@ -523,7 +513,7 @@ impl AgentToMcpBridge { let item = item_res.map_err(|e| A2aMcpError::AgentCommunication(e.to_string()))?; match item { - a2a_rs::services::StreamItem::Task(t) => { + a2a_rs::StreamItem::Task(t) => { debug!("Stream initial task for {}: {:?}", t.id, t.status.state); task = t; self.tasks_cache @@ -588,7 +578,10 @@ impl AgentToMcpBridge { let sampling_res = match sampling_res_result { Ok(res) => res, Err(e) => { - debug!("Sampling failed or unavailable: {e}. Suspending task {} and returning to LLM.", task.id); + debug!( + "Sampling failed or unavailable: {e}. Suspending task {} and returning to LLM.", + task.id + ); break; } }; @@ -634,7 +627,7 @@ impl AgentToMcpBridge { break; } } - a2a_rs::services::StreamItem::StatusUpdate(event) => { + a2a_rs::StreamItem::StatusUpdate(event) => { debug!( "Stream status update for {}: {:?}", task.id, event.status.state @@ -703,7 +696,10 @@ impl AgentToMcpBridge { let sampling_res = match sampling_res_result { Ok(res) => res, Err(e) => { - debug!("Sampling failed or unavailable: {e}. Suspending task {} and returning to LLM.", task.id); + debug!( + "Sampling failed or unavailable: {e}. Suspending task {} and returning to LLM.", + task.id + ); break; } }; @@ -749,7 +745,7 @@ impl AgentToMcpBridge { break; } } - a2a_rs::services::StreamItem::ArtifactUpdate(event) => { + a2a_rs::StreamItem::ArtifactUpdate(event) => { debug!( "Stream artifact update for {}: {}", task.id, event.artifact.artifact_id @@ -832,7 +828,10 @@ impl AgentToMcpBridge { let sampling_res = match sampling_res_result { Ok(res) => res, Err(e) => { - debug!("Sampling failed or unavailable: {e}. Suspending task {} and returning to LLM.", task.id); + debug!( + "Sampling failed or unavailable: {e}. Suspending task {} and returning to LLM.", + task.id + ); break; } }; @@ -1006,12 +1005,12 @@ impl ServerHandler for AgentToMcpBridge { let mut extensions = ExtensionCapabilities::new(); for scheme in self.agent_card.security_schemes.values() { if let Some(a2a_rs::domain::generated::security_scheme::Scheme::Oauth2SecurityScheme( - ref oauth2_scheme, + oauth2_scheme, )) = &scheme.scheme { if let Some(flows) = oauth2_scheme.flows.as_option() { if let Some(a2a_rs::domain::generated::o_auth_flows::Flow::ClientCredentials( - ref cc, + cc, )) = &flows.flow { let mut cc_settings = serde_json::Map::new(); @@ -1448,7 +1447,7 @@ impl ServerHandler for AgentToMcpBridge { return Err(McpError::internal_error( format!("Failed to serialize data: {}", e), None, - )) + )); } }; contents.push(ResourceContents::TextResourceContents { diff --git a/a2a-mcp/src/bridge/mcp_to_a2a.rs b/a2a-mcp/src/bridge/mcp_to_a2a.rs index ccea9d9..9f6000a 100644 --- a/a2a-mcp/src/bridge/mcp_to_a2a.rs +++ b/a2a-mcp/src/bridge/mcp_to_a2a.rs @@ -1,7 +1,7 @@ //! Bridge that provides MCP tools as capabilities to A2A agents use crate::{ - converters::{llm_tool::LlmToolConverter, MessageConverter}, + converters::{MessageConverter, llm_tool::LlmToolConverter}, error::{A2aMcpError, Result}, }; use a2a_agents_common::llm::{ToolCall, ToolDefinition}; @@ -11,10 +11,10 @@ use a2a_rs::{ }; use async_trait::async_trait; use rmcp::{ + Peer, RoleClient, handler::client::progress::ProgressDispatcher, model::*, service::{NotificationContext, PeerRequestOptions}, - Peer, RoleClient, }; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; @@ -71,7 +71,10 @@ impl Drop for RequestCancelGuard { let peer = self.peer.clone(); let request_id = self.request_id.clone(); tokio::spawn(async move { - debug!("RequestCancelGuard triggered: notifying server of cancellation for request: {:?}", request_id); + debug!( + "RequestCancelGuard triggered: notifying server of cancellation for request: {:?}", + request_id + ); let _ = peer .notify_cancelled(CancelledNotificationParam { request_id, @@ -408,7 +411,7 @@ impl McpToA2ABridge { Ok(_) => { return Err(A2aMcpError::McpServer( "Unexpected response from MCP server".to_string(), - )) + )); } Err(e) => return Err(e.into()), }; @@ -468,7 +471,7 @@ impl McpToA2ABridge { Ok(_) => { return Err(A2aMcpError::McpServer( "Unexpected response from MCP server".to_string(), - )) + )); } Err(e) => return Err(e.into()), }; diff --git a/a2a-mcp/src/client.rs b/a2a-mcp/src/client.rs index 85daa76..cde10d6 100644 --- a/a2a-mcp/src/client.rs +++ b/a2a-mcp/src/client.rs @@ -3,19 +3,19 @@ use crate::adapter::AgentToToolAdapter; use crate::error::{Error, Result}; use a2a_rs::domain::agent::AgentCard; -use a2a_rs::port::client::AsyncA2AClient; +use a2a_rs::port::client::Transport; use rmcp::{Tool, ToolCall, ToolResponse}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use tracing::{info, debug, error}; /// A client that accesses A2A agents as RMCP tools -pub struct A2aRmcpClient { +pub struct A2aRmcpClient { a2a_client: C, adapter: Arc>, } -impl A2aRmcpClient { +impl A2aRmcpClient { /// Create a new client that discovers A2A agents pub fn new(a2a_client: C) -> Self { Self { diff --git a/a2a-mcp/src/converters/task_result.rs b/a2a-mcp/src/converters/task_result.rs index 79a632f..caef3af 100644 --- a/a2a-mcp/src/converters/task_result.rs +++ b/a2a-mcp/src/converters/task_result.rs @@ -156,11 +156,13 @@ mod tests { .id("task-1".to_string()) .context_id("ctx-1".to_string()) .status(TaskStatus::new(TaskState::TASK_STATE_COMPLETED, None)) - .history(vec![Message::builder() - .role(Role::Agent) - .parts(vec![Part::text("Result text".to_string())]) - .message_id("msg-1".to_string()) - .build()]) + .history(vec![ + Message::builder() + .role(Role::Agent) + .parts(vec![Part::text("Result text".to_string())]) + .message_id("msg-1".to_string()) + .build(), + ]) .artifacts(vec![Artifact { artifact_id: "art-1".to_string(), name: "Test Artifact".to_string(), @@ -183,11 +185,13 @@ mod tests { .id("task-2".to_string()) .context_id("ctx-2".to_string()) .status(TaskStatus::new(TaskState::TASK_STATE_FAILED, None)) - .history(vec![Message::builder() - .role(Role::Agent) - .parts(vec![Part::text("Error details".to_string())]) - .message_id("msg-2".to_string()) - .build()]) + .history(vec![ + Message::builder() + .role(Role::Agent) + .parts(vec![Part::text("Error details".to_string())]) + .message_id("msg-2".to_string()) + .build(), + ]) .build(); let result = TaskResultConverter::task_to_result(&task).unwrap(); diff --git a/a2a-mcp/src/lib.rs b/a2a-mcp/src/lib.rs index 6e2a19e..fca9a6b 100644 --- a/a2a-mcp/src/lib.rs +++ b/a2a-mcp/src/lib.rs @@ -144,7 +144,7 @@ pub mod error; // Re-export key types pub use bridge::mcp_to_a2a::{ - attach_tool_call, create_tool_call_message, McpToolCall, MCP_TOOL_CALL_METADATA_KEY, + MCP_TOOL_CALL_METADATA_KEY, McpToolCall, attach_tool_call, create_tool_call_message, }; pub use bridge::{AgentToMcpBridge, McpToA2ABridge}; pub use converters::{MessageConverter, SkillToolConverter, TaskResultConverter}; diff --git a/a2a-mcp/tests/agent_to_mcp_integration.rs b/a2a-mcp/tests/agent_to_mcp_integration.rs index 446258a..e3eac81 100644 --- a/a2a-mcp/tests/agent_to_mcp_integration.rs +++ b/a2a-mcp/tests/agent_to_mcp_integration.rs @@ -4,8 +4,8 @@ use std::pin::Pin; use std::sync::{ - atomic::{AtomicUsize, Ordering}, Arc, Mutex, + atomic::{AtomicUsize, Ordering}, }; use a2a_mcp::bridge::agent_to_mcp::AgentToMcpBridge; @@ -13,14 +13,14 @@ use a2a_mcp::converters::skill_tool::SkillToolConverter; use a2a_rs::adapter::transport::http::HttpClient; use a2a_rs::domain::core::agent::{AgentCapabilities, AgentCard, AgentSkill}; use a2a_rs::domain::{ - error::A2AError, Message, Part, Role, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus, - TaskStatusUpdateEvent, + Message, Part, Role, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus, + TaskStatusUpdateEvent, error::A2AError, }; use a2a_rs::port::streaming_handler::Subscriber; -use a2a_rs::port::{AsyncMessageHandler, AsyncStreamingHandler, UpdateEvent}; +use a2a_rs::port::{AsyncMessageHandler, AsyncStreamingHandler, SeqEvent, UpdateEvent}; use async_trait::async_trait; use rmcp::service::{NotificationContext, RequestContext}; -use rmcp::{model::*, ClientHandler, ErrorData as McpError, RoleClient, ServerHandler, ServiceExt}; +use rmcp::{ClientHandler, ErrorData as McpError, RoleClient, ServerHandler, ServiceExt, model::*}; #[tokio::test] async fn test_agent_skills_as_mcp_tools() { @@ -467,42 +467,49 @@ impl AsyncStreamingHandler for MockStreamingHandler { async fn combined_update_stream( &self, task_id: &str, - ) -> Result> + Send>>, A2AError> + _from_event_id: Option, + ) -> Result> + Send>>, A2AError> { let task_id = task_id.to_string(); let events = vec![ - Ok(UpdateEvent::StatusUpdate(TaskStatusUpdateEvent { - task_id: task_id.clone(), - context_id: "ctx-1".to_string(), - kind: "status-update".to_string(), - status: TaskStatus::new( - TaskState::Working, - Some( - Message::builder() - .role(Role::Agent) - .parts(vec![Part::text("Doing step 1".to_string())]) - .message_id("step-1-msg".to_string()) - .build(), + Ok(SeqEvent::new( + 1, + UpdateEvent::StatusUpdate(TaskStatusUpdateEvent { + task_id: task_id.clone(), + context_id: "ctx-1".to_string(), + kind: "status-update".to_string(), + status: TaskStatus::new( + TaskState::Working, + Some( + Message::builder() + .role(Role::Agent) + .parts(vec![Part::text("Doing step 1".to_string())]) + .message_id("step-1-msg".to_string()) + .build(), + ), ), - ), - metadata: None, - })), - Ok(UpdateEvent::StatusUpdate(TaskStatusUpdateEvent { - task_id: task_id.clone(), - context_id: "ctx-1".to_string(), - kind: "status-update".to_string(), - status: TaskStatus::new( - TaskState::InputRequired, - Some( - Message::builder() - .role(Role::Agent) - .parts(vec![Part::text("Please provide input:".to_string())]) - .message_id("elicitation-msg".to_string()) - .build(), + metadata: None, + }), + )), + Ok(SeqEvent::new( + 2, + UpdateEvent::StatusUpdate(TaskStatusUpdateEvent { + task_id: task_id.clone(), + context_id: "ctx-1".to_string(), + kind: "status-update".to_string(), + status: TaskStatus::new( + TaskState::InputRequired, + Some( + Message::builder() + .role(Role::Agent) + .parts(vec![Part::text("Please provide input:".to_string())]) + .message_id("elicitation-msg".to_string()) + .build(), + ), ), - ), - metadata: None, - })), + metadata: None, + }), + )), ]; Ok(Box::pin(futures::stream::iter(events))) @@ -751,14 +758,7 @@ impl a2a_mcp::bridge::agent_to_mcp::BridgeBackend for MockPollingBackend { &self, _task_id: &str, ) -> Result< - Option< - Pin< - Box< - dyn futures::Stream> - + Send, - >, - >, - >, + Option> + Send>>>, A2AError, > { // Return Ok(None) to force polling fallback diff --git a/a2a-mcp/tests/bidirectional_integration.rs b/a2a-mcp/tests/bidirectional_integration.rs index 73e6256..cfb2160 100644 --- a/a2a-mcp/tests/bidirectional_integration.rs +++ b/a2a-mcp/tests/bidirectional_integration.rs @@ -85,11 +85,13 @@ async fn test_error_handling_bidirectional() { .id("task-failed".to_string()) .context_id("ctx-failed".to_string()) .status(TaskStatus::new(TaskState::Failed, None)) - .history(vec![Message::builder() - .role(Role::Agent) - .parts(vec![Part::text("Error: Something went wrong".to_string())]) - .message_id("msg-error".to_string()) - .build()]) + .history(vec![ + Message::builder() + .role(Role::Agent) + .parts(vec![Part::text("Error: Something went wrong".to_string())]) + .message_id("msg-error".to_string()) + .build(), + ]) .build(); // Convert to MCP result @@ -116,15 +118,17 @@ async fn test_skill_tool_bidirectional_metadata() { .capabilities(Default::default()) .default_input_modes(vec!["text".to_string()]) .default_output_modes(vec!["text".to_string()]) - .skills(vec![AgentSkill::new( - "test_skill".to_string(), - "Test Skill".to_string(), - "A skill for testing metadata preservation".to_string(), - vec!["test".to_string(), "metadata".to_string()], - ) - .with_examples(vec!["Example usage".to_string()]) - .with_input_modes(vec!["text".to_string()]) - .with_output_modes(vec!["text".to_string()])]) + .skills(vec![ + AgentSkill::new( + "test_skill".to_string(), + "Test Skill".to_string(), + "A skill for testing metadata preservation".to_string(), + vec!["test".to_string(), "metadata".to_string()], + ) + .with_examples(vec!["Example usage".to_string()]) + .with_input_modes(vec!["text".to_string()]) + .with_output_modes(vec!["text".to_string()]), + ]) .build(); let client = HttpClient::new("https://example.com/agent".to_string()); diff --git a/a2a-mcp/tests/mcp_to_a2a_integration.rs b/a2a-mcp/tests/mcp_to_a2a_integration.rs index 096eec6..4e22142 100644 --- a/a2a-mcp/tests/mcp_to_a2a_integration.rs +++ b/a2a-mcp/tests/mcp_to_a2a_integration.rs @@ -3,7 +3,7 @@ //! This test verifies that MCP tools and prompts can be successfully exposed as A2A agent skills use a2a_mcp::bridge::mcp_to_a2a::{ - create_prompt_call_message, create_tool_call_message, McpToA2ABridge, ProgressClientHandler, + McpToA2ABridge, ProgressClientHandler, create_prompt_call_message, create_tool_call_message, }; use a2a_rs::domain::core::agent::AgentCard; use a2a_rs::domain::{ @@ -11,11 +11,11 @@ use a2a_rs::domain::{ TaskStatusUpdateEvent, }; use a2a_rs::port::streaming_handler::Subscriber; -use a2a_rs::port::{AsyncMessageHandler, AsyncStreamingHandler, UpdateEvent}; +use a2a_rs::port::{AsyncMessageHandler, AsyncStreamingHandler, SeqEvent}; use async_trait::async_trait; use rmcp::{ - handler::client::progress::ProgressDispatcher, model::*, service::RequestContext, ErrorData as McpError, RoleServer, ServerHandler, ServiceExt, + handler::client::progress::ProgressDispatcher, model::*, service::RequestContext, }; use std::pin::Pin; use std::sync::{Arc, Mutex}; @@ -331,10 +331,11 @@ impl AsyncStreamingHandler for TestStreamingHandler { async fn combined_update_stream( &self, _task_id: &str, + _from_event_id: Option, ) -> Result< Pin< Box< - dyn futures::Stream> + dyn futures::Stream> + Send, >, >, diff --git a/a2a-rs/CHANGELOG.md b/a2a-rs/CHANGELOG.md index 30d007f..04e64ec 100644 --- a/a2a-rs/CHANGELOG.md +++ b/a2a-rs/CHANGELOG.md @@ -7,6 +7,222 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **`impl AsyncStreamingHandler for Arc`** — a + forwarding blanket impl so a type-erased, shared streaming backend can be + passed wherever an `impl AsyncStreamingHandler` is expected (e.g. + `TaskService::with_streaming_handler`). This lets one streaming instance be + injected into both a message handler and a transport adapter without naming + its concrete type, so handler broadcasts and SSE subscribers share a registry. + +### Breaking Changes — Port capability decomposition + +The server-side `AsyncTaskManager` port trait carried 17 methods spanning four +distinct capabilities. It has been **removed** and split into focused capability +traits. All consumers are in-workspace; there is no deprecation shim. + +#### Task ports + +- **Removed** `AsyncTaskManager`. +- **Added** `AsyncTaskLifecycle` — per-task CRUD: `create`, `get`, `update_status`, + `cancel`, `exists`. +- **Added** `AsyncTaskQuery` — cross-task listing: `list`. +- **Added** `AsyncTaskLifecycleExt` (blanket-implemented) — validation + conveniences: `get_validated`, `cancel_validated`. +- Express requirements at the use site (e.g. `T: AsyncTaskLifecycle + AsyncTaskQuery`); + there is no umbrella trait. + +Method renames (the noun-prefix is redundant once the trait carries it): + +| Old (`AsyncTaskManager`) | New | +|----------------------------|------------------------------------| +| `create_task(id, ctx)` | `AsyncTaskLifecycle::create` | +| `get_task(id, hist)` | `AsyncTaskLifecycle::get` | +| `update_task_status(...)` | `AsyncTaskLifecycle::update_status` | +| `cancel_task(id)` | `AsyncTaskLifecycle::cancel` | +| `task_exists(id)` | `AsyncTaskLifecycle::exists` | +| `list_tasks_v3(params)` | `AsyncTaskQuery::list` | + +- **Removed** the dead `get_task_metadata` and legacy `list_tasks(context, limit)` + methods (never called). + +#### Push-notification ports + +- The four v1.0.0 push-config methods moved **off** `AsyncTaskManager` and were + reconciled into `AsyncNotificationManager`, now expressed in terms of the + richer multi-config model: `set_config`, `get_config`, `list_configs`, + `delete_config`. +- **Added** `AsyncNotificationManagerExt` (blanket-implemented): `validate_config`, + `set_validated`. +- **Removed** the drifting single-config methods (`set_task_notification`, + `get_task_notification`, `remove_task_notification`, `has_task_notification`, + `send_test_notification`) and the unused `notify_task_status_update` / + `notify_task_artifact_update` stubs from the async trait. The synchronous + `NotificationManager` trait is unchanged. + +#### Strongly-typed identifiers + +- **Added** `TaskId`, `ContextId`, `PushConfigId` newtypes (`domain::ids`, + re-exported from the crate root). Each validates non-emptiness on construction + (`FromStr`/`TryFrom`), making argument-order mix-ups a compile error. They + appear in the new port signatures; conversion from wire strings happens once at + the RPC boundary. `#[serde(transparent)]` deserialization bypasses validation + by design (validated at the boundary). + +#### Dispatch — ports held as `Arc` at the composition edge + +The composition-edge structs no longer carry viral generic parameters; they hold +their ports as `Arc` trait objects. Dispatch goes through the vtable — +one indirect call per RPC, negligible on the I/O-bound port boundary — and the +generic noise disappears from every type that holds a processor or handler. + +- **`DefaultRequestProcessor`** lost its five generic parameters + (``). It is now a plain non-generic struct with + `Arc`, `Arc`, + `Arc`, `Arc`, + `Arc`, and `Arc` fields. + Constructors (`new`, `with_handler`, `with_streaming_handler`) now take + `impl Trait` arguments, so call sites are unchanged. +- **`DefaultMessageHandler`** lost its `` parameter; it holds + `Arc` and its constructor takes + `impl AsyncTaskLifecycle + 'static`. +- **`ReimbursementHandler`** (in `a2a-agents`) lost its `` parameter; it holds + `Arc` + `Arc`. The `Clone` + bound it forced on storage is gone (cloning an `Arc` is a refcount bump). + +#### Migration + +- Construction is source-compatible: the de-generic'd constructors accept the + same arguments via `impl Trait`, so existing `DefaultRequestProcessor::new(…)` + / `ReimbursementHandler::new(…)` call sites compile unchanged. +- Code that named the processor's generic parameters + (`DefaultRequestProcessor`) must drop the type arguments — the + type is now non-generic. +- The **HTTP client** API (`HttpClient::get_task`, `cancel_task`, etc.) is + unaffected — those names belong to the client surface, not the server port. + +### Added — cross-port `TaskStatusBroadcast` mixin + +The capability-mixin pattern from `.claude/rules/hexagonal_architecture.md` §9, +applied at the port boundary (`application::task_status_broadcast`, behind the +`server` feature): + +- **Added** accessor ingredients `HasTaskLifecycle` and `HasStreaming` — each + hands out a `&dyn` **port**, never a concrete adapter. +- **Added** `TaskStatusBroadcast`, a blanket-implemented mixin giving any host + that exposes both ingredients an `update_and_broadcast` ("commit the status + through the lifecycle port, then announce it through the streaming port") + method for free. A host exposing only one ingredient does not get the method — + a `compile_fail` doc test pins that guarantee. +- `TaskService` implements both accessors (see below), so it gains + `update_and_broadcast` without coupling its lifecycle and streaming ports. + +This is additive (no behavior change to existing call paths). Consuming it in +the request flow — and shedding the storage adapter's internal self-broadcast — +is deferred (`REFACTORING_PLAN.md` §4.0.2). + +### Added — application/transport split (`REFACTORING_PLAN.md` §4.2) + +`DefaultRequestProcessor` previously did two jobs: orchestrating the ports and +serving as the ConnectRPC transport adapter. Those layers are now separated. + +- **Added** `application::TaskService` (behind the `server` feature) — the inner + application service. It owns the six ports as `Arc` and holds all + use-case orchestration (`send_message`, `send_streaming_message`, `get`, + `list`, `cancel`, `subscribe`, push-config CRUD, `extended_agent_card`), + speaking only domain types and `A2AError`. It hosts the `HasTaskLifecycle` / + `HasStreaming` accessors, so it owns `update_and_broadcast`. +- **`DefaultRequestProcessor`** is now a thin ConnectRPC transport adapter that + decodes `buffa` wire views, delegates to a `TaskService`, and re-encodes the + results. Its public constructors (`new`, `with_handler`, + `with_streaming_handler`) are unchanged, so all call sites compile as before. + `map_*` helpers and `NoopStreamingHandler` remain transport-side. + +### Changed — storage no longer self-broadcasts (`REFACTORING_PLAN.md` §4.0.2) + +Persistence and streaming are now decoupled in the adapters; "commit then +announce" is owned by the orchestration layer via the `TaskStatusBroadcast` +mixin. + +- **`InMemoryTaskStorage` / `SqlxTaskStorage`** `update_status` and `cancel` are + now persistence-only — they no longer call `broadcast_status_update` as a side + effect. (Both structs still implement `AsyncStreamingHandler`; that is where + streaming subscribers live. Shedding that role entirely is a later struct + split, not done here.) +- **Added** `TaskStatusBroadcast::cancel_and_broadcast`, the cancellation + counterpart to `update_and_broadcast`. `TaskService::cancel` now routes through + it, so cancellations still reach subscribers. +- **`DefaultMessageHandler`** now hosts the broadcast mixin: it holds a streaming + port in addition to the lifecycle port and routes every transition in + `process_message` through `update_and_broadcast`. **Breaking:** its + constructor takes a streaming port (and a responder — see below); use + `DefaultMessageHandler::echo(lifecycle, streaming)` for the previous behavior. +- **`ReimbursementHandler`** (in `a2a-agents`) implements `HasTaskLifecycle` / + `HasStreaming` and broadcasts at all five transition sites, including the + background AI worker — its updates and push notifications no longer depend on a + storage side effect. +- Behavioral note: an agent that drives `update_status`/`cancel` directly on + storage no longer streams as a side effect. To announce transitions, host the + `TaskStatusBroadcast` mixin (hold both ports) or use `DefaultMessageHandler`. + +### Breaking — storage/streaming/push struct-split (`REFACTORING_PLAN.md` §4.3, final) + +The storage adapters shed their two non-persistence jobs. `InMemoryTaskStorage` +and `SqlxTaskStorage` previously implemented persistence **and** streaming +fan-out **and** fired push notifications inside their broadcast helpers. Each of +those is now its own adapter behind its own port, wired at the composition edge. + +- **Removed** the `AsyncStreamingHandler` impl (and the internal `subscribers` + map) from `InMemoryTaskStorage` and `SqlxTaskStorage`. They now implement only + `AsyncTaskLifecycle` + `AsyncTaskQuery` + `AsyncNotificationManager` + (persistence and push-config CRUD). +- **Added** `adapter::streaming::InMemoryStreamingHandler` — the in-memory + subscriber registry and broadcast fan-out, extracted out of the storage + structs. Re-exported from the crate root. +- **Added** the `AsyncPushNotifier` port (`port::notification_manager`) — the + out-of-band webhook **delivery** capability, separate from config CRUD + (`AsyncNotificationManager`) and from streaming. `PushNotificationRegistry` + implements it (the `PushNotificationSender` trait remains the pluggable backend + seam: HTTP, no-op, custom). **Added** `NoopPushNotifier`, and a deref-forwarding + impl so `Arc` satisfies `impl AsyncPushNotifier`. +- **Added** `InMemoryTaskStorage::push_notifier()` / `SqlxTaskStorage::push_notifier()` + returning the store's registry as an `Arc` — so a config + written via `set_config` is visible to the notifier at the composition edge. +- **`TaskStatusBroadcast`** gained a third ingredient `HasPushNotifier`: the + mixin now fires push delivery (best-effort, logged on failure) alongside the + streaming broadcast, and gained a `broadcast_artifact` method. Every host + (`TaskService`, `ReimbursementMessageHandler`, `ResponderMessageHandler`) now + also exposes `HasPushNotifier`. +- **Breaking constructors:** `TaskService::new`/`with_handler`, + `ResponderMessageHandler::new`/`echo`, and `ReimbursementHandler::new`/`with_llm` + take a separate `impl AsyncPushNotifier`; `ResponderMessageHandler` and + `ReimbursementHandler` also take the streaming port separately (no longer + requiring the storage to be the streaming handler). The transport adapters + (`ConnectRpcAdapter`, `JsonRpcAdapter`) default to `NoopPushNotifier` and gained + a `with_push_notifier` builder method. +- **Behavior change — no replay on subscribe:** `add_status_subscriber` / + `add_artifact_subscriber` no longer replay the task's current state to a new + subscriber (the streaming adapter has no task access). This is spec-compliant — + the initial `Task` snapshot is delivered by `TaskService::subscribe` / + `send_streaming_message` and emitted by the transport before stream items. + +### Added — injected `Responder` on `DefaultMessageHandler` + +`DefaultMessageHandler` now separates lifecycle/streaming plumbing from the +business decision of what to reply. + +- **Added** the `Responder` trait (`adapter::business`) — + `async fn respond(&self, message, task) -> Result<(Message, TaskState)>`. The + handler does create-if-absent, history append, and broadcasting; the responder + only decides the reply and the resulting state, getting streaming for free. +- **Added** `EchoResponder`, the reference implementation (echoes the input, + stays `Working`). +- **`DefaultMessageHandler::new(lifecycle, streaming, responder)`** takes a + custom responder; **`DefaultMessageHandler::echo(lifecycle, streaming)`** wires + `EchoResponder`. Agents needing "ack now, finish later" semantics still + implement `AsyncMessageHandler` directly. + ## [0.3.0](https://github.com/EmilLindfors/a2a-rs/compare/a2a-rs-v0.2.0...a2a-rs-v0.3.0) - 2026-05-27 ### Fixed diff --git a/a2a-rs/Cargo.toml b/a2a-rs/Cargo.toml index f86b85c..96edfbc 100644 --- a/a2a-rs/Cargo.toml +++ b/a2a-rs/Cargo.toml @@ -13,25 +13,28 @@ categories = ["api-bindings", "network-programming"] [dependencies] # Core dependencies -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -chrono = { version = "0.4", features = ["serde"] } -thiserror = "1.0" -uuid = { version = "1.4", features = ["v4", "serde"] } +serde = { workspace = true } +serde_json = { workspace = true } +chrono = { workspace = true } +thiserror = { workspace = true } +uuid = { workspace = true, features = ["v4", "serde"] } base64 = "0.21" url = { version = "2.4", features = ["serde"] } -bon = "2.3" +bon = { workspace = true } # Database - optional sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "chrono", "uuid", "json"], optional = true } -# Async foundation - optional -tokio = { version = "1.32", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync", "time"], optional = true } -async-trait = { version = "0.1", optional = true } -futures = { version = "0.3", optional = true } +# Async foundation +# async-trait and futures are non-optional: the port layer (e.g. the always-on +# Authenticator trait) uses them unconditionally, so domain + port must compile +# with zero features. tokio stays optional — only adapters need a runtime. +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync", "time"], optional = true } +async-trait = { workspace = true } +futures = { workspace = true } # HTTP client - optional -reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false, optional = true } +reqwest = { workspace = true, features = ["json", "rustls-tls", "stream"], optional = true } # HTTP server - optional @@ -43,8 +46,8 @@ oauth2 = { version = "5.0", optional = true } openidconnect = { version = "4.0", optional = true } # Logging - optional -tracing = { version = "0.1", optional = true } -tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"], optional = true } +tracing = { workspace = true, optional = true } +tracing-subscriber = { workspace = true, optional = true } connectrpc = { version = "0.3.2", features = ["tls", "axum"] } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } webpki-roots = "0.26" @@ -62,14 +65,22 @@ proptest-derive = "0.5" jsonschema = "0.22" criterion = { version = "0.5", features = ["html_reports"] } arbitrary = { version = "1.3", features = ["derive"] } +# Drives the axum routers via `ServiceExt::oneshot` in the JSON-RPC router tests. +tower = { version = "0.5", features = ["util"] } [features] default = ["server", "tracing"] -client = ["dep:tokio", "dep:async-trait", "dep:futures"] +client = ["dep:tokio"] http-client = ["client", "dep:reqwest"] +# Wire-compatible JSON-RPC 2.0 client adapter (reqwest over the spec-mandated +# JSON-RPC + SSE wire format). Lets our client talk to any standard A2A agent. +jsonrpc-client = ["client", "dep:reqwest"] -server = ["dep:tokio", "dep:async-trait", "dep:futures"] +server = ["dep:tokio"] http-server = ["server", "dep:axum"] +# Wire-compatible JSON-RPC 2.0 + HTTP+JSON (REST) transport adapter. Needs axum +# for its routers; the dispatch core itself only needs `server`. +jsonrpc-server = ["server", "dep:axum"] tracing = ["dep:tracing", "dep:tracing-subscriber"] auth = ["dep:jsonwebtoken", "dep:oauth2", "dep:openidconnect", "dep:reqwest"] @@ -77,7 +88,7 @@ sqlx-storage = ["server", "dep:sqlx"] sqlite = ["sqlx-storage", "sqlx/sqlite"] postgres = ["sqlx-storage", "sqlx/postgres"] mysql = ["sqlx-storage", "sqlx/mysql"] -full = ["http-client", "http-server", "tracing", "auth", "sqlite", "postgres"] +full = ["http-client", "http-server", "jsonrpc-server", "jsonrpc-client", "tracing", "auth", "sqlite", "postgres"] [package.metadata.docs.rs] @@ -89,6 +100,16 @@ name = "http_client_server" path = "examples/http_client_server.rs" required-features = ["http-server", "http-client"] +[[example]] +name = "jsonrpc_server" +path = "examples/jsonrpc_server.rs" +required-features = ["jsonrpc-server"] + +[[example]] +name = "jsonrpc_client" +path = "examples/jsonrpc_client.rs" +required-features = ["jsonrpc-client"] + [[example]] name = "sqlx_storage_demo" diff --git a/a2a-rs/README.md b/a2a-rs/README.md index 9b66831..7f73700 100644 --- a/a2a-rs/README.md +++ b/a2a-rs/README.md @@ -8,12 +8,12 @@ A Rust implementation of the Agent-to-Agent (A2A) Protocol v1.0.0, providing a t ## Features -- 🚀 **A2A Protocol v1.0.0** - Full support for the latest A2A specification including: +- 🚀 **A2A Protocol v1.0.0** - Implements the A2A specification (see [Spec compliance](#spec-compliance) for the small, documented divergences), including: - Enhanced push notification management with listing and deletion - Task listing with comprehensive filtering and pagination - Authenticated extended card support - Protocol extensions framework - - Multi-transport support (GRPC, HTTP+JSON) + - Multi-transport support: spec-compliant JSON-RPC 2.0 and HTTP+JSON, plus ConnectRPC (see [Spec compliance](#spec-compliance)) - 🔄 **Multiple Transport Options** - HTTP support - 📡 **Streaming Updates** - Real-time task and artifact updates - 🔐 **Authentication & Security** - JWT, OAuth2, OpenID Connect support with agent card signatures @@ -44,7 +44,7 @@ a2a-rs = { version = "0.1.0", features = ["full"] } ```rust use a2a_rs::{HttpClient, Message}; -use a2a_rs::services::AsyncA2AClient; +use a2a_rs::Transport; #[tokio::main] async fn main() -> Result<(), Box> { @@ -100,6 +100,36 @@ This library follows a hexagonal architecture pattern: - **Ports**: Trait definitions for external dependencies - **Adapters**: Concrete implementations for different transports and storage +## Spec compliance + +`a2a-rs` targets **A2A Protocol v1.0.0** and is wire-compatible with the +specification: the domain types, transports, and `StreamResponse`/JSON-RPC +payloads follow the spec, so off-the-shelf A2A clients and servers interoperate. +There are a couple of small, deliberate divergences, all backward-compatible: + +- **`Last-Event-ID` stream resumption is an opt-in enhancement, not a spec + feature.** The A2A spec reconnects a dropped stream by re-issuing the subscribe + call (resuming from the task's *current* state). On top of that, `a2a-rs` adds + gap-free resumption using the **W3C SSE-standard** `id:` field and + `Last-Event-ID` header (`RetryingTransport` / `WebA2AClient::subscribe_resilient` + on the client; buffered replay on the server). This is fully interoperable — + spec clients ignore the `id:` field and never send the header, getting standard + reconnect-from-current-state behavior — but **gap-free resume only works + a2a-rs ↔ a2a-rs**, not against third-party agents. For strictly spec-shaped + streaming, use `WebA2AClient::subscribe` (or `subscribe_to_task` with + `last_event_id = None`). +- **ConnectRPC is offered as an additional transport.** The spec names three + transport bindings — `JSONRPC`, `GRPC`, and `HTTP+JSON`. `a2a-rs` adds + **ConnectRPC** as the in-tree default (advertised in the agent card under the + non-spec `CONNECTRPC` binding), alongside a spec-compliant **JSON-RPC 2.0** + transport and HTTP+JSON/REST. For interop with third-party A2A agents use the + JSON-RPC transport (`JsonRpcClient` / `jsonrpc_router`); ConnectRPC is the + preferred path a2a-rs ↔ a2a-rs. +- **JSON-RPC method names follow the proto RPC names** (`SubscribeToTask`, + `SendStreamingMessage`, …) rather than the canonical JSON-RPC strings + (`tasks/resubscribe`, `message/stream`); the request/response bodies are + spec-shaped ProtoJSON. + ## Feature Flags - `client` - Client-side functionality diff --git a/a2a-rs/examples/common/simple_agent_handler.rs b/a2a-rs/examples/common/simple_agent_handler.rs index 05e3ffa..36b9adf 100644 --- a/a2a-rs/examples/common/simple_agent_handler.rs +++ b/a2a-rs/examples/common/simple_agent_handler.rs @@ -12,14 +12,14 @@ use async_trait::async_trait; use a2a_rs::{ adapter::storage::InMemoryTaskStorage, + adapter::streaming::InMemoryStreamingHandler, domain::{ - A2AError, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, TaskState, - TaskStatusUpdateEvent, + A2AError, ContextId, Message, Task, TaskArtifactUpdateEvent, TaskId, + TaskPushNotificationConfig, TaskState, TaskStatusUpdateEvent, }, port::{ - AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - MessageHandler, NotificationManager, StreamingHandler, TaskManager, - streaming_handler::Subscriber, + AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, + AsyncTaskQuery, streaming_handler::Subscriber, }, }; @@ -33,11 +33,14 @@ use a2a_rs::{ /// - Agents that don't need custom message processing /// /// For production agents with custom business logic, implement your own -/// `AsyncMessageHandler` and compose it with storage using `DefaultRequestProcessor`. +/// `AsyncMessageHandler` and compose it with storage using `ConnectRpcAdapter`. #[derive(Clone)] pub struct SimpleAgentHandler { - /// Task storage that implements all the business capabilities + /// Task storage (persistence + push-config CRUD) storage: Arc, + /// Dedicated streaming fan-out, shared between this handler's broadcasts and + /// its subscriber registry. + streaming: InMemoryStreamingHandler, } impl SimpleAgentHandler { @@ -45,6 +48,7 @@ impl SimpleAgentHandler { pub fn new() -> Self { Self { storage: Arc::new(InMemoryTaskStorage::new()), + streaming: InMemoryStreamingHandler::new(), } } @@ -52,6 +56,7 @@ impl SimpleAgentHandler { pub fn with_storage(storage: InMemoryTaskStorage) -> Self { Self { storage: Arc::new(storage), + streaming: InMemoryStreamingHandler::new(), } } @@ -68,123 +73,6 @@ impl Default for SimpleAgentHandler { } } -// Synchronous trait implementations - not supported since we use async storage -impl MessageHandler for SimpleAgentHandler { - fn process_message( - &self, - _task_id: &str, - _message: &Message, - _session_id: Option<&str>, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous message processing not supported. Use async version.".to_string(), - )) - } -} - -impl TaskManager for SimpleAgentHandler { - fn create_task(&self, _task_id: &str, _context_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task creation not supported. Use async version.".to_string(), - )) - } - - fn get_task(&self, _task_id: &str, _history_length: Option) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task retrieval not supported. Use async version.".to_string(), - )) - } - - fn update_task_status( - &self, - _task_id: &str, - _state: TaskState, - _message: Option, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task status update not supported. Use async version.".to_string(), - )) - } - - fn cancel_task(&self, _task_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task cancellation not supported. Use async version.".to_string(), - )) - } - - fn task_exists(&self, _task_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task existence check not supported. Use async version.".to_string(), - )) - } -} - -impl NotificationManager for SimpleAgentHandler { - fn set_task_notification( - &self, - _config: &TaskPushNotificationConfig, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous notification setup not supported. Use async version.".to_string(), - )) - } - - fn get_task_notification( - &self, - _task_id: &str, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous notification retrieval not supported. Use async version.".to_string(), - )) - } - - fn remove_task_notification(&self, _task_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Synchronous notification removal not supported. Use async version.".to_string(), - )) - } -} - -impl StreamingHandler for SimpleAgentHandler { - fn add_status_subscriber( - &self, - _task_id: &str, - _subscriber: Box + Send + Sync>, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming subscription not supported. Use async version.".to_string(), - )) - } - - fn add_artifact_subscriber( - &self, - _task_id: &str, - _subscriber: Box + Send + Sync>, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming subscription not supported. Use async version.".to_string(), - )) - } - - fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming unsubscription not supported. Use async version.".to_string(), - )) - } - - fn remove_task_subscribers(&self, _task_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming unsubscription not supported. Use async version.".to_string(), - )) - } - - fn get_subscriber_count(&self, _task_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous subscriber count not supported. Use async version.".to_string(), - )) - } -} - // Asynchronous trait implementations - delegate to storage #[async_trait] @@ -195,9 +83,13 @@ impl AsyncMessageHandler for SimpleAgentHandler { message: &Message, session_id: Option<&str>, ) -> Result { - // Create a message handler and delegate - let message_handler = - a2a_rs::adapter::business::DefaultMessageHandler::new((*self.storage).clone()); + // Create a message handler and delegate, sharing the streaming handler so + // the echo handler's broadcasts reach this handler's subscribers. + let message_handler = a2a_rs::adapter::business::ResponderMessageHandler::echo( + (*self.storage).clone(), + self.streaming.clone(), + self.storage.push_notifier(), + ); message_handler .process_message(task_id, message, session_id) .await @@ -205,53 +97,71 @@ impl AsyncMessageHandler for SimpleAgentHandler { } #[async_trait] -impl AsyncTaskManager for SimpleAgentHandler { - async fn create_task(&self, task_id: &str, context_id: &str) -> Result { - self.storage.create_task(task_id, context_id).await +impl AsyncTaskLifecycle for SimpleAgentHandler { + async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result { + self.storage.create(id, context_id).await } - async fn get_task(&self, task_id: &str, history_length: Option) -> Result { - self.storage.get_task(task_id, history_length).await + async fn get(&self, id: &TaskId, history_length: Option) -> Result { + self.storage.get(id, history_length).await } - async fn update_task_status( + async fn update_status( &self, - task_id: &str, + id: &TaskId, state: TaskState, message: Option, ) -> Result { - self.storage - .update_task_status(task_id, state, message) - .await + self.storage.update_status(id, state, message).await } - async fn cancel_task(&self, task_id: &str) -> Result { - self.storage.cancel_task(task_id).await + async fn cancel(&self, id: &TaskId) -> Result { + self.storage.cancel(id).await } - async fn task_exists(&self, task_id: &str) -> Result { - self.storage.task_exists(task_id).await + async fn exists(&self, id: &TaskId) -> Result { + self.storage.exists(id).await + } +} + +#[async_trait] +impl AsyncTaskQuery for SimpleAgentHandler { + async fn list( + &self, + params: &a2a_rs::domain::ListTasksParams, + ) -> Result { + self.storage.list(params).await } } #[async_trait] impl AsyncNotificationManager for SimpleAgentHandler { - async fn set_task_notification( + async fn set_config( &self, config: &TaskPushNotificationConfig, ) -> Result { - self.storage.set_task_notification(config).await + self.storage.set_config(config).await } - async fn get_task_notification( + async fn get_config( &self, - task_id: &str, + params: &a2a_rs::domain::GetTaskPushNotificationConfigParams, ) -> Result { - self.storage.get_task_notification(task_id).await + self.storage.get_config(params).await } - async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> { - self.storage.remove_task_notification(task_id).await + async fn list_configs( + &self, + params: &a2a_rs::domain::ListTaskPushNotificationConfigsParams, + ) -> Result, A2AError> { + self.storage.list_configs(params).await + } + + async fn delete_config( + &self, + params: &a2a_rs::domain::DeleteTaskPushNotificationConfigParams, + ) -> Result<(), A2AError> { + self.storage.delete_config(params).await } } @@ -262,7 +172,7 @@ impl AsyncStreamingHandler for SimpleAgentHandler { task_id: &str, subscriber: Box + Send + Sync>, ) -> Result { - self.storage + self.streaming .add_status_subscriber(task_id, subscriber) .await } @@ -272,21 +182,21 @@ impl AsyncStreamingHandler for SimpleAgentHandler { task_id: &str, subscriber: Box + Send + Sync>, ) -> Result { - self.storage + self.streaming .add_artifact_subscriber(task_id, subscriber) .await } async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> { - self.storage.remove_subscription(subscription_id).await + self.streaming.remove_subscription(subscription_id).await } async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { - self.storage.remove_task_subscribers(task_id).await + self.streaming.remove_task_subscribers(task_id).await } async fn get_subscriber_count(&self, task_id: &str) -> Result { - self.storage.get_subscriber_count(task_id).await + self.streaming.get_subscriber_count(task_id).await } async fn broadcast_status_update( @@ -294,7 +204,9 @@ impl AsyncStreamingHandler for SimpleAgentHandler { task_id: &str, update: TaskStatusUpdateEvent, ) -> Result<(), A2AError> { - self.storage.broadcast_status_update(task_id, update).await + self.streaming + .broadcast_status_update(task_id, update) + .await } async fn broadcast_artifact_update( @@ -302,7 +214,7 @@ impl AsyncStreamingHandler for SimpleAgentHandler { task_id: &str, update: TaskArtifactUpdateEvent, ) -> Result<(), A2AError> { - self.storage + self.streaming .broadcast_artifact_update(task_id, update) .await } @@ -316,7 +228,7 @@ impl AsyncStreamingHandler for SimpleAgentHandler { >, A2AError, > { - self.storage.status_update_stream(task_id).await + self.streaming.status_update_stream(task_id).await } async fn artifact_update_stream( @@ -328,22 +240,25 @@ impl AsyncStreamingHandler for SimpleAgentHandler { >, A2AError, > { - self.storage.artifact_update_stream(task_id).await + self.streaming.artifact_update_stream(task_id).await } async fn combined_update_stream( &self, task_id: &str, + from_event_id: Option, ) -> Result< std::pin::Pin< Box< dyn futures::Stream< - Item = Result, + Item = Result, > + Send, >, >, A2AError, > { - self.storage.combined_update_stream(task_id).await + self.streaming + .combined_update_stream(task_id, from_event_id) + .await } } diff --git a/a2a-rs/examples/http_client_server.rs b/a2a-rs/examples/http_client_server.rs index 5c98402..d992c6b 100644 --- a/a2a-rs/examples/http_client_server.rs +++ b/a2a-rs/examples/http_client_server.rs @@ -4,14 +4,14 @@ use std::time::Duration; use tokio::time::sleep; use a2a_rs::adapter::{ - BearerTokenAuthenticator, DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, + BearerTokenAuthenticator, ConnectRpcAdapter, HttpClient, HttpServer, InMemoryTaskStorage, NoopPushNotificationSender, SimpleAgentInfo, }; mod common; +use a2a_rs::Transport; use a2a_rs::domain::{Message, Part, Role}; use a2a_rs::observability; -use a2a_rs::services::AsyncA2AClient; use common::SimpleAgentHandler; #[tokio::main] @@ -57,7 +57,7 @@ async fn run_server() -> Result<(), Box> { "test-agent".to_string(), "http://localhost:8080".to_string(), ); - let processor = DefaultRequestProcessor::with_handler(handler, test_agent_info); + let processor = ConnectRpcAdapter::with_handler(handler, test_agent_info); // Create agent info let agent_info = SimpleAgentInfo::new( diff --git a/a2a-rs/examples/jsonrpc_client.rs b/a2a-rs/examples/jsonrpc_client.rs new file mode 100644 index 0000000..3ae6ba0 --- /dev/null +++ b/a2a-rs/examples/jsonrpc_client.rs @@ -0,0 +1,92 @@ +//! A wire-compatible JSON-RPC 2.0 A2A client — the counterpart to +//! [`jsonrpc_server`]. +//! +//! It demonstrates the **auto-connect** path: fetch the agent card, negotiate a +//! transport from the interfaces it advertises ([`connect`] + +//! [`default_registry`]), and fall back to a direct [`JsonRpcClient`] if the card +//! can't be fetched or negotiated. Then it drives the negotiated +//! [`Transport`](a2a_rs::Transport) port through a full task lifecycle: send a +//! message, read the task back, subscribe to its SSE stream, and cancel it. +//! +//! Start the server in one terminal: +//! ```sh +//! cargo run -p a2a-rs --example jsonrpc_server --features jsonrpc-server +//! ``` +//! …then the client in another (default target `http://127.0.0.1:8137`, or pass +//! a base URL): +//! ```sh +//! cargo run -p a2a-rs --example jsonrpc_client --features jsonrpc-client +//! cargo run -p a2a-rs --example jsonrpc_client --features jsonrpc-client -- http://127.0.0.1:8137 +//! ``` + +use std::time::Duration; + +use futures::StreamExt; + +use a2a_rs::domain::Message; +use a2a_rs::{JsonRpcClient, StreamItem, Transport, connect, default_registry}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let base_url = std::env::args() + .nth(1) + .unwrap_or_else(|| "http://127.0.0.1:8137".to_string()); + + // 1. Auto-connect: fetch the card and let the negotiator pick a transport the + // client compiled in, ranked by client preference. If the card can't be + // fetched or negotiated, fall back to a direct JSON-RPC client. + let transport: Box = match connect(&base_url, &default_registry()).await { + Ok(t) => { + println!("✅ negotiated transport: {}", t.protocol()); + t + } + Err(e) => { + println!("⚠️ card negotiation failed ({e}); falling back to direct JSON-RPC client"); + Box::new(JsonRpcClient::new(base_url.clone())) + } + }; + + // 2. Send a message — the server creates (or updates) the task and echoes it. + let task = transport + .send_task_message( + "demo-task", + &Message::user_text("hello".to_string(), "m1".to_string()), + None, + None, + ) + .await?; + println!("📨 sent message; task id = {}", task.id); + + // 3. Read the task back. + let fetched = transport.get_task(&task.id, None).await?; + println!("📥 get_task → state {:?}", fetched.status.state); + + // 4. Subscribe to the task's SSE stream and print the first few events. The + // first event is the initial task snapshot; live updates follow. + let mut stream = transport.subscribe_to_task(&task.id, None, None).await?; + println!("📡 subscribing (up to 3 events / 5s)…"); + for _ in 0..3 { + match tokio::time::timeout(Duration::from_secs(5), stream.next()).await { + Ok(Some(Ok(event))) => match &event.item { + StreamItem::Task(t) => { + println!(" • snapshot: task {} ({:?})", t.id, t.status.state) + } + StreamItem::StatusUpdate(u) => println!(" • status: {:?}", u.status.state), + StreamItem::ArtifactUpdate(_) => println!(" • artifact update"), + }, + Ok(Some(Err(e))) => { + println!(" • stream error: {e}"); + break; + } + Ok(None) => break, // stream ended + Err(_) => break, // timed out waiting for the next event + } + } + drop(stream); + + // 5. Cancel the task. + let canceled = transport.cancel_task(&task.id).await?; + println!("🛑 canceled; final state {:?}", canceled.status.state); + + Ok(()) +} diff --git a/a2a-rs/examples/jsonrpc_server.rs b/a2a-rs/examples/jsonrpc_server.rs new file mode 100644 index 0000000..3b7e9d7 --- /dev/null +++ b/a2a-rs/examples/jsonrpc_server.rs @@ -0,0 +1,95 @@ +//! A wire-compatible JSON-RPC 2.0 + HTTP+JSON (REST) A2A server. +//! +//! Unlike [`http_client_server`], which speaks ConnectRPC, this example mounts +//! the [`JsonRpcAdapter`] so off-the-shelf A2A clients (the official `a2acli`, +//! the Go/C#/Python SDKs) can talk to it. Composition happens here at the edge: +//! `jsonrpc_router(adapter).merge(rest_router(adapter))` plus the well-known +//! agent-card route, all on one `axum::serve`. +//! +//! Run it: +//! ```sh +//! cargo run -p a2a-rs --example jsonrpc_server --features jsonrpc-server +//! ``` +//! Then exercise it with curl (JSON-RPC): +//! ```sh +//! curl -s localhost:8137/ -d '{"jsonrpc":"2.0","id":1,"method":"SendMessage", +//! "params":{"message":{"messageId":"m1","role":"ROLE_USER", +//! "parts":[{"text":"hello"}],"taskId":"t1"}}}' +//! curl -s localhost:8137/.well-known/agent-card.json +//! ``` +//! …or REST: +//! ```sh +//! curl -s localhost:8137/message:send -d '{"message":{"messageId":"m1", +//! "role":"ROLE_USER","parts":[{"text":"hi"}],"taskId":"t1"}}' +//! curl -s localhost:8137/tasks/t1 +//! ``` +//! …or the official CLI (clone at ./a2aproject/a2a-rs/a2acli): +//! ```sh +//! cargo run --bin a2acli -- --base-url http://localhost:8137 card +//! cargo run --bin a2acli -- --base-url http://localhost:8137 send "hello" +//! ``` + +use std::sync::Arc; + +use axum::{Json, Router, extract::State, response::IntoResponse, routing::get}; + +use a2a_rs::adapter::{ + InMemoryTaskStorage, JsonRpcAdapter, SimpleAgentInfo, jsonrpc_router, rest_router, +}; +use a2a_rs::services::server::AgentInfoProvider; + +mod common; +use common::SimpleAgentHandler; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let address = "127.0.0.1:8137"; + + // 1. Inner application: an in-memory handler behind the JSON-RPC adapter. + let handler = SimpleAgentHandler::with_storage(InMemoryTaskStorage::new()); + let adapter_card = + SimpleAgentInfo::new("jsonrpc-agent".to_string(), format!("http://{address}")); + let adapter = Arc::new(JsonRpcAdapter::with_handler(handler, adapter_card)); + + // 2. Agent card served for client transport negotiation. The primary + // interface (from `new`) already advertises JSON-RPC at `base`; we add the + // REST binding so an official client reading `supportedInterfaces` can + // negotiate to either endpoint this server mounts. + let base = format!("http://{address}"); + let card_info = Arc::new( + SimpleAgentInfo::new("Example JSON-RPC A2A Agent".to_string(), base.clone()) + .with_description("Wire-compatible JSON-RPC 2.0 + HTTP+JSON A2A server".to_string()) + .with_preferred_transport("JSONRPC".to_string()) + .add_interface(base, "HTTP+JSON".to_string()) + .add_skill( + "echo".to_string(), + "Echo".to_string(), + Some("Echoes input".to_string()), + ), + ); + + // 3. Composition at the edge: both transports + the agent card on one server. + // Each sub-router carries its own state, so they merge as `Router<()>`. + let card_router = Router::new() + .route("/.well-known/agent-card.json", get(agent_card)) + .with_state(card_info); + let app: Router = jsonrpc_router(adapter.clone()) + .merge(rest_router(adapter)) + .merge(card_router); + + println!("🚀 JSON-RPC + REST A2A server on http://{address}"); + let listener = tokio::net::TcpListener::bind(address).await?; + axum::serve(listener, app).await?; + Ok(()) +} + +async fn agent_card(State(info): State>) -> impl IntoResponse { + match info.get_agent_card().await { + Ok(card) => Json(card).into_response(), + Err(e) => ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ "error": e.to_string() })), + ) + .into_response(), + } +} diff --git a/a2a-rs/examples/sqlx_storage_demo.rs b/a2a-rs/examples/sqlx_storage_demo.rs index ef3fc74..c83623a 100644 --- a/a2a-rs/examples/sqlx_storage_demo.rs +++ b/a2a-rs/examples/sqlx_storage_demo.rs @@ -21,7 +21,16 @@ use a2a_rs::adapter::storage::{DatabaseConfig, SqlxTaskStorage}; #[cfg(feature = "sqlx-storage")] use a2a_rs::domain::TaskState; #[cfg(feature = "sqlx-storage")] -use a2a_rs::port::AsyncTaskManager; +use a2a_rs::port::AsyncTaskLifecycle; + +#[cfg(feature = "sqlx-storage")] +fn tid(s: &str) -> a2a_rs::domain::TaskId { + s.parse().unwrap() +} +#[cfg(feature = "sqlx-storage")] +fn cid(s: &str) -> a2a_rs::domain::ContextId { + s.parse().unwrap() +} #[cfg(feature = "sqlx-storage")] #[tokio::main] @@ -65,7 +74,7 @@ async fn main() -> Result<(), Box> { let task_ids = vec!["demo-task-1", "demo-task-2", "demo-task-3"]; for task_id in &task_ids { - let task = storage.create_task(task_id, "demo-context").await?; + let task = storage.create(&tid(task_id), &cid("demo-context")).await?; println!( " ✓ Created task: {} (status: {:?})", task.id, task.status.state @@ -76,27 +85,27 @@ async fn main() -> Result<(), Box> { // Demo: Update task statuses println!("🔄 Updating task statuses..."); storage - .update_task_status("demo-task-1", TaskState::Working, None) + .update_status(&tid("demo-task-1"), TaskState::Working, None) .await?; println!(" ✓ Updated demo-task-1 to Working"); storage - .update_task_status("demo-task-2", TaskState::Working, None) + .update_status(&tid("demo-task-2"), TaskState::Working, None) .await?; storage - .update_task_status("demo-task-2", TaskState::Completed, None) + .update_status(&tid("demo-task-2"), TaskState::Completed, None) .await?; println!(" ✓ Updated demo-task-2 to Working, then Completed"); storage - .update_task_status("demo-task-3", TaskState::Working, None) + .update_status(&tid("demo-task-3"), TaskState::Working, None) .await?; println!(" ✓ Updated demo-task-3 to Working"); println!(); // Demo: Cancel a task println!("❌ Canceling a task..."); - let canceled_task = storage.cancel_task("demo-task-3").await?; + let canceled_task = storage.cancel(&tid("demo-task-3")).await?; println!( " ✓ Canceled task: {} (status: {:?})", canceled_task.id, canceled_task.status.state @@ -106,7 +115,7 @@ async fn main() -> Result<(), Box> { // Demo: Retrieve tasks and show history println!("📖 Retrieving tasks with history..."); for task_id in &task_ids { - let task = storage.get_task(task_id, Some(10)).await?; + let task = storage.get(&tid(task_id), Some(10)).await?; println!(" 📋 Task: {} (status: {:?})", task.id, task.status.state); let history = &task.history; @@ -125,11 +134,11 @@ async fn main() -> Result<(), Box> { // Demo: Task existence checks println!("🔍 Checking task existence..."); for task_id in &task_ids { - let exists = storage.task_exists(task_id).await?; + let exists = storage.exists(&tid(task_id)).await?; println!(" {} exists: {}", task_id, exists); } - let exists = storage.task_exists("non-existent-task").await?; + let exists = storage.exists(&tid("non-existent-task")).await?; println!(" non-existent-task exists: {}", exists); println!(); diff --git a/a2a-rs/examples/storage_comparison.rs b/a2a-rs/examples/storage_comparison.rs index 45fe57e..d7c546e 100644 --- a/a2a-rs/examples/storage_comparison.rs +++ b/a2a-rs/examples/storage_comparison.rs @@ -16,11 +16,18 @@ use std::time::Duration; use a2a_rs::adapter::storage::InMemoryTaskStorage; use a2a_rs::domain::TaskState; -use a2a_rs::port::AsyncTaskManager; +use a2a_rs::port::AsyncTaskLifecycle; #[cfg(feature = "sqlx-storage")] use a2a_rs::adapter::storage::{DatabaseConfig, SqlxTaskStorage}; +fn tid(s: &str) -> a2a_rs::domain::TaskId { + s.parse().unwrap() +} +fn cid(s: &str) -> a2a_rs::domain::ContextId { + s.parse().unwrap() +} + #[tokio::main] async fn main() -> Result<(), Box> { // Initialize tracing @@ -117,26 +124,26 @@ async fn main() -> Result<(), Box> { Ok(()) } -async fn run_storage_tests( +async fn run_storage_tests( storage: &T, storage_name: &str, ) -> Result<(), Box> { let task_id = format!("test-task-{}", storage_name.to_lowercase()); // Test 1: Create task - let task = storage.create_task(&task_id, "test-context").await?; + let task = storage.create(&tid(&task_id), &cid("test-context")).await?; println!( " ✓ Created task: {} (status: {:?})", task.id, task.status.state ); // Test 2: Check existence - let exists = storage.task_exists(&task_id).await?; + let exists = storage.exists(&tid(&task_id)).await?; println!(" ✓ Task exists: {}", exists); // Test 3: Update status let updated_task = storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; println!( " ✓ Updated to Working (status: {:?})", @@ -144,13 +151,13 @@ async fn run_storage_tests( ); // Test 4: Get task with history - let task_with_history = storage.get_task(&task_id, Some(10)).await?; + let task_with_history = storage.get(&tid(&task_id), Some(10)).await?; let history_count = task_with_history.history.len(); println!(" ✓ Retrieved task with {} history entries", history_count); // Test 5: Complete the task let completed_task = storage - .update_task_status(&task_id, TaskState::Completed, None) + .update_status(&tid(&task_id), TaskState::Completed, None) .await?; println!( " ✓ Completed task (status: {:?})", @@ -158,18 +165,20 @@ async fn run_storage_tests( ); // Test 6: Try to cancel completed task (should fail) - match storage.cancel_task(&task_id).await { + match storage.cancel(&tid(&task_id)).await { Ok(_) => println!(" ❌ Unexpected: was able to cancel completed task"), Err(_) => println!(" ✓ Correctly prevented canceling completed task"), } // Test 7: Create and cancel a working task let cancel_task_id = format!("cancel-test-{}", storage_name.to_lowercase()); - storage.create_task(&cancel_task_id, "test-context").await?; storage - .update_task_status(&cancel_task_id, TaskState::Working, None) + .create(&tid(&cancel_task_id), &cid("test-context")) + .await?; + storage + .update_status(&tid(&cancel_task_id), TaskState::Working, None) .await?; - let canceled_task = storage.cancel_task(&cancel_task_id).await?; + let canceled_task = storage.cancel(&tid(&cancel_task_id)).await?; println!( " ✓ Canceled working task (status: {:?})", canceled_task.status.state @@ -178,7 +187,7 @@ async fn run_storage_tests( Ok(()) } -async fn measure_performance( +async fn measure_performance( storage: &T, ) -> Result> { let start = std::time::Instant::now(); @@ -187,14 +196,14 @@ async fn measure_performance( let task_id = format!("perf-task-{}", i); // Create, update, and retrieve task - storage.create_task(&task_id, "perf-context").await?; + storage.create(&tid(&task_id), &cid("perf-context")).await?; storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; storage - .update_task_status(&task_id, TaskState::Completed, None) + .update_status(&tid(&task_id), TaskState::Completed, None) .await?; - storage.get_task(&task_id, Some(5)).await?; + storage.get(&tid(&task_id), Some(5)).await?; } Ok(start.elapsed()) diff --git a/a2a-rs/migrations/003_task_version.sql b/a2a-rs/migrations/003_task_version.sql new file mode 100644 index 0000000..dda46a2 --- /dev/null +++ b/a2a-rs/migrations/003_task_version.sql @@ -0,0 +1,6 @@ +-- v0.4.0 Migration: add an optimistic-concurrency version column to tasks. +-- The version is a monotonic counter bumped on every task mutation; conditional +-- updates (AsyncTaskVersioning::update_status_checked) compare it to detect and +-- reject lost updates. + +ALTER TABLE tasks ADD COLUMN version INTEGER NOT NULL DEFAULT 1; diff --git a/a2a-rs/src/adapter/business/agent_info.rs b/a2a-rs/src/adapter/business/agent_info.rs index 3a43f4f..0605f11 100644 --- a/a2a-rs/src/adapter/business/agent_info.rs +++ b/a2a-rs/src/adapter/business/agent_info.rs @@ -7,7 +7,10 @@ use async_trait::async_trait; use std::collections::HashMap; use crate::{ - domain::{A2AError, AgentCard, AgentExtension, AgentProvider, AgentSkill, SecurityScheme}, + domain::{ + A2AError, AgentCard, AgentExtension, AgentInterface, AgentProvider, AgentSkill, + SecurityScheme, + }, services::server::AgentInfoProvider, }; @@ -60,6 +63,44 @@ impl SimpleAgentInfo { self } + /// Set the transport protocol a client should prefer when connecting + /// (e.g. `"JSONRPC"`, `"HTTP+JSON"`, `"GRPC"`). + /// + /// The card has no standalone "preferred transport" field — the *first* + /// entry in `supportedInterfaces` is the preferred one (that is what + /// [`AgentCard::preferred_transport`] reads). This sets the protocol binding + /// of that primary interface (creating one if the card has none), which a + /// card-driven A2A client uses to rank transports during negotiation. + pub fn with_preferred_transport(mut self, transport: String) -> Self { + if let Some(primary) = self.card.supported_interfaces.first_mut() { + primary.protocol_binding = transport; + } else { + self.card.supported_interfaces.push(AgentInterface { + protocol_binding: transport, + protocol_version: "1.0".to_string(), + ..Default::default() + }); + } + self + } + + /// Advertise an additional transport interface — a `(url, protocol_binding)` + /// pair — on the agent card so card-driven clients can negotiate to it. + /// + /// A server mounting both the JSON-RPC and REST routers advertises both: the + /// primary interface (from [`SimpleAgentInfo::new`]) already carries the + /// JSON-RPC binding, so add the REST one with + /// `.add_interface(base, "HTTP+JSON")`. + pub fn add_interface(mut self, url: String, protocol_binding: String) -> Self { + self.card.supported_interfaces.push(AgentInterface { + url, + protocol_binding, + protocol_version: "1.0".to_string(), + ..Default::default() + }); + self + } + /// Enable streaming capability pub fn with_streaming(mut self) -> Self { self.card.capabilities.get_or_insert_default().streaming = Some(true); diff --git a/a2a-rs/src/adapter/business/message_handler.rs b/a2a-rs/src/adapter/business/message_handler.rs index 679b6c2..1bb6964 100644 --- a/a2a-rs/src/adapter/business/message_handler.rs +++ b/a2a-rs/src/adapter/business/message_handler.rs @@ -1,85 +1,237 @@ -//! Default message handler implementation +//! Default message handler implementation. +//! +//! `ResponderMessageHandler` owns the *plumbing* of turning an incoming message +//! into a task — parse the id, create the task if absent, append the message to +//! history, broadcast each transition — and delegates the *business decision* +//! (what to reply, and what state the task should end in) to an injected +//! [`Responder`]. The built-in [`EchoResponder`] echoes the message back; a +//! caller that wants AI behaviour implements `Responder` and keeps all of the +//! lifecycle + streaming wiring for free. +//! +//! This split keeps the broadcasting in one place: because the handler holds +//! both the lifecycle and streaming ports it hosts the [`TaskStatusBroadcast`] +//! mixin, so every transition it drives — the incoming-message append *and* the +//! responder's reply — goes through [`update_and_broadcast`], announcing to +//! streaming subscribers. Storage mutators are persistence-only and do not +//! self-broadcast, so a `Responder` author never has to think about streaming +//! at all. +//! +//! `Responder` is synchronous-shaped (`message + task → reply + state`); agents +//! that need "acknowledge now, finish later" semantics implement +//! [`AsyncMessageHandler`](crate::port::AsyncMessageHandler) directly and host +//! the mixin themselves (the reimbursement agent does this). +//! +//! [`update_and_broadcast`]: TaskStatusBroadcast::update_and_broadcast use std::sync::Arc; use async_trait::async_trait; use crate::{ - domain::{A2AError, Message, Task, TaskState}, - port::{AsyncMessageHandler, AsyncTaskManager}, + application::{HasPushNotifier, HasStreaming, HasTaskLifecycle, TaskStatusBroadcast}, + domain::{A2AError, ContextId, Message, Part, Role, Task, TaskId, TaskState}, + port::{AsyncMessageHandler, AsyncPushNotifier, AsyncStreamingHandler, AsyncTaskLifecycle}, }; -/// Default message handler that processes messages and delegates to task manager +/// The business decision behind a message handler: given the incoming `message` +/// and the `task` as it now stands (already in `Working` with the message +/// appended to history), produce the agent's reply and the state the task +/// should transition to. +/// +/// Implement this to plug custom logic (an LLM call, a rules engine, …) into +/// [`ResponderMessageHandler`] without re-implementing task lifecycle or +/// streaming. Implementations must be cheap to share (`Send + Sync`): the +/// handler holds the responder behind an `Arc`. +#[async_trait] +pub trait Responder: Send + Sync { + /// Produce the reply message and the resulting task state. + async fn respond( + &self, + message: &Message, + task: &Task, + ) -> Result<(Message, TaskState), A2AError>; +} + +/// The reference [`Responder`]: echoes the incoming text back and leaves the +/// task in `Working`. Useful for smoke tests, examples, and as the default for +/// [`ResponderMessageHandler::echo`]. +#[derive(Clone, Debug, Default)] +pub struct EchoResponder; + +#[async_trait] +impl Responder for EchoResponder { + async fn respond( + &self, + message: &Message, + task: &Task, + ) -> Result<(Message, TaskState), A2AError> { + let echoed = message + .parts + .iter() + .filter_map(|p| p.get_text()) + .collect::>() + .join(" "); + + let reply = Message::builder() + .role(Role::Agent) + .parts(vec![Part::text(format!("Echo: {}", echoed))]) + .message_id(uuid::Uuid::new_v4().to_string()) + .task_id(task.id.clone()) + .context_id(message.context_id.clone()) + .build(); + + // The reference handler keeps the task Working; real agents pick a + // terminal state appropriate to their processing. + Ok((reply, TaskState::Working)) + } +} + +/// A message handler that owns task-lifecycle plumbing and streaming +/// announcements, delegating the reply to an injected [`Responder`]. +/// +/// Holds its ports as `Arc` trait objects (injected at the composition +/// edge), so the handler carries no generic parameter. Because it holds both the +/// lifecycle and streaming ports it is a host for the [`TaskStatusBroadcast`] +/// capability mixin. #[derive(Clone)] -pub struct DefaultMessageHandler -where - T: AsyncTaskManager + Send + Sync + 'static, -{ - /// Task manager for handling task operations - task_manager: Arc, +pub struct ResponderMessageHandler { + /// Task lifecycle port for handling task operations + task_lifecycle: Arc, + /// Streaming port for announcing status transitions to subscribers + streaming: Arc, + /// Push-notifier port for out-of-band webhook delivery on each transition + push_notifier: Arc, + /// The business decision: what to reply and which state to end in + responder: Arc, } -impl DefaultMessageHandler -where - T: AsyncTaskManager + Send + Sync + 'static, -{ - /// Create a new message handler with the given task manager - pub fn new(task_manager: T) -> Self { +impl ResponderMessageHandler { + /// Create a handler with a custom [`Responder`]. + /// + /// The lifecycle, streaming, and push-notifier ports are accepted separately + /// so the handler depends only on the capabilities it uses; at the + /// composition edge the streaming and push ports typically come from a + /// dedicated streaming adapter and the store's `push_notifier()`. + pub fn new( + task_lifecycle: impl AsyncTaskLifecycle + 'static, + streaming: impl AsyncStreamingHandler + 'static, + push_notifier: impl AsyncPushNotifier + 'static, + responder: impl Responder + 'static, + ) -> Self { Self { - task_manager: Arc::new(task_manager), + task_lifecycle: Arc::new(task_lifecycle), + streaming: Arc::new(streaming), + push_notifier: Arc::new(push_notifier), + responder: Arc::new(responder), } } + + /// Create the reference echo handler ([`EchoResponder`]). + pub fn echo( + task_lifecycle: impl AsyncTaskLifecycle + 'static, + streaming: impl AsyncStreamingHandler + 'static, + push_notifier: impl AsyncPushNotifier + 'static, + ) -> Self { + Self::new(task_lifecycle, streaming, push_notifier, EchoResponder) + } +} + +impl HasTaskLifecycle for ResponderMessageHandler { + fn lifecycle(&self) -> &dyn AsyncTaskLifecycle { + self.task_lifecycle.as_ref() + } +} + +impl HasStreaming for ResponderMessageHandler { + fn streaming(&self) -> &dyn AsyncStreamingHandler { + self.streaming.as_ref() + } +} + +impl HasPushNotifier for ResponderMessageHandler { + fn push_notifier(&self) -> &dyn AsyncPushNotifier { + self.push_notifier.as_ref() + } } #[async_trait] -impl AsyncMessageHandler for DefaultMessageHandler -where - T: AsyncTaskManager + Send + Sync + 'static, -{ +impl AsyncMessageHandler for ResponderMessageHandler { async fn process_message( &self, task_id: &str, message: &Message, session_id: Option<&str>, ) -> Result { - // Check if task exists - let task_exists = self.task_manager.task_exists(task_id).await?; + let id: TaskId = task_id.parse()?; - if !task_exists { - // Create a new task - let context_id = session_id.unwrap_or("default"); - self.task_manager.create_task(task_id, context_id).await?; + // Create the task on first contact. + if !self.task_lifecycle.exists(&id).await? { + let context_id: ContextId = session_id.unwrap_or("default").parse()?; + self.task_lifecycle.create(&id, &context_id).await?; } - // First, update the task with the incoming message to add it to history - self.task_manager - .update_task_status(task_id, TaskState::Working, Some(message.clone())) + // Append the incoming message to history (Working), announcing the + // transition to any streaming subscribers. + let task = self + .update_and_broadcast(&id, TaskState::Working, Some(message.clone())) .await?; - // Create a simple echo response - let response_message = Message::builder() - .role(crate::domain::Role::Agent) - .parts(vec![crate::domain::Part::text(format!( - "Echo: {}", - message - .parts - .iter() - .filter_map(|p| p.get_text()) - .collect::>() - .join(" ") - ))]) - .message_id(uuid::Uuid::new_v4().to_string()) - .task_id(task_id.to_string()) - .context_id(message.context_id.clone()) - .build(); - - // For the default handler, we'll add the response message to history but keep the task in Working state - // Real agents would process the message and determine the appropriate final state - let final_task = self - .task_manager - .update_task_status(task_id, TaskState::Working, Some(response_message)) - .await?; + // Delegate the business decision to the responder, then commit and + // announce its reply. + let (reply, state) = self.responder.respond(message, &task).await?; + let final_task = self.update_and_broadcast(&id, state, Some(reply)).await?; Ok(final_task) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::adapter::storage::InMemoryTaskStorage; + use crate::adapter::streaming::InMemoryStreamingHandler; + + /// A responder that ignores the input and drives the task to a terminal + /// state with a fixed reply — proof that the injected responder, not the + /// handler, owns the reply text and the final state. + struct FixedResponder; + + #[async_trait] + impl Responder for FixedResponder { + async fn respond( + &self, + _message: &Message, + task: &Task, + ) -> Result<(Message, TaskState), A2AError> { + let reply = Message::builder() + .role(Role::Agent) + .parts(vec![Part::text("done".to_string())]) + .message_id("fixed-1".to_string()) + .task_id(task.id.clone()) + .build(); + Ok((reply, TaskState::Completed)) + } + } + + #[tokio::test] + async fn injected_responder_controls_reply_and_state() { + let storage = InMemoryTaskStorage::new(); + let streaming = InMemoryStreamingHandler::new(); + let push = storage.push_notifier(); + let handler = ResponderMessageHandler::new(storage, streaming, push, FixedResponder); + + let message = Message::user_text("anything".to_string(), "m1".to_string()); + let task = handler.process_message("t1", &message, None).await.unwrap(); + + // The responder chose the terminal state... + assert_eq!(task.status.state, TaskState::Completed); + // ...and its reply landed in history (after the appended user message). + let replied = task.history.iter().any(|m| { + m.parts + .iter() + .filter_map(|p| p.get_text()) + .any(|t| t == "done") + }); + assert!(replied, "responder reply should be in task history"); + } +} diff --git a/a2a-rs/src/adapter/business/mod.rs b/a2a-rs/src/adapter/business/mod.rs index 214733b..bf1aba5 100644 --- a/a2a-rs/src/adapter/business/mod.rs +++ b/a2a-rs/src/adapter/business/mod.rs @@ -6,19 +6,15 @@ pub mod agent_info; pub mod message_handler; #[cfg(feature = "server")] pub mod push_notification; -#[cfg(feature = "server")] -pub mod request_processor; // Re-export business implementations #[cfg(feature = "server")] pub use agent_info::SimpleAgentInfo; #[cfg(feature = "server")] -pub use message_handler::DefaultMessageHandler; +pub use message_handler::{EchoResponder, Responder, ResponderMessageHandler}; #[cfg(all(feature = "server", feature = "http-client"))] pub use push_notification::HttpPushNotificationSender; #[cfg(feature = "server")] pub use push_notification::{ NoopPushNotificationSender, PushNotificationRegistry, PushNotificationSender, }; -#[cfg(feature = "server")] -pub use request_processor::DefaultRequestProcessor; diff --git a/a2a-rs/src/adapter/business/push_notification.rs b/a2a-rs/src/adapter/business/push_notification.rs index 3b502ce..716eb06 100644 --- a/a2a-rs/src/adapter/business/push_notification.rs +++ b/a2a-rs/src/adapter/business/push_notification.rs @@ -15,6 +15,7 @@ use tokio::sync::Mutex; use crate::domain::{ A2AError, TaskArtifactUpdateEvent, TaskPushNotificationConfig, TaskStatusUpdateEvent, }; +use crate::port::AsyncPushNotifier; /// Interface for a push notification sender #[async_trait] @@ -431,3 +432,26 @@ impl PushNotificationRegistry { } } } + +/// The registry is the in-house [`AsyncPushNotifier`] adapter: it looks up the +/// per-task config it holds and dispatches through its pluggable +/// [`PushNotificationSender`] backend. The two trait methods are exactly the +/// existing inherent `send_*` methods. +#[async_trait] +impl AsyncPushNotifier for PushNotificationRegistry { + async fn notify_status( + &self, + task_id: &str, + event: &TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { + self.send_status_update(task_id, event).await + } + + async fn notify_artifact( + &self, + task_id: &str, + event: &TaskArtifactUpdateEvent, + ) -> Result<(), A2AError> { + self.send_artifact_update(task_id, event).await + } +} diff --git a/a2a-rs/src/adapter/error/client.rs b/a2a-rs/src/adapter/error/client.rs index 0b699a4..7580fb5 100644 --- a/a2a-rs/src/adapter/error/client.rs +++ b/a2a-rs/src/adapter/error/client.rs @@ -6,7 +6,7 @@ use thiserror::Error; /// Error type for HTTP client adapter #[derive(Error, Debug)] -#[cfg(feature = "http-client")] +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] pub enum HttpClientError { /// Reqwest client error #[error("HTTP client error: {0}")] @@ -30,7 +30,7 @@ pub enum HttpClientError { } // Conversion from adapter errors to domain errors -#[cfg(feature = "http-client")] +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] impl From for A2AError { fn from(error: HttpClientError) -> Self { match error { diff --git a/a2a-rs/src/adapter/error/mod.rs b/a2a-rs/src/adapter/error/mod.rs index 0cc18e7..4db5675 100644 --- a/a2a-rs/src/adapter/error/mod.rs +++ b/a2a-rs/src/adapter/error/mod.rs @@ -1,13 +1,13 @@ //! Error types for adapter implementations -#[cfg(feature = "client")] +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] pub mod client; #[cfg(feature = "server")] pub mod server; // Re-export client error types -#[cfg(feature = "http-client")] +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] pub use client::HttpClientError; // Re-export server error types diff --git a/a2a-rs/src/adapter/interceptor.rs b/a2a-rs/src/adapter/interceptor.rs new file mode 100644 index 0000000..657d7cd --- /dev/null +++ b/a2a-rs/src/adapter/interceptor.rs @@ -0,0 +1,42 @@ +//! Built-in [`CallInterceptor`](crate::port::CallInterceptor) adapters. +//! +//! Concrete interceptors live in the adapter layer (the port is just the trait). +//! They attach to either transport via `with_interceptor`. + +#[cfg(feature = "tracing")] +use async_trait::async_trait; + +#[cfg(feature = "tracing")] +use crate::domain::A2AError; +#[cfg(feature = "tracing")] +use crate::port::{CallContext, CallInterceptor}; + +/// A [`CallInterceptor`](crate::port::CallInterceptor) that logs each call's +/// start and outcome via `tracing`. +/// +/// Register it on a client or server transport to get one structured log line +/// per call boundary (method, side) plus a success/failure line with the error. +/// A drop-in for the official SDK's logging interceptor. +#[cfg(feature = "tracing")] +#[derive(Debug, Clone, Default)] +pub struct LoggingInterceptor; + +#[cfg(feature = "tracing")] +#[async_trait] +impl CallInterceptor for LoggingInterceptor { + async fn before(&self, ctx: &CallContext) -> Result<(), A2AError> { + tracing::debug!(method = %ctx.method, side = ?ctx.side, "A2A call started"); + Ok(()) + } + + async fn after(&self, ctx: &CallContext, outcome: Result<(), &A2AError>) { + match outcome { + Ok(()) => { + tracing::debug!(method = %ctx.method, side = ?ctx.side, "A2A call succeeded") + } + Err(e) => { + tracing::warn!(method = %ctx.method, side = ?ctx.side, error = %e, "A2A call failed") + } + } + } +} diff --git a/a2a-rs/src/adapter/mod.rs b/a2a-rs/src/adapter/mod.rs index 1bac013..d5f8cc4 100644 --- a/a2a-rs/src/adapter/mod.rs +++ b/a2a-rs/src/adapter/mod.rs @@ -12,7 +12,10 @@ pub mod auth; pub mod business; pub mod error; +pub mod interceptor; pub mod storage; +#[cfg(feature = "server")] +pub mod streaming; pub mod transport; // Legacy re-exports for backward compatibility @@ -21,6 +24,14 @@ pub mod transport; // Client re-exports (from transport) #[cfg(feature = "http-client")] pub use transport::http::HttpClient; +#[cfg(feature = "jsonrpc-client")] +pub use transport::jsonrpc_client::JsonRpcClient; +#[cfg(feature = "client")] +pub use transport::negotiation::{TransportFactory, TransportNegotiator, default_registry}; +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] +pub use transport::negotiation::{connect, fetch_agent_card}; +#[cfg(feature = "client")] +pub use transport::retry::{RetryingTransport, subscribe_resilient}; // Server re-exports (from various modules) #[cfg(feature = "http-server")] @@ -32,16 +43,28 @@ pub use auth::{JwtAuthenticator, OAuth2Authenticator, OpenIdConnectAuthenticator #[cfg(all(feature = "server", feature = "http-client"))] pub use business::HttpPushNotificationSender; #[cfg(feature = "server")] -pub use business::{DefaultRequestProcessor, SimpleAgentInfo}; +pub use business::SimpleAgentInfo; #[cfg(feature = "server")] pub use business::{NoopPushNotificationSender, PushNotificationRegistry, PushNotificationSender}; #[cfg(feature = "server")] pub use storage::InMemoryTaskStorage; +#[cfg(feature = "server")] +pub use streaming::InMemoryStreamingHandler; +#[cfg(feature = "server")] +pub use transport::connectrpc::ConnectRpcAdapter; +#[cfg(feature = "server")] +pub use transport::connectrpc::NoopStreamingHandler; #[cfg(feature = "http-server")] pub use transport::http::HttpServer; +#[cfg(feature = "jsonrpc-server")] +pub use transport::jsonrpc::{JsonRpcAdapter, jsonrpc_router, rest_router}; + +// Interceptor re-exports +#[cfg(feature = "tracing")] +pub use interceptor::LoggingInterceptor; // Error re-exports -#[cfg(feature = "http-client")] +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] pub use error::HttpClientError; #[cfg(feature = "http-server")] pub use error::HttpServerError; diff --git a/a2a-rs/src/adapter/storage/sqlx_storage.rs b/a2a-rs/src/adapter/storage/sqlx_storage.rs index f8e1e9b..56b1f53 100644 --- a/a2a-rs/src/adapter/storage/sqlx_storage.rs +++ b/a2a-rs/src/adapter/storage/sqlx_storage.rs @@ -3,9 +3,6 @@ //! This module provides a persistent storage solution using SQLx, supporting //! SQLite, PostgreSQL, and MySQL databases. -#[cfg(feature = "sqlx-storage")] -use std::collections::HashMap; - #[cfg(feature = "sqlx-storage")] use async_trait::async_trait; #[cfg(feature = "sqlx-storage")] @@ -27,50 +24,30 @@ use crate::adapter::business::push_notification::NoopPushNotificationSender; #[cfg(feature = "sqlx-storage")] use crate::domain::{ - A2AError, Artifact, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, - TaskState, TaskStatus, TaskStatusUpdateEvent, + A2AError, ContextId, Message, Task, TaskId, TaskPushNotificationConfig, TaskState, TaskStatus, + VersionedTask, }; #[cfg(feature = "sqlx-storage")] use crate::port::{ - AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - streaming_handler::Subscriber, + AsyncNotificationManager, AsyncPushNotifier, AsyncTaskLifecycle, AsyncTaskQuery, + AsyncTaskVersioning, }; #[cfg(feature = "sqlx-storage")] use std::sync::Arc; -#[cfg(feature = "sqlx-storage")] -use tokio::sync::Mutex; - -#[cfg(feature = "sqlx-storage")] -type StatusSubscribers = Vec + Send + Sync>>; -#[cfg(feature = "sqlx-storage")] -type ArtifactSubscribers = Vec + Send + Sync>>; - -#[cfg(feature = "sqlx-storage")] -/// Structure to hold subscribers for a task -pub(crate) struct TaskSubscribers { - status: StatusSubscribers, - artifacts: ArtifactSubscribers, -} - -#[cfg(feature = "sqlx-storage")] -impl TaskSubscribers { - fn new() -> Self { - Self { - status: Vec::new(), - artifacts: Vec::new(), - } - } -} #[cfg(feature = "sqlx-storage")] -/// SQLx-based task storage for persistent storage +/// SQLx-based task storage for persistent storage. +/// +/// Persistence-only: streaming fan-out lives in +/// [`InMemoryStreamingHandler`](crate::adapter::InMemoryStreamingHandler) and +/// push-webhook delivery behind the [`AsyncPushNotifier`] port (handed out via +/// [`push_notifier`](Self::push_notifier)). The store still owns push-config +/// CRUD ([`AsyncNotificationManager`]) — that is config persistence. pub struct SqlxTaskStorage { /// Database pool pool: SqlitePool, - /// Subscribers for task updates (in-memory for now) - subscribers: Arc>>, - /// Push notification registry + /// Push notification registry (config store + delivery backend) push_notification_registry: Arc, } @@ -121,7 +98,6 @@ impl SqlxTaskStorage { Ok(Self { pool, - subscribers: Arc::new(Mutex::new(HashMap::new())), push_notification_registry: Arc::new(push_registry), }) } @@ -146,7 +122,6 @@ impl SqlxTaskStorage { Ok(Self { pool, - subscribers: Arc::new(Mutex::new(HashMap::new())), push_notification_registry: Arc::new(push_registry), }) } @@ -180,7 +155,6 @@ impl SqlxTaskStorage { Ok(Self { pool, - subscribers: Arc::new(Mutex::new(HashMap::new())), push_notification_registry: Arc::new(push_registry), }) } @@ -199,6 +173,21 @@ impl SqlxTaskStorage { .await .map_err(|e| A2AError::DatabaseError(format!("Migration 002 failed: {}", e)))?; + // Migration 003 is an `ALTER TABLE ADD COLUMN`, which SQLite cannot + // express idempotently. Since base migrations re-run on every `new()`, + // tolerate the "duplicate column name" error on an already-migrated DB. + if let Err(e) = sqlx::query(include_str!("../../../migrations/003_task_version.sql")) + .execute(pool) + .await + { + let msg = e.to_string(); + if !msg.contains("duplicate column name") { + return Err(A2AError::DatabaseError(format!( + "Migration 003 failed: {msg}" + ))); + } + } + Ok(()) } @@ -386,112 +375,23 @@ impl SqlxTaskStorage { Ok(()) } - /// Look up the context_id for a task from the database - async fn get_task_context_id(&self, task_id: &str) -> String { - sqlx::query_scalar::<_, String>("SELECT context_id FROM tasks WHERE id = ?") - .bind(task_id) - .fetch_optional(&self.pool) - .await - .ok() - .flatten() - .unwrap_or_else(|| "default".to_string()) - } - - /// Send a status update to all subscribers for a task - pub(crate) async fn broadcast_status_update( - &self, - task_id: &str, - status: TaskStatus, - ) -> Result<(), A2AError> { - let context_id = self.get_task_context_id(task_id).await; - - // Create the update event - let event = TaskStatusUpdateEvent { - task_id: task_id.to_string(), - context_id, - kind: "status-update".to_string(), - status, - metadata: None, - }; - - // Get all subscribers for this task and notify them - { - let subscribers_guard = self.subscribers.lock().await; - - if let Some(task_subscribers) = subscribers_guard.get(task_id) { - // Clone the subscribers so we don't hold the lock during notification - for subscriber in task_subscribers.status.iter() { - if let Err(e) = subscriber.on_update(event.clone()).await { - eprintln!("Failed to notify subscriber: {}", e); - } - } - } - }; // Lock is dropped here - - // Send push notification if configured - if let Err(e) = self - .push_notification_registry - .send_status_update(task_id, &event) - .await - { - eprintln!("Failed to send push notification: {}", e); - } - - Ok(()) - } - - /// Send an artifact update to all subscribers for a task - pub(crate) async fn broadcast_artifact_update( - &self, - task_id: &str, - artifact: Artifact, - _index: Option, - _final: bool, - ) -> Result<(), A2AError> { - let context_id = self.get_task_context_id(task_id).await; - - // Create the update event - let event = TaskArtifactUpdateEvent { - task_id: task_id.to_string(), - context_id, - kind: "artifact-update".to_string(), - artifact, - append: None, - last_chunk: None, - metadata: None, - }; - - // Get all subscribers for this task - { - let subscribers_guard = self.subscribers.lock().await; - - if let Some(task_subscribers) = subscribers_guard.get(task_id) { - // Clone the subscribers so we don't hold the lock during notification - for subscriber in task_subscribers.artifacts.iter() { - if let Err(e) = subscriber.on_update(event.clone()).await { - eprintln!("Failed to notify subscriber: {}", e); - } - } - } - }; // Lock is dropped here - - // Send push notification if configured - if let Err(e) = self - .push_notification_registry - .send_artifact_update(task_id, &event) - .await - { - eprintln!("Failed to send push notification: {}", e); - } - - Ok(()) + /// Hand out this store's push-notification registry as an + /// [`AsyncPushNotifier`]. + /// + /// The returned notifier shares the same config registry the store writes to + /// via [`AsyncNotificationManager::set_config`], so a config registered on + /// the store is immediately visible to the notifier at the composition edge. + pub fn push_notifier(&self) -> Arc { + self.push_notification_registry.clone() } } #[cfg(feature = "sqlx-storage")] #[async_trait] -impl AsyncTaskManager for SqlxTaskStorage { - async fn create_task(&self, task_id: &str, context_id: &str) -> Result { +impl AsyncTaskLifecycle for SqlxTaskStorage { + async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result { + let task_id = id.as_str(); + let context_id = context_id.as_str(); // Check if task already exists let existing = sqlx::query("SELECT id FROM tasks WHERE id = ?") .bind(task_id) @@ -542,12 +442,13 @@ impl AsyncTaskManager for SqlxTaskStorage { Ok(task) } - async fn update_task_status( + async fn update_status( &self, - task_id: &str, + id: &TaskId, state: TaskState, message: Option, ) -> Result { + let task_id = id.as_str(); // Convert state to string let state_str = match state { TaskState::Submitted => "submitted", @@ -561,13 +462,16 @@ impl AsyncTaskManager for SqlxTaskStorage { TaskState::Unknown => "unknown", }; - // Update task in database - let result = sqlx::query("UPDATE tasks SET status_state = ? WHERE id = ?") - .bind(state_str) - .bind(task_id) - .execute(&self.pool) - .await - .map_err(|e| A2AError::DatabaseError(format!("Failed to update task status: {}", e)))?; + // Update task in database (bump the optimistic-concurrency version) + let result = + sqlx::query("UPDATE tasks SET status_state = ?, version = version + 1 WHERE id = ?") + .bind(state_str) + .bind(task_id) + .execute(&self.pool) + .await + .map_err(|e| { + A2AError::DatabaseError(format!("Failed to update task status: {}", e)) + })?; if result.rows_affected() == 0 { return Err(A2AError::TaskNotFound(task_id.to_string())); @@ -576,19 +480,14 @@ impl AsyncTaskManager for SqlxTaskStorage { // Add to history self.add_to_history(task_id, state, message).await?; - // Get updated task - let task = self.get_task(task_id, None).await?; - - // Clone status before broadcasting to avoid double clone - let status = task.status.clone().take().unwrap_or_default(); - - // Broadcast status update - self.broadcast_status_update(task_id, status).await?; - - Ok(task) + // Persistence only: announcing the change to streaming subscribers is + // the orchestration layer's job (see `TaskStatusBroadcast`), not a side + // effect of the mutator. + self.get(id, None).await } - async fn task_exists(&self, task_id: &str) -> Result { + async fn exists(&self, id: &TaskId) -> Result { + let task_id = id.as_str(); let row = sqlx::query("SELECT id FROM tasks WHERE id = ?") .bind(task_id) .fetch_optional(&self.pool) @@ -600,7 +499,8 @@ impl AsyncTaskManager for SqlxTaskStorage { Ok(row.is_some()) } - async fn get_task(&self, task_id: &str, history_length: Option) -> Result { + async fn get(&self, id: &TaskId, history_length: Option) -> Result { + let task_id = id.as_str(); // Get task from database let row = sqlx::query("SELECT * FROM tasks WHERE id = ?") .bind(task_id) @@ -623,9 +523,10 @@ impl AsyncTaskManager for SqlxTaskStorage { Ok(task) } - async fn cancel_task(&self, task_id: &str) -> Result { + async fn cancel(&self, id: &TaskId) -> Result { + let task_id = id.as_str(); // Get current task - let task = self.get_task(task_id, None).await?; + let task = self.get(id, None).await?; // Only working tasks can be canceled if task.status.state != TaskState::Working { @@ -643,8 +544,8 @@ impl AsyncTaskManager for SqlxTaskStorage { cancel_message.task_id = task_id.to_string(); cancel_message.context_id = task.context_id.clone(); - // Update task status - sqlx::query("UPDATE tasks SET status_state = ? WHERE id = ?") + // Update task status (bump the optimistic-concurrency version) + sqlx::query("UPDATE tasks SET status_state = ?, version = version + 1 WHERE id = ?") .bind("canceled") .bind(task_id) .execute(&self.pool) @@ -655,21 +556,106 @@ impl AsyncTaskManager for SqlxTaskStorage { self.add_to_history(task_id, TaskState::Canceled, Some(cancel_message)) .await?; - // Get updated task - let updated_task = self.get_task(task_id, None).await?; + // Persistence only: the orchestration layer announces the cancellation + // to streaming subscribers (see `TaskStatusBroadcast`). + self.get(id, None).await + } +} - // Clone status before broadcasting to avoid double clone - let status = updated_task.status.clone().take().unwrap_or_default(); +#[cfg(feature = "sqlx-storage")] +impl SqlxTaskStorage { + /// Read the current stored version of a task, or `None` if it doesn't exist. + async fn current_version(&self, task_id: &str) -> Result, A2AError> { + let row = sqlx::query("SELECT version FROM tasks WHERE id = ?") + .bind(task_id) + .fetch_optional(&self.pool) + .await + .map_err(|e| A2AError::DatabaseError(format!("Failed to read task version: {}", e)))?; + match row { + Some(row) => { + let v: i64 = row.try_get("version").map_err(|e| { + A2AError::DatabaseError(format!("Failed to get version column: {}", e)) + })?; + Ok(Some(v as u64)) + } + None => Ok(None), + } + } +} - // Broadcast status update (with final flag set to true) - self.broadcast_status_update(task_id, status).await?; +#[cfg(feature = "sqlx-storage")] +#[async_trait] +impl AsyncTaskVersioning for SqlxTaskStorage { + async fn version(&self, id: &TaskId) -> Result { + self.current_version(id.as_str()) + .await? + .ok_or_else(|| A2AError::TaskNotFound(id.as_str().to_string())) + } - Ok(updated_task) + async fn get_versioned( + &self, + id: &TaskId, + history_length: Option, + ) -> Result { + let task = self.get(id, history_length).await?; + let version = self.version(id).await?; + Ok(VersionedTask::new(task, version)) } - // ===== v1.0.0 Methods ===== + async fn update_status_checked( + &self, + id: &TaskId, + expected: u64, + state: TaskState, + message: Option, + ) -> Result { + let task_id = id.as_str(); + let state_str = match state { + TaskState::Submitted => "submitted", + TaskState::Working => "working", + TaskState::InputRequired => "input-required", + TaskState::Completed => "completed", + TaskState::Canceled => "canceled", + TaskState::Failed => "failed", + TaskState::Rejected => "rejected", + TaskState::AuthRequired => "auth-required", + TaskState::Unknown => "unknown", + }; + + // Conditional update: SQLite applies it atomically, so the row count + // tells us whether the version matched without a separate lock. + let result = sqlx::query( + "UPDATE tasks SET status_state = ?, version = version + 1 WHERE id = ? AND version = ?", + ) + .bind(state_str) + .bind(task_id) + .bind(expected as i64) + .execute(&self.pool) + .await + .map_err(|e| A2AError::DatabaseError(format!("Failed to update task status: {}", e)))?; + + if result.rows_affected() == 0 { + // No row matched: either the task is gone or the version moved on. + return match self.current_version(task_id).await? { + Some(actual) => Err(A2AError::VersionConflict { + id: task_id.to_string(), + expected, + actual, + }), + None => Err(A2AError::TaskNotFound(task_id.to_string())), + }; + } + + self.add_to_history(task_id, state, message).await?; + let task = self.get(id, None).await?; + Ok(VersionedTask::new(task, expected + 1)) + } +} - async fn list_tasks_v3( +#[cfg(feature = "sqlx-storage")] +#[async_trait] +impl AsyncTaskQuery for SqlxTaskStorage { + async fn list( &self, params: &crate::domain::ListTasksParams, ) -> Result { @@ -836,22 +822,30 @@ impl AsyncTaskManager for SqlxTaskStorage { next_page_token, }) } +} - async fn get_push_notification_config( +#[cfg(feature = "sqlx-storage")] +#[async_trait] +impl AsyncNotificationManager for SqlxTaskStorage { + async fn get_config( &self, params: &crate::domain::GetTaskPushNotificationConfigParams, ) -> Result { - // Query the database for the specific config - // Note: push_notification_config_id filtering requires migration 002 to be applied - let config_id = params.push_notification_config_id.as_ref().ok_or_else(|| { - A2AError::TaskNotFound("push_notification_config_id is required".to_string()) - })?; - - let row = sqlx::query( - "SELECT id, task_id, url, token, authentication FROM push_notification_configs WHERE task_id = ? AND id = ?" - ) - .bind(¶ms.id) - .bind(config_id) + // When a specific config id is supplied, filter by it; otherwise fall + // back to the task's config (single-config-per-task convenience, matching + // the in-memory adapter and the v1.0.0 single-config helpers). + // Note: push_notification_config_id filtering requires migration 002 to be applied. + let row = match params.push_notification_config_id.as_ref() { + Some(config_id) => sqlx::query( + "SELECT id, task_id, url, token, authentication FROM push_notification_configs WHERE task_id = ? AND id = ?" + ) + .bind(¶ms.id) + .bind(config_id), + None => sqlx::query( + "SELECT id, task_id, url, token, authentication FROM push_notification_configs WHERE task_id = ? ORDER BY id LIMIT 1" + ) + .bind(¶ms.id), + } .fetch_optional(&self.pool) .await .map_err(|e| A2AError::DatabaseError(format!("Failed to get push config: {}", e)))?; @@ -883,13 +877,18 @@ impl AsyncTaskManager for SqlxTaskStorage { }) } else { Err(A2AError::TaskNotFound(format!( - "Push notification config not found for task {} with id {}", - params.id, config_id + "Push notification config not found for task {}{}", + params.id, + params + .push_notification_config_id + .as_ref() + .map(|id| format!(" with id {}", id)) + .unwrap_or_default() ))) } } - async fn list_push_notification_configs( + async fn list_configs( &self, params: &crate::domain::ListTaskPushNotificationConfigsParams, ) -> Result, A2AError> { @@ -931,30 +930,30 @@ impl AsyncTaskManager for SqlxTaskStorage { Ok(configs) } - async fn delete_push_notification_config( + async fn delete_config( &self, params: &crate::domain::DeleteTaskPushNotificationConfigParams, ) -> Result<(), A2AError> { - // Delete the specific config - let _result = + // Delete the specific config when an id is supplied; otherwise delete all + // configs for the task (single-config-per-task convenience, matching the + // in-memory adapter). + let query = if params.push_notification_config_id.is_empty() { + sqlx::query("DELETE FROM push_notification_configs WHERE task_id = ?").bind(¶ms.id) + } else { sqlx::query("DELETE FROM push_notification_configs WHERE task_id = ? AND id = ?") .bind(¶ms.id) .bind(¶ms.push_notification_config_id) - .execute(&self.pool) - .await - .map_err(|e| { - A2AError::DatabaseError(format!("Failed to delete push config: {}", e)) - })?; + }; + let _result = query + .execute(&self.pool) + .await + .map_err(|e| A2AError::DatabaseError(format!("Failed to delete push config: {}", e)))?; // Idempotent - don't error if already deleted (v1.0.0 spec behavior) Ok(()) } -} -#[cfg(feature = "sqlx-storage")] -#[async_trait] -impl AsyncNotificationManager for SqlxTaskStorage { - async fn set_task_notification( + async fn set_config( &self, config: &TaskPushNotificationConfig, ) -> Result { @@ -996,225 +995,6 @@ impl AsyncNotificationManager for SqlxTaskStorage { result_config.id = config_id; Ok(result_config) } - - async fn get_task_notification( - &self, - task_id: &str, - ) -> Result { - // Get from database (get first config for backwards compatibility) - let row = - sqlx::query("SELECT id, url, token, authentication FROM push_notification_configs WHERE task_id = ? LIMIT 1") - .bind(task_id) - .fetch_optional(&self.pool) - .await - .map_err(|e| { - A2AError::DatabaseError(format!( - "Failed to get push notification config: {}", - e - )) - })?; - - if let Some(row) = row { - let id: String = row - .try_get("id") - .map_err(|e| A2AError::DatabaseError(format!("Failed to get id: {}", e)))?; - let url: String = row - .try_get("url") - .map_err(|e| A2AError::DatabaseError(format!("Failed to get url: {}", e)))?; - let token: Option = row.try_get("token").ok(); - let auth_json: Option = row.try_get("authentication").ok(); - - let auth_info = if let Some(auth_str) = auth_json { - serde_json::from_str(&auth_str).ok() - } else { - None - }; - - Ok(TaskPushNotificationConfig { - task_id: task_id.to_string(), - id, - url, - token: token.unwrap_or_default(), - authentication: auth_info.into(), - tenant: "".to_string(), - ..Default::default() - }) - } else { - Err(A2AError::TaskNotFound(format!( - "No push notification config found for task {}", - task_id - ))) - } - } - - async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> { - // Remove from database - sqlx::query("DELETE FROM push_notification_configs WHERE task_id = ?") - .bind(task_id) - .execute(&self.pool) - .await - .map_err(|e| { - A2AError::DatabaseError(format!("Failed to remove push notification config: {}", e)) - })?; - - // Unregister from registry - self.push_notification_registry.unregister(task_id).await?; - Ok(()) - } -} - -#[cfg(feature = "sqlx-storage")] -#[async_trait] -impl AsyncStreamingHandler for SqlxTaskStorage { - async fn add_status_subscriber( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result { - // Add the subscriber - { - let mut subscribers_guard = self.subscribers.lock().await; - - let task_subscribers = subscribers_guard - .entry(task_id.to_string()) - .or_insert_with(TaskSubscribers::new); - - task_subscribers.status.push(subscriber); - } // Lock is dropped here - - // Try to get the current status to send as an initial update - // But don't fail if the task doesn't exist yet - the subscriber will get updates when it's created - if let Ok(task) = self.get_task(task_id, None).await { - let _ = self - .broadcast_status_update(task_id, (*task.status).clone()) - .await; - } - - Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4())) - } - - async fn add_artifact_subscriber( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result { - // Add the subscriber - { - let mut subscribers_guard = self.subscribers.lock().await; - - let task_subscribers = subscribers_guard - .entry(task_id.to_string()) - .or_insert_with(TaskSubscribers::new); - - task_subscribers.artifacts.push(subscriber); - } // Lock is dropped here - - // If there are existing artifacts, broadcast them - // But don't fail if the task doesn't exist yet - the subscriber will get updates when it's created - if let Ok(task) = self.get_task(task_id, None).await { - for artifact in task.artifacts { - let _ = self - .broadcast_artifact_update(task_id, artifact, None, false) - .await; - } - } - - Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4())) - } - - async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Subscription removal by ID requires storage layer refactoring".to_string(), - )) - } - - async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { - // Remove all subscribers - { - let mut subscribers_guard = self.subscribers.lock().await; - subscribers_guard.remove(task_id); - } // Lock is dropped here - - Ok(()) - } - - async fn get_subscriber_count(&self, task_id: &str) -> Result { - let subscribers_guard = self.subscribers.lock().await; - - if let Some(task_subscribers) = subscribers_guard.get(task_id) { - Ok(task_subscribers.status.len() + task_subscribers.artifacts.len()) - } else { - Ok(0) - } - } - - async fn broadcast_status_update( - &self, - task_id: &str, - update: TaskStatusUpdateEvent, - ) -> Result<(), A2AError> { - self.broadcast_status_update(task_id, update.status).await - } - - async fn broadcast_artifact_update( - &self, - task_id: &str, - update: TaskArtifactUpdateEvent, - ) -> Result<(), A2AError> { - self.broadcast_artifact_update( - task_id, - update.artifact, - None, - update.last_chunk.unwrap_or(false), - ) - .await - } - - async fn status_update_stream( - &self, - _task_id: &str, - ) -> Result< - std::pin::Pin< - Box> + Send>, - >, - A2AError, - > { - Err(A2AError::UnsupportedOperation( - "Status update stream requires storage layer refactoring".to_string(), - )) - } - - async fn artifact_update_stream( - &self, - _task_id: &str, - ) -> Result< - std::pin::Pin< - Box> + Send>, - >, - A2AError, - > { - Err(A2AError::UnsupportedOperation( - "Artifact update stream requires storage layer refactoring".to_string(), - )) - } - - async fn combined_update_stream( - &self, - _task_id: &str, - ) -> Result< - std::pin::Pin< - Box< - dyn futures::Stream< - Item = Result, - > + Send, - >, - >, - A2AError, - > { - Err(A2AError::UnsupportedOperation( - "Combined update stream requires storage layer refactoring".to_string(), - )) - } } #[cfg(feature = "sqlx-storage")] @@ -1222,7 +1002,6 @@ impl Clone for SqlxTaskStorage { fn clone(&self) -> Self { Self { pool: self.pool.clone(), - subscribers: self.subscribers.clone(), push_notification_registry: self.push_notification_registry.clone(), } } diff --git a/a2a-rs/src/adapter/storage/task_storage.rs b/a2a-rs/src/adapter/storage/task_storage.rs index 25651cc..f4cb1f8 100644 --- a/a2a-rs/src/adapter/storage/task_storage.rs +++ b/a2a-rs/src/adapter/storage/task_storage.rs @@ -17,39 +17,32 @@ use crate::adapter::business::push_notification::HttpPushNotificationSender; #[cfg(not(feature = "http-client"))] use crate::adapter::business::push_notification::NoopPushNotificationSender; use crate::domain::{ - A2AError, Artifact, Message, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, - TaskState, TaskStatus, TaskStatusUpdateEvent, + A2AError, ContextId, Message, Task, TaskId, TaskPushNotificationConfig, TaskState, + VersionedTask, }; use crate::port::{ - AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - streaming_handler::Subscriber, + AsyncNotificationManager, AsyncPushNotifier, AsyncTaskLifecycle, AsyncTaskQuery, + AsyncTaskVersioning, }; -type StatusSubscribers = Vec + Send + Sync>>; -type ArtifactSubscribers = Vec + Send + Sync>>; - -/// Structure to hold subscribers for a task -pub(crate) struct TaskSubscribers { - status: StatusSubscribers, - artifacts: ArtifactSubscribers, -} - -impl TaskSubscribers { - fn new() -> Self { - Self { - status: Vec::new(), - artifacts: Vec::new(), - } - } -} - -/// Simple in-memory task storage for testing and example purposes +/// Simple in-memory task storage for testing and example purposes. +/// +/// Persistence-only: streaming fan-out lives in +/// [`InMemoryStreamingHandler`](crate::adapter::InMemoryStreamingHandler) and +/// push-webhook delivery behind the [`AsyncPushNotifier`] port (this struct hands +/// out its registry via [`push_notifier`](Self::push_notifier)). The store still +/// owns push-config CRUD ([`AsyncNotificationManager`]) because that is config +/// *persistence*. pub struct InMemoryTaskStorage { /// Tasks stored by ID pub(crate) tasks: Arc>>, - /// Subscribers for task updates - pub(crate) subscribers: Arc>>, - /// Push notification registry + /// Per-task optimistic-concurrency version, bumped on every mutation. + /// + /// A separate map keyed by the same task id. Mutators always lock `tasks` + /// first and `versions` second, so the two stay consistent and never + /// deadlock (see [`AsyncTaskVersioning`]). + pub(crate) versions: Arc>>, + /// Push notification registry (config store + delivery backend) pub(crate) push_notification_registry: Arc, } @@ -66,7 +59,7 @@ impl InMemoryTaskStorage { Self { tasks: Arc::new(Mutex::new(HashMap::new())), - subscribers: Arc::new(Mutex::new(HashMap::new())), + versions: Arc::new(Mutex::new(HashMap::new())), push_notification_registry: Arc::new(push_registry), } } @@ -77,31 +70,29 @@ impl InMemoryTaskStorage { Self { tasks: Arc::new(Mutex::new(HashMap::new())), - subscribers: Arc::new(Mutex::new(HashMap::new())), + versions: Arc::new(Mutex::new(HashMap::new())), push_notification_registry: Arc::new(push_registry), } } - /// Add a status update subscriber for streaming (convenience method) - pub async fn add_status_subscriber_legacy( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result<(), A2AError> { - self.add_status_subscriber(task_id, subscriber) - .await - .map(|_| ()) + /// Bump (or initialize) the stored version for a task, returning the new + /// value. Callers already hold the `tasks` lock; this acquires `versions` + /// second, preserving the global lock order. + async fn bump_version(&self, task_id: &str) -> u64 { + let mut versions = self.versions.lock().await; + let v = versions.entry(task_id.to_string()).or_insert(0); + *v += 1; + *v } - /// Add an artifact update subscriber for streaming (convenience method) - pub async fn add_artifact_subscriber_legacy( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result<(), A2AError> { - self.add_artifact_subscriber(task_id, subscriber) - .await - .map(|_| ()) + /// Hand out this store's push-notification registry as an + /// [`AsyncPushNotifier`]. + /// + /// The returned notifier shares the same config registry the store writes to + /// via [`AsyncNotificationManager::set_config`], so a config registered on + /// the store is immediately visible to the notifier at the composition edge. + pub fn push_notifier(&self) -> Arc { + self.push_notification_registry.clone() } } @@ -111,158 +102,11 @@ impl Default for InMemoryTaskStorage { } } -impl InMemoryTaskStorage { - /// Look up the context_id for a task - async fn get_task_context_id(&self, task_id: &str) -> String { - let tasks_guard = self.tasks.lock().await; - tasks_guard - .get(task_id) - .map(|t| t.context_id.clone()) - .unwrap_or_else(|| "default".to_string()) - } - - /// Send a status update to all subscribers for a task - pub(crate) async fn broadcast_status_update( - &self, - task_id: &str, - status: TaskStatus, - ) -> Result<(), A2AError> { - let context_id = self.get_task_context_id(task_id).await; - - // Create the update event - let event = TaskStatusUpdateEvent { - task_id: task_id.to_string(), - context_id, - kind: "status-update".to_string(), - status: status.clone(), - metadata: None, - }; - - #[cfg(feature = "tracing")] - tracing::debug!( - task_id = %task_id, - state = ?status.state, - "📡 Broadcasting status update to subscribers" - ); - - // Get all subscribers for this task and notify them - let subscriber_count = { - let subscribers_guard = self.subscribers.lock().await; - - if let Some(task_subscribers) = subscribers_guard.get(task_id) { - let count = task_subscribers.status.len(); - #[cfg(feature = "tracing")] - tracing::info!( - task_id = %task_id, - subscriber_count = count, - state = ?status.state, - "📡 Notifying WebSocket subscribers of status update" - ); - - // Clone the subscribers so we don't hold the lock during notification - for (i, subscriber) in task_subscribers.status.iter().enumerate() { - if let Err(e) = subscriber.on_update(event.clone()).await { - #[cfg(feature = "tracing")] - tracing::error!( - task_id = %task_id, - subscriber_index = i, - error = %e, - "❌ Failed to notify subscriber" - ); - eprintln!("Failed to notify subscriber: {}", e); - } else { - #[cfg(feature = "tracing")] - tracing::debug!( - task_id = %task_id, - subscriber_index = i, - "✅ Successfully notified subscriber" - ); - } - } - count - } else { - // No subscribers is the steady-state for tasks created via - // message/send (no streaming requested). Only worth tracing - // at DEBUG — WARN here floods logs in normal operation. - #[cfg(feature = "tracing")] - tracing::debug!( - task_id = %task_id, - "no WebSocket subscribers for task; status broadcast skipped" - ); - 0 - } - }; // Lock is dropped here - - #[cfg(feature = "tracing")] - tracing::debug!( - task_id = %task_id, - notified_count = subscriber_count, - "📡 Finished broadcasting to WebSocket subscribers" - ); - - // Send push notification if configured - if let Err(e) = self - .push_notification_registry - .send_status_update(task_id, &event) - .await - { - eprintln!("Failed to send push notification: {}", e); - } - - Ok(()) - } - - /// Send an artifact update to all subscribers for a task - pub(crate) async fn broadcast_artifact_update( - &self, - task_id: &str, - artifact: Artifact, - _index: Option, - _final: bool, - ) -> Result<(), A2AError> { - let context_id = self.get_task_context_id(task_id).await; - - // Create the update event - let event = TaskArtifactUpdateEvent { - task_id: task_id.to_string(), - context_id, - kind: "artifact-update".to_string(), - artifact, - append: None, - last_chunk: None, - metadata: None, - }; - - // Get all subscribers for this task - { - let subscribers_guard = self.subscribers.lock().await; - - if let Some(task_subscribers) = subscribers_guard.get(task_id) { - // Clone the subscribers so we don't hold the lock during notification - for subscriber in task_subscribers.artifacts.iter() { - if let Err(e) = subscriber.on_update(event.clone()).await { - eprintln!("Failed to notify subscriber: {}", e); - } - } - } - }; // Lock is dropped here - - // Send push notification if configured - if let Err(e) = self - .push_notification_registry - .send_artifact_update(task_id, &event) - .await - { - eprintln!("Failed to send push notification: {}", e); - } - - Ok(()) - } -} - #[async_trait] -impl AsyncTaskManager for InMemoryTaskStorage { - async fn create_task(&self, task_id: &str, context_id: &str) -> Result { +impl AsyncTaskLifecycle for InMemoryTaskStorage { + async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result { + let task_id = id.as_str(); + let context_id = context_id.as_str(); let mut tasks_guard = self.tasks.lock().await; if tasks_guard.contains_key(task_id) { @@ -274,16 +118,18 @@ impl AsyncTaskManager for InMemoryTaskStorage { let task = Task::new(task_id.to_string(), context_id.to_string()); tasks_guard.insert(task_id.to_string(), task.clone()); + self.bump_version(task_id).await; // version 0 -> 1 Ok(task) } - async fn update_task_status( + async fn update_status( &self, - task_id: &str, + id: &TaskId, state: TaskState, message: Option, ) -> Result { + let task_id = id.as_str(); let mut tasks_guard = self.tasks.lock().await; let task = tasks_guard @@ -292,27 +138,23 @@ impl AsyncTaskManager for InMemoryTaskStorage { // Update the task status with the optional message task.update_status(state, message); + let updated = task.clone(); + self.bump_version(task_id).await; - // Clone status before cloning the entire task to avoid double clone - let status_for_broadcast = task.status.clone().into_option().unwrap_or_default(); - let updated_task = task.clone(); - - // Release the lock before broadcasting - drop(tasks_guard); - - // Broadcast status update - self.broadcast_status_update(task_id, status_for_broadcast) - .await?; - - Ok(updated_task) + // Persistence only: announcing the change to streaming subscribers is + // the orchestration layer's job (see `TaskStatusBroadcast`), not a side + // effect of the mutator. + Ok(updated) } - async fn task_exists(&self, task_id: &str) -> Result { + async fn exists(&self, id: &TaskId) -> Result { + let task_id = id.as_str(); let tasks_guard = self.tasks.lock().await; Ok(tasks_guard.contains_key(task_id)) } - async fn get_task(&self, task_id: &str, history_length: Option) -> Result { + async fn get(&self, id: &TaskId, history_length: Option) -> Result { + let task_id = id.as_str(); // Get the task let task = { let tasks_guard = self.tasks.lock().await; @@ -328,64 +170,109 @@ impl AsyncTaskManager for InMemoryTaskStorage { Ok(task) } - async fn cancel_task(&self, task_id: &str) -> Result { - // Get and update the task - let (task, status_for_broadcast) = { - let mut tasks_guard = self.tasks.lock().await; + async fn cancel(&self, id: &TaskId) -> Result { + let task_id = id.as_str(); + let mut tasks_guard = self.tasks.lock().await; - let Some(task) = tasks_guard.get(task_id) else { - return Err(A2AError::TaskNotFound(task_id.to_string())); - }; + let Some(task) = tasks_guard.get(task_id) else { + return Err(A2AError::TaskNotFound(task_id.to_string())); + }; - let mut updated_task = task.clone(); + let mut updated_task = task.clone(); - // Only working tasks can be canceled - if updated_task.status.state != TaskState::Working { - return Err(A2AError::TaskNotCancelable(format!( - "Task {} is in state {:?} and cannot be canceled", - task_id, updated_task.status.state - ))); - } + // Only working tasks can be canceled + if updated_task.status.state != TaskState::Working { + return Err(A2AError::TaskNotCancelable(format!( + "Task {} is in state {:?} and cannot be canceled", + task_id, updated_task.status.state + ))); + } - // Create a cancellation message to add to history - let cancel_message = Message { - role: ::buffa::EnumValue::from(crate::domain::Role::Agent), - parts: vec![crate::domain::Part::text(format!( - "Task {} canceled.", - task_id - ))], - message_id: uuid::Uuid::new_v4().to_string(), - task_id: task_id.to_string(), - context_id: updated_task.context_id.clone(), - ..Default::default() - }; + // Create a cancellation message to add to history + let cancel_message = Message { + role: ::buffa::EnumValue::from(crate::domain::Role::Agent), + parts: vec![crate::domain::Part::text(format!( + "Task {} canceled.", + task_id + ))], + message_id: uuid::Uuid::new_v4().to_string(), + task_id: task_id.to_string(), + context_id: updated_task.context_id.clone(), + ..Default::default() + }; - // Update the status with the cancellation message to track in history - updated_task.update_status(TaskState::Canceled, Some(cancel_message)); + // Update the status with the cancellation message to track in history + updated_task.update_status(TaskState::Canceled, Some(cancel_message)); + tasks_guard.insert(task_id.to_string(), updated_task.clone()); + self.bump_version(task_id).await; - // Clone status before updating storage to avoid cloning task twice - let status_for_broadcast = updated_task - .status - .clone() - .into_option() - .unwrap_or_default(); - tasks_guard.insert(task_id.to_string(), updated_task.clone()); - - // Drop guard early and return status for use after broadcasting - drop(tasks_guard); - (updated_task, status_for_broadcast) - }; // Lock is dropped here + // Persistence only: the orchestration layer announces the cancellation + // to streaming subscribers (see `TaskStatusBroadcast`). + Ok(updated_task) + } +} - // Broadcast status update (with final flag set to true) - self.broadcast_status_update(task_id, status_for_broadcast) - .await?; +#[async_trait] +impl AsyncTaskVersioning for InMemoryTaskStorage { + async fn version(&self, id: &TaskId) -> Result { + let task_id = id.as_str(); + let tasks_guard = self.tasks.lock().await; + if !tasks_guard.contains_key(task_id) { + return Err(A2AError::TaskNotFound(task_id.to_string())); + } + let versions = self.versions.lock().await; + Ok(versions.get(task_id).copied().unwrap_or(0)) + } - Ok(task) + async fn get_versioned( + &self, + id: &TaskId, + history_length: Option, + ) -> Result { + let task_id = id.as_str(); + let tasks_guard = self.tasks.lock().await; + let Some(task) = tasks_guard.get(task_id) else { + return Err(A2AError::TaskNotFound(task_id.to_string())); + }; + let task = task.with_limited_history(history_length); + let versions = self.versions.lock().await; + let version = versions.get(task_id).copied().unwrap_or(0); + Ok(VersionedTask::new(task, version)) } - // ===== v1.0.0 New Methods ===== + async fn update_status_checked( + &self, + id: &TaskId, + expected: u64, + state: TaskState, + message: Option, + ) -> Result { + let task_id = id.as_str(); + // Lock order: tasks, then versions — the compare-and-swap holds both so + // the check and the bump are atomic against every other mutator. + let mut tasks_guard = self.tasks.lock().await; + let task = tasks_guard + .get_mut(task_id) + .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?; + let mut versions = self.versions.lock().await; + let current = versions.get(task_id).copied().unwrap_or(0); + if current != expected { + return Err(A2AError::VersionConflict { + id: task_id.to_string(), + expected, + actual: current, + }); + } + task.update_status(state, message); + let new_version = current + 1; + versions.insert(task_id.to_string(), new_version); + Ok(VersionedTask::new(task.clone(), new_version)) + } +} - async fn list_tasks_v3( +#[async_trait] +impl AsyncTaskQuery for InMemoryTaskStorage { + async fn list( &self, params: &crate::domain::ListTasksParams, ) -> Result { @@ -487,46 +374,15 @@ impl AsyncTaskManager for InMemoryTaskStorage { next_page_token, }) } - - async fn get_push_notification_config( - &self, - params: &crate::domain::GetTaskPushNotificationConfigParams, - ) -> Result { - // For in-memory storage, we don't support multiple configs per task yet - // Just use the existing get_task_notification method - self.get_task_notification(¶ms.id).await - } - - async fn list_push_notification_configs( - &self, - params: &crate::domain::ListTaskPushNotificationConfigsParams, - ) -> Result, A2AError> { - // For in-memory storage, we only support one config per task - // Return it as a single-item vec - match self - .push_notification_registry - .get_config(¶ms.id) - .await? - { - Some(config) => Ok(vec![config]), - None => Ok(vec![]), - } - } - - async fn delete_push_notification_config( - &self, - params: &crate::domain::DeleteTaskPushNotificationConfigParams, - ) -> Result<(), A2AError> { - // For in-memory storage, just remove the single config - // In a full implementation, would need to handle config_id - self.remove_task_notification(¶ms.id).await - } } -// AsyncNotificationManager implementation +// AsyncNotificationManager implementation. +// +// In-memory storage keeps a single config per task in the push-notification +// registry, so the multi-config CRUD surface is expressed in those terms. #[async_trait] impl AsyncNotificationManager for InMemoryTaskStorage { - async fn set_task_notification( + async fn set_config( &self, config: &TaskPushNotificationConfig, ) -> Result { @@ -551,199 +407,109 @@ impl AsyncNotificationManager for InMemoryTaskStorage { Ok(config.clone()) } - async fn get_task_notification( + async fn get_config( &self, - task_id: &str, + params: &crate::domain::GetTaskPushNotificationConfigParams, ) -> Result { - // Get the push notification config from the registry - match self.push_notification_registry.get_config(task_id).await? { + match self + .push_notification_registry + .get_config(¶ms.id) + .await? + { Some(config) => Ok(config), None => Err(A2AError::PushNotificationNotSupported), } } - async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> { - self.push_notification_registry.unregister(task_id).await?; - Ok(()) - } -} - -// AsyncStreamingHandler implementation -#[async_trait] -impl AsyncStreamingHandler for InMemoryTaskStorage { - async fn add_status_subscriber( + async fn list_configs( &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result { - #[cfg(feature = "tracing")] - tracing::info!( - task_id = %task_id, - "✅ Adding WebSocket subscriber for status updates" - ); - - // Add the subscriber + params: &crate::domain::ListTaskPushNotificationConfigsParams, + ) -> Result, A2AError> { + // In-memory storage supports one config per task; return it as a + // single-item vec (or empty if none registered). + match self + .push_notification_registry + .get_config(¶ms.id) + .await? { - let mut subscribers_guard = self.subscribers.lock().await; - - let task_subscribers = subscribers_guard - .entry(task_id.to_string()) - .or_insert_with(TaskSubscribers::new); - - task_subscribers.status.push(subscriber); - - #[cfg(feature = "tracing")] - tracing::info!( - task_id = %task_id, - subscriber_count = task_subscribers.status.len(), - "✅ WebSocket subscriber added successfully" - ); - } // Lock is dropped here - - // Try to get the current status to send as an initial update - // But don't fail if the task doesn't exist yet - the subscriber will get updates when it's created - if let Ok(task) = self.get_task(task_id, None).await { - let _ = self - .broadcast_status_update( - task_id, - task.status.clone().into_option().unwrap_or_default(), - ) - .await; + Some(config) => Ok(vec![config]), + None => Ok(vec![]), } - - Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4())) } - async fn add_artifact_subscriber( + async fn delete_config( &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result { - // Add the subscriber - { - let mut subscribers_guard = self.subscribers.lock().await; - - let task_subscribers = subscribers_guard - .entry(task_id.to_string()) - .or_insert_with(TaskSubscribers::new); - - task_subscribers.artifacts.push(subscriber); - } // Lock is dropped here - - // If there are existing artifacts, broadcast them - // But don't fail if the task doesn't exist yet - the subscriber will get updates when it's created - if let Ok(task) = self.get_task(task_id, None).await { - for artifact in &task.artifacts { - let _ = self - .broadcast_artifact_update(task_id, artifact.clone(), None, false) - .await; - } - } - - Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4())) - } - - async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Subscription removal by ID requires storage layer refactoring".to_string(), - )) - } - - async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { - // Remove all subscribers - { - let mut subscribers_guard = self.subscribers.lock().await; - subscribers_guard.remove(task_id); - } // Lock is dropped here - + params: &crate::domain::DeleteTaskPushNotificationConfigParams, + ) -> Result<(), A2AError> { + // In-memory storage keeps a single config per task, so config_id is + // not used for lookup. Idempotent per the v1.0.0 spec. + self.push_notification_registry + .unregister(¶ms.id) + .await?; Ok(()) } +} - async fn get_subscriber_count(&self, task_id: &str) -> Result { - let subscribers_guard = self.subscribers.lock().await; - - if let Some(task_subscribers) = subscribers_guard.get(task_id) { - Ok(task_subscribers.status.len() + task_subscribers.artifacts.len()) - } else { - Ok(0) +impl Clone for InMemoryTaskStorage { + fn clone(&self) -> Self { + Self { + tasks: self.tasks.clone(), + versions: self.versions.clone(), + push_notification_registry: self.push_notification_registry.clone(), } } +} - async fn broadcast_status_update( - &self, - task_id: &str, - update: TaskStatusUpdateEvent, - ) -> Result<(), A2AError> { - self.broadcast_status_update(task_id, update.status).await - } +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::ContextId; - async fn broadcast_artifact_update( - &self, - task_id: &str, - update: TaskArtifactUpdateEvent, - ) -> Result<(), A2AError> { - self.broadcast_artifact_update( - task_id, - update.artifact, - None, - update.last_chunk.unwrap_or(false), - ) - .await + fn tid(s: &str) -> TaskId { + s.parse().unwrap() } - - async fn status_update_stream( - &self, - _task_id: &str, - ) -> Result< - std::pin::Pin< - Box> + Send>, - >, - A2AError, - > { - Err(A2AError::UnsupportedOperation( - "Status update stream requires storage layer refactoring".to_string(), - )) + fn cid(s: &str) -> ContextId { + s.parse().unwrap() } - async fn artifact_update_stream( - &self, - _task_id: &str, - ) -> Result< - std::pin::Pin< - Box> + Send>, - >, - A2AError, - > { - Err(A2AError::UnsupportedOperation( - "Artifact update stream requires storage layer refactoring".to_string(), - )) - } + #[tokio::test] + async fn versioning_tracks_and_guards_mutations() { + let store = InMemoryTaskStorage::new(); + store.create(&tid("t1"), &cid("c1")).await.unwrap(); + assert_eq!(store.version(&tid("t1")).await.unwrap(), 1); - async fn combined_update_stream( - &self, - _task_id: &str, - ) -> Result< - std::pin::Pin< - Box< - dyn futures::Stream< - Item = Result, - > + Send, - >, - >, - A2AError, - > { - Err(A2AError::UnsupportedOperation( - "Combined update stream requires storage layer refactoring".to_string(), - )) - } -} + // Unversioned mutations bump the version, keeping the two views in sync. + store + .update_status(&tid("t1"), TaskState::Working, None) + .await + .unwrap(); + let snap = store.get_versioned(&tid("t1"), None).await.unwrap(); + assert_eq!(snap.version, 2); -impl Clone for InMemoryTaskStorage { - fn clone(&self) -> Self { - Self { - tasks: self.tasks.clone(), - subscribers: self.subscribers.clone(), - push_notification_registry: self.push_notification_registry.clone(), - } + // Stale conditional update is rejected and leaves the task unchanged. + let err = store + .update_status_checked(&tid("t1"), 1, TaskState::Completed, None) + .await + .unwrap_err(); + assert!(matches!( + err, + A2AError::VersionConflict { + expected: 1, + actual: 2, + .. + } + )); + assert_eq!( + store.get(&tid("t1"), None).await.unwrap().status.state, + TaskState::Working + ); + + // Current-version conditional update succeeds and bumps. + let ok = store + .update_status_checked(&tid("t1"), 2, TaskState::Completed, None) + .await + .unwrap(); + assert_eq!(ok.version, 3); + assert_eq!(ok.task.status.state, TaskState::Completed); } } diff --git a/a2a-rs/src/adapter/streaming/in_memory.rs b/a2a-rs/src/adapter/streaming/in_memory.rs new file mode 100644 index 0000000..cdeae99 --- /dev/null +++ b/a2a-rs/src/adapter/streaming/in_memory.rs @@ -0,0 +1,385 @@ +//! In-memory streaming fan-out adapter. +//! +//! `InMemoryStreamingHandler` is the [`AsyncStreamingHandler`] adapter. It owns +//! **only** the per-task fan-out state — a broadcast channel plus a bounded +//! replay buffer, and an optional set of synchronous callback subscribers — and +//! fans broadcast events out to live `combined_update_stream` readers and to +//! those subscribers. It deliberately does *not*: +//! +//! - touch the task store (so it cannot replay current task state on subscribe — +//! the initial `Task` snapshot is delivered by the application service before +//! stream items, which is spec-compliant), nor +//! - fire push-webhook notifications (that is the [`AsyncPushNotifier`] port's +//! job, orchestrated by the +//! [`TaskStatusBroadcast`](crate::application::TaskStatusBroadcast) mixin). +//! +//! Each broadcast event is assigned a per-task monotonic id and retained in a +//! bounded ring buffer, so a reconnecting client can resume after a disconnect +//! by passing the last id it observed (`from_event_id`); the handler replays the +//! buffered tail with a greater id before switching to live updates. +//! +//! [`AsyncPushNotifier`]: crate::port::AsyncPushNotifier + +use std::collections::HashMap; +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; + +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use tokio::sync::Mutex; +use tokio::sync::broadcast; + +use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent}; +use crate::port::AsyncStreamingHandler; +use crate::port::streaming_handler::{SeqEvent, Subscriber, UpdateEvent}; + +type StatusSubscribers = Vec + Send + Sync>>; +type ArtifactSubscribers = Vec + Send + Sync>>; + +/// Capacity of the per-task broadcast channel and replay ring buffer. +const CHANNEL_CAPACITY: usize = 256; +const RING_CAPACITY: usize = 256; + +/// Per-task fan-out state: a broadcast channel for live readers, a bounded +/// replay buffer keyed by monotonic id, and any synchronous callback +/// subscribers. +struct TaskChannel { + sender: broadcast::Sender, + next_id: u64, + buffer: VecDeque, + status: StatusSubscribers, + artifacts: ArtifactSubscribers, +} + +impl TaskChannel { + fn new() -> Self { + let (sender, _) = broadcast::channel(CHANNEL_CAPACITY); + Self { + sender, + next_id: 0, + buffer: VecDeque::with_capacity(RING_CAPACITY), + status: Vec::new(), + artifacts: Vec::new(), + } + } + + /// Assign the next id, retain the event for replay, and publish it to live + /// readers. Returns the sequenced event for any further fan-out. + fn publish(&mut self, event: UpdateEvent) -> SeqEvent { + self.next_id += 1; + let seq = SeqEvent::new(self.next_id, event); + if self.buffer.len() == RING_CAPACITY { + self.buffer.pop_front(); + } + self.buffer.push_back(seq.clone()); + // A send error just means there are no live receivers; the buffer still + // retains the event for a later resume, so the error is ignored. + let _ = self.sender.send(seq.clone()); + seq + } + + /// Buffered events with an id strictly greater than `from`, in order. + fn replay_after(&self, from: u64) -> Vec { + self.buffer + .iter() + .filter(|e| e.id > from) + .cloned() + .collect() + } +} + +/// In-memory [`AsyncStreamingHandler`]: per-task broadcast fan-out with a +/// bounded replay buffer for Last-Event-ID resumption. +/// +/// Cloning shares the underlying per-task state (an `Arc>`), so a clone +/// observes the same channels and subscribers. +#[derive(Clone, Default)] +pub struct InMemoryStreamingHandler { + tasks: Arc>>, +} + +impl InMemoryStreamingHandler { + /// Create an empty streaming handler. + pub fn new() -> Self { + Self::default() + } +} + +#[async_trait] +impl AsyncStreamingHandler for InMemoryStreamingHandler { + async fn add_status_subscriber( + &self, + task_id: &str, + subscriber: Box + Send + Sync>, + ) -> Result { + #[cfg(feature = "tracing")] + tracing::info!( + task_id = %task_id, + "✅ Adding subscriber for status updates" + ); + + let mut guard = self.tasks.lock().await; + guard + .entry(task_id.to_string()) + .or_insert_with(TaskChannel::new) + .status + .push(subscriber); + + Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4())) + } + + async fn add_artifact_subscriber( + &self, + task_id: &str, + subscriber: Box + Send + Sync>, + ) -> Result { + let mut guard = self.tasks.lock().await; + guard + .entry(task_id.to_string()) + .or_insert_with(TaskChannel::new) + .artifacts + .push(subscriber); + + Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4())) + } + + async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { + Err(A2AError::UnsupportedOperation( + "Subscription removal by ID is not supported by the in-memory streaming handler" + .to_string(), + )) + } + + async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { + let mut guard = self.tasks.lock().await; + guard.remove(task_id); + Ok(()) + } + + async fn get_subscriber_count(&self, task_id: &str) -> Result { + let guard = self.tasks.lock().await; + Ok(guard + .get(task_id) + .map(|c| c.status.len() + c.artifacts.len() + c.sender.receiver_count()) + .unwrap_or(0)) + } + + async fn broadcast_status_update( + &self, + task_id: &str, + update: TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { + #[cfg(feature = "tracing")] + tracing::debug!( + task_id = %task_id, + state = ?update.status.state, + "📡 Broadcasting status update to subscribers" + ); + + let mut guard = self.tasks.lock().await; + let channel = guard + .entry(task_id.to_string()) + .or_insert_with(TaskChannel::new); + channel.publish(UpdateEvent::StatusUpdate(update.clone())); + for subscriber in channel.status.iter() { + if let Err(e) = subscriber.on_update(update.clone()).await { + #[cfg(feature = "tracing")] + tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber"); + #[cfg(not(feature = "tracing"))] + let _ = e; + } + } + Ok(()) + } + + async fn broadcast_artifact_update( + &self, + task_id: &str, + update: TaskArtifactUpdateEvent, + ) -> Result<(), A2AError> { + let mut guard = self.tasks.lock().await; + let channel = guard + .entry(task_id.to_string()) + .or_insert_with(TaskChannel::new); + channel.publish(UpdateEvent::ArtifactUpdate(update.clone())); + for subscriber in channel.artifacts.iter() { + if let Err(e) = subscriber.on_update(update.clone()).await { + #[cfg(feature = "tracing")] + tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber"); + #[cfg(not(feature = "tracing"))] + let _ = e; + } + } + Ok(()) + } + + async fn status_update_stream( + &self, + _task_id: &str, + ) -> Result> + Send>>, A2AError> + { + Err(A2AError::UnsupportedOperation( + "Status-only update stream is not supported; use combined_update_stream".to_string(), + )) + } + + async fn artifact_update_stream( + &self, + _task_id: &str, + ) -> Result< + Pin> + Send>>, + A2AError, + > { + Err(A2AError::UnsupportedOperation( + "Artifact-only update stream is not supported; use combined_update_stream".to_string(), + )) + } + + async fn combined_update_stream( + &self, + task_id: &str, + from_event_id: Option, + ) -> Result> + Send>>, A2AError> { + let mut guard = self.tasks.lock().await; + let channel = guard + .entry(task_id.to_string()) + .or_insert_with(TaskChannel::new); + let receiver = channel.sender.subscribe(); + let replay = from_event_id + .map(|from| channel.replay_after(from)) + .unwrap_or_default(); + drop(guard); + + let live = futures::stream::unfold(receiver, |mut rx| async move { + match rx.recv().await { + Ok(event) => Some((Ok(event), rx)), + // Reader fell behind the ring buffer: surface an error so a + // resilient client reconnects and resumes from its last id. + Err(broadcast::error::RecvError::Lagged(n)) => Some(( + Err(A2AError::Internal(format!( + "streaming reader lagged, dropped {n} events" + ))), + rx, + )), + Err(broadcast::error::RecvError::Closed) => None, + } + }); + + let stream = futures::stream::iter(replay.into_iter().map(Ok)).chain(live); + Ok(Box::pin(stream)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::domain::{TaskState, TaskStatus, TaskStatusUpdateEvent}; + + fn status_event(task_id: &str, state: TaskState) -> TaskStatusUpdateEvent { + TaskStatusUpdateEvent { + task_id: task_id.to_string(), + context_id: "ctx".to_string(), + kind: "status-update".to_string(), + status: TaskStatus::new(state, None), + metadata: None, + } + } + + fn seq_state(seq: &SeqEvent) -> ::buffa::EnumValue { + match &seq.event { + UpdateEvent::StatusUpdate(e) => e.status.state, + UpdateEvent::ArtifactUpdate(_) => panic!("expected status update"), + } + } + + /// A live `combined_update_stream` reader receives broadcasts in order, each + /// tagged with a monotonic id starting at 1. + #[tokio::test] + async fn live_stream_delivers_in_order_with_ids() { + let handler = InMemoryStreamingHandler::new(); + let mut stream = handler.combined_update_stream("t1", None).await.unwrap(); + + handler + .broadcast_status_update("t1", status_event("t1", TaskState::Working)) + .await + .unwrap(); + handler + .broadcast_status_update("t1", status_event("t1", TaskState::Completed)) + .await + .unwrap(); + + let first = stream.next().await.unwrap().unwrap(); + let second = stream.next().await.unwrap().unwrap(); + assert_eq!(first.id, 1); + assert_eq!( + seq_state(&first), + ::buffa::EnumValue::from(TaskState::Working) + ); + assert_eq!(second.id, 2); + assert_eq!( + seq_state(&second), + ::buffa::EnumValue::from(TaskState::Completed) + ); + } + + /// Subscribing with `from_event_id` replays the buffered tail with a greater + /// id before any live updates. + #[tokio::test] + async fn resume_replays_buffered_tail() { + let handler = InMemoryStreamingHandler::new(); + // Emit two events with no live reader; they are retained in the buffer. + handler + .broadcast_status_update("t1", status_event("t1", TaskState::Working)) + .await + .unwrap(); + handler + .broadcast_status_update("t1", status_event("t1", TaskState::Completed)) + .await + .unwrap(); + + // Resume from id 1: only event 2 should replay. + let mut stream = handler.combined_update_stream("t1", Some(1)).await.unwrap(); + let replayed = stream.next().await.unwrap().unwrap(); + assert_eq!(replayed.id, 2); + assert_eq!( + seq_state(&replayed), + ::buffa::EnumValue::from(TaskState::Completed) + ); + } + + /// A synchronous callback subscriber still receives broadcasts (the push API + /// rides alongside the broadcast channel). + #[tokio::test] + async fn callback_subscriber_still_notified() { + use std::sync::Mutex as StdMutex; + + #[derive(Default, Clone)] + struct Recorder { + seen: Arc>>>, + } + #[async_trait] + impl Subscriber for Recorder { + async fn on_update(&self, update: TaskStatusUpdateEvent) -> Result<(), A2AError> { + self.seen.lock().unwrap().push(update.status.state); + Ok(()) + } + } + + let handler = InMemoryStreamingHandler::new(); + let recorder = Recorder::default(); + handler + .add_status_subscriber("t1", Box::new(recorder.clone())) + .await + .unwrap(); + handler + .broadcast_status_update("t1", status_event("t1", TaskState::Working)) + .await + .unwrap(); + + assert_eq!( + *recorder.seen.lock().unwrap(), + vec![::buffa::EnumValue::from(TaskState::Working)] + ); + } +} diff --git a/a2a-rs/src/adapter/streaming/mod.rs b/a2a-rs/src/adapter/streaming/mod.rs new file mode 100644 index 0000000..4612f7a --- /dev/null +++ b/a2a-rs/src/adapter/streaming/mod.rs @@ -0,0 +1,12 @@ +//! Streaming adapters: real-time fan-out of task updates to subscribers. +//! +//! This is the technical-concern bucket for the [`AsyncStreamingHandler`] port +//! (`.claude/rules/hexagonal_architecture.md` §3). It holds the in-process +//! subscriber registry — distinct from the storage adapters, which are +//! persistence-only and do not fan out updates. +//! +//! [`AsyncStreamingHandler`]: crate::port::AsyncStreamingHandler + +mod in_memory; + +pub use in_memory::InMemoryStreamingHandler; diff --git a/a2a-rs/src/adapter/transport/codec.rs b/a2a-rs/src/adapter/transport/codec.rs new file mode 100644 index 0000000..52edcce --- /dev/null +++ b/a2a-rs/src/adapter/transport/codec.rs @@ -0,0 +1,26 @@ +//! Shared client-side wire decoding helpers. +//! +//! Both transport client adapters (ConnectRPC's `HttpClient` and the JSON-RPC +//! `JsonRpcClient`) receive the same generated [`StreamResponse`] union on a +//! subscription and must map it to the protocol-neutral [`StreamItem`] the +//! [`Transport`](crate::port::Transport) port yields. Keeping that mapping here +//! ensures both directions agree. + +use crate::domain::generated::{StreamResponse, stream_response}; +use crate::port::StreamItem; + +/// Map a wire [`StreamResponse`] (tag-free field-presence union) onto the +/// protocol-neutral [`StreamItem`]. Returns `None` for an empty/unrecognized +/// payload. +pub fn stream_response_to_item(resp: StreamResponse) -> Option { + match resp.payload { + Some(stream_response::Payload::Task(task)) => Some(StreamItem::Task(*task)), + Some(stream_response::Payload::StatusUpdate(update)) => { + Some(StreamItem::StatusUpdate((*update).into())) + } + Some(stream_response::Payload::ArtifactUpdate(update)) => { + Some(StreamItem::ArtifactUpdate((*update).into())) + } + _ => None, + } +} diff --git a/a2a-rs/src/adapter/business/request_processor.rs b/a2a-rs/src/adapter/transport/connectrpc.rs similarity index 60% rename from a2a-rs/src/adapter/business/request_processor.rs rename to a2a-rs/src/adapter/transport/connectrpc.rs index 52b07d7..00d0d9d 100644 --- a/a2a-rs/src/adapter/business/request_processor.rs +++ b/a2a-rs/src/adapter/transport/connectrpc.rs @@ -1,20 +1,30 @@ -//! A default request processor implementation +//! The ConnectRPC transport adapter. +//! +//! `ConnectRpcAdapter` is the **outer** half of the service/transport split: a +//! thin transport adapter that implements the generated [`A2aService`] surface. +//! Its only job is to decode `buffa` wire views into domain values, delegate to +//! the inner [`TaskService`], and re-encode the domain results (and map +//! [`A2AError`] onto ConnectRPC error codes). All use-case orchestration lives +//! in [`TaskService`]; this layer holds no port traits directly. +//! +//! The public constructors (`new`, `with_handler`, `with_streaming_handler`) +//! each build the inner service and wrap it. use async_trait::async_trait; use buffa::Enumeration; use std::pin::Pin; -use std::sync::Arc; use crate::{ + application::TaskService, domain::{ - A2AError, AgentCard, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, + A2AError, AgentCard, Task, TaskArtifactUpdateEvent, TaskId, TaskPushNotificationConfig, TaskStatusUpdateEvent, generated::{ A2aService, CancelTaskRequestView, DeleteTaskPushNotificationConfigRequestView, GetExtendedAgentCardRequestView, GetTaskPushNotificationConfigRequestView, GetTaskRequestView, ListTaskPushNotificationConfigsRequestView, - ListTaskPushNotificationConfigsResponse, ListTasksRequestView, ListTasksResponse, - SendMessageRequestView, SendMessageResponse, StreamResponse, + ListTaskPushNotificationConfigsResponse, ListTasksRequest, ListTasksRequestView, + ListTasksResponse, SendMessageRequestView, SendMessageResponse, StreamResponse, SubscribeToTaskRequestView, TaskArtifactUpdateEvent as GenTaskArtifactUpdateEvent, TaskPushNotificationConfigView, TaskState, TaskStatusUpdateEvent as GenTaskStatusUpdateEvent, send_message_response, @@ -22,103 +32,82 @@ use crate::{ }, }, port::{ - AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - UpdateEvent, streaming_handler::Subscriber, + AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, + AsyncTaskQuery, SeqEvent, UpdateEvent, streaming_handler::Subscriber, }, services::server::AgentInfoProvider, }; -/// Default implementation of a request processor that routes ConnectRPC requests to business handlers +/// ConnectRPC transport adapter over a [`TaskService`]. +/// +/// Holds no ports directly — it owns the inner application service and forwards +/// decoded requests to it. Dispatch into the service goes through the service's +/// `Arc` fields, which is a cold path against the I/O each call performs. #[derive(Clone)] -pub struct DefaultRequestProcessor< - M, - T, - N, - A = crate::adapter::SimpleAgentInfo, - S = NoopStreamingHandler, -> where - M: AsyncMessageHandler + Send + Sync + 'static, - T: AsyncTaskManager + Send + Sync + 'static, - N: AsyncNotificationManager + Send + Sync + 'static, - A: AgentInfoProvider + Send + Sync + 'static, - S: AsyncStreamingHandler + Send + Sync + 'static, -{ - /// Message handler - message_handler: Arc, - /// Task manager - task_manager: Arc, - /// Notification manager - notification_manager: Arc, - /// Agent info provider - agent_info: Arc, - /// Streaming handler - streaming_handler: Arc, +pub struct ConnectRpcAdapter { + service: TaskService, } -impl DefaultRequestProcessor -where - M: AsyncMessageHandler + Send + Sync + 'static, - T: AsyncTaskManager + Send + Sync + 'static, - N: AsyncNotificationManager + Send + Sync + 'static, - A: AgentInfoProvider + Send + Sync + 'static, -{ - /// Create a new request processor with the given handlers and default NoopStreamingHandler +impl ConnectRpcAdapter { + /// Create a new adapter from separate handlers, defaulting to a no-op + /// streaming handler. + /// + /// `tasks` supplies both the lifecycle and query capabilities. pub fn new( - message_handler: M, - task_manager: T, - notification_manager: N, - agent_info: A, + message_handler: impl AsyncMessageHandler + 'static, + tasks: impl AsyncTaskLifecycle + AsyncTaskQuery + 'static, + notification_manager: impl AsyncNotificationManager + 'static, + agent_info: impl AgentInfoProvider + 'static, ) -> Self { Self { - message_handler: Arc::new(message_handler), - task_manager: Arc::new(task_manager), - notification_manager: Arc::new(notification_manager), - agent_info: Arc::new(agent_info), - streaming_handler: Arc::new(NoopStreamingHandler), + service: TaskService::new( + message_handler, + tasks, + notification_manager, + agent_info, + NoopStreamingHandler, + crate::port::NoopPushNotifier, + ), } } -} -impl DefaultRequestProcessor -where - H: AsyncMessageHandler + AsyncTaskManager + AsyncNotificationManager + Send + Sync + 'static, - A: AgentInfoProvider + Send + Sync + 'static, -{ - /// Create a new request processor with a single handler that implements all traits - pub fn with_handler(handler: H, agent_info: A) -> Self { - let handler_arc = Arc::new(handler); + /// Create a new adapter from a single handler that implements every port, + /// defaulting to a no-op streaming handler. + pub fn with_handler( + handler: impl AsyncMessageHandler + + AsyncTaskLifecycle + + AsyncTaskQuery + + AsyncNotificationManager + + 'static, + agent_info: impl AgentInfoProvider + 'static, + ) -> Self { Self { - message_handler: handler_arc.clone(), - task_manager: handler_arc.clone(), - notification_manager: handler_arc, - agent_info: Arc::new(agent_info), - streaming_handler: Arc::new(NoopStreamingHandler), + service: TaskService::with_handler( + handler, + agent_info, + NoopStreamingHandler, + crate::port::NoopPushNotifier, + ), } } -} -impl DefaultRequestProcessor -where - M: AsyncMessageHandler + Send + Sync + 'static, - T: AsyncTaskManager + Send + Sync + 'static, - N: AsyncNotificationManager + Send + Sync + 'static, - A: AgentInfoProvider + Send + Sync + 'static, - S: AsyncStreamingHandler + Send + Sync + 'static, -{ - /// Builder-style method to inject custom streaming handler support - pub fn with_streaming_handler( + /// Builder-style method to inject custom streaming handler support. + pub fn with_streaming_handler( self, - streaming_handler: NewS, - ) -> DefaultRequestProcessor - where - NewS: AsyncStreamingHandler + Send + Sync + 'static, - { - DefaultRequestProcessor { - message_handler: self.message_handler, - task_manager: self.task_manager, - notification_manager: self.notification_manager, - agent_info: self.agent_info, - streaming_handler: Arc::new(streaming_handler), + streaming_handler: impl AsyncStreamingHandler + 'static, + ) -> Self { + Self { + service: self.service.with_streaming_handler(streaming_handler), + } + } + + /// Builder-style method to inject a custom push notifier. + pub fn with_push_notifier( + self, + push_notifier: impl crate::port::AsyncPushNotifier + 'static, + ) -> Self { + Self { + service: self.service.with_push_notifier(push_notifier), } } } @@ -189,14 +178,28 @@ fn map_artifact_update( } } -impl A2aService for DefaultRequestProcessor -where - M: AsyncMessageHandler + Send + Sync + 'static, - T: AsyncTaskManager + Send + Sync + 'static, - N: AsyncNotificationManager + Send + Sync + 'static, - A: AgentInfoProvider + Send + Sync + 'static, - S: AsyncStreamingHandler + Send + Sync + 'static, -{ +/// Map a domain [`UpdateEvent`] onto its wire [`StreamResponse`]. +/// +/// Shared with the JSON-RPC adapter so both transports map streaming updates +/// through one path. +pub(super) fn map_update_event(evt: UpdateEvent) -> StreamResponse { + match evt { + UpdateEvent::StatusUpdate(event) => StreamResponse { + payload: Some(stream_response::Payload::StatusUpdate(Box::new( + map_status_update(event), + ))), + ..Default::default() + }, + UpdateEvent::ArtifactUpdate(event) => StreamResponse { + payload: Some(stream_response::Payload::ArtifactUpdate(Box::new( + map_artifact_update(event), + ))), + ..Default::default() + }, + } +} + +impl A2aService for ConnectRpcAdapter { async fn send_message( &self, ctx: ::connectrpc::Context, @@ -218,32 +221,14 @@ where Some(message.context_id.as_str()) }; - let mut history_limit = None; - - // If push notification configuration is provided, configure it - if let Some(c) = config { - if let Some(mut push_config) = c.task_push_notification_config.into_option() { - push_config.task_id = task_id.clone(); - self.notification_manager - .set_task_notification_validated(&push_config) - .await - .map_err(map_err)?; - } - if let Some(limit) = c.history_length { - history_limit = Some(limit as u32); - } - } + let (push_config, history_limit) = decode_send_config(config); - let mut task = self - .message_handler - .process_message(&task_id, &message, session_id) + let task = self + .service + .send_message(&task_id, &message, session_id, push_config, history_limit) .await .map_err(map_err)?; - if let Some(limit) = history_limit { - task = task.with_limited_history(Some(limit)); - } - let response = SendMessageResponse { payload: Some(send_message_response::Payload::Task(Box::new(task))), ..Default::default() @@ -285,39 +270,14 @@ where Some(message.context_id.as_str()) }; - let mut history_limit = None; - - // Setup notification if present - if let Some(c) = config { - if let Some(mut push_config) = c.task_push_notification_config.into_option() { - push_config.task_id = task_id.clone(); - self.notification_manager - .set_task_notification_validated(&push_config) - .await - .map_err(map_err)?; - } - if let Some(limit) = c.history_length { - history_limit = Some(limit as u32); - } - } + let (push_config, history_limit) = decode_send_config(config); - // Start updates stream first so we don't miss early updates - let update_stream = self - .streaming_handler - .start_task_streaming(&task_id) + let (task, update_stream) = self + .service + .send_streaming_message(&task_id, &message, session_id, push_config, history_limit) .await .map_err(map_err)?; - let mut task = self - .message_handler - .process_message(&task_id, &message, session_id) - .await - .map_err(map_err)?; - - if let Some(limit) = history_limit { - task = task.with_limited_history(Some(limit)); - } - use futures::StreamExt; let initial_response = StreamResponse { @@ -325,23 +285,8 @@ where ..Default::default() }; - let mapped_stream = update_stream.map(|item| { - item.map(|evt| match evt { - UpdateEvent::StatusUpdate(event) => StreamResponse { - payload: Some(stream_response::Payload::StatusUpdate(Box::new( - map_status_update(event), - ))), - ..Default::default() - }, - UpdateEvent::ArtifactUpdate(event) => StreamResponse { - payload: Some(stream_response::Payload::ArtifactUpdate(Box::new( - map_artifact_update(event), - ))), - ..Default::default() - }, - }) - .map_err(map_err) - }); + let mapped_stream = + update_stream.map(|item| item.map(|seq| map_update_event(seq.event)).map_err(map_err)); let chained_stream = futures::stream::once(async { Ok(initial_response) }).chain(mapped_stream); @@ -356,9 +301,10 @@ where ) -> Result<(Task, ::connectrpc::Context), ::connectrpc::ConnectError> { let req = request.to_owned_message(); let history_length = req.history_length.map(|l| l as u32); + let id: TaskId = req.id.parse().map_err(map_err)?; let task = self - .task_manager - .get_task(&req.id, history_length) + .service + .get(&id, history_length) .await .map_err(map_err)?; Ok((task, ctx)) @@ -370,38 +316,9 @@ where request: ::buffa::view::OwnedView>, ) -> Result<(ListTasksResponse, ::connectrpc::Context), ::connectrpc::ConnectError> { let req = request.to_owned_message(); + let params = list_request_to_params(req); - let params = crate::domain::ListTasksParams { - context_id: if req.context_id.is_empty() { - None - } else { - Some(req.context_id) - }, - status: match req.status.to_i32() { - 0 => None, - val => Some(TaskState::from_i32(val).unwrap_or(TaskState::TASK_STATE_UNSPECIFIED)), - }, - page_size: req.page_size, - page_token: if req.page_token.is_empty() { - None - } else { - Some(req.page_token) - }, - history_length: req.history_length, - include_artifacts: req.include_artifacts, - status_timestamp_after: req.status_timestamp_after.as_option().map(|t| { - let dt = chrono::DateTime::::from_timestamp(t.seconds, t.nanos as u32) - .unwrap_or_default(); - dt.to_rfc3339() - }), - metadata: None, - }; - - let result = self - .task_manager - .list_tasks_v3(¶ms) - .await - .map_err(map_err)?; + let result = self.service.list(¶ms).await.map_err(map_err)?; let response = ListTasksResponse { tasks: result.tasks, @@ -420,11 +337,8 @@ where request: ::buffa::view::OwnedView>, ) -> Result<(Task, ::connectrpc::Context), ::connectrpc::ConnectError> { let req = request.to_owned_message(); - let task = self - .task_manager - .cancel_task(&req.id) - .await - .map_err(map_err)?; + let id: TaskId = req.id.parse().map_err(map_err)?; + let task = self.service.cancel(&id).await.map_err(map_err)?; Ok((task, ctx)) } @@ -446,39 +360,17 @@ where ::connectrpc::ConnectError, > { let req = request.to_owned_message(); - let task_id = req.id; - - let initial_task = match self.task_manager.get_task(&task_id, None).await { - Ok(task) => Some(task), - Err(A2AError::TaskNotFound(_)) => None, - Err(e) => return Err(map_err(e)), - }; - let update_stream = self - .streaming_handler - .start_task_streaming(&task_id) + let (initial_task, update_stream) = self + .service + .subscribe(&req.id, None) .await .map_err(map_err)?; use futures::StreamExt; - let mapped_stream = update_stream.map(|item| { - item.map(|evt| match evt { - UpdateEvent::StatusUpdate(event) => StreamResponse { - payload: Some(stream_response::Payload::StatusUpdate(Box::new( - map_status_update(event), - ))), - ..Default::default() - }, - UpdateEvent::ArtifactUpdate(event) => StreamResponse { - payload: Some(stream_response::Payload::ArtifactUpdate(Box::new( - map_artifact_update(event), - ))), - ..Default::default() - }, - }) - .map_err(map_err) - }); + let mapped_stream = + update_stream.map(|item| item.map(|seq| map_update_event(seq.event)).map_err(map_err)); if let Some(task) = initial_task { let initial_response = StreamResponse { @@ -501,8 +393,8 @@ where { let config = request.to_owned_message(); let created_config = self - .notification_manager - .set_task_notification_validated(&config) + .service + .set_push_config(&config) .await .map_err(map_err)?; Ok((created_config, ctx)) @@ -521,8 +413,8 @@ where metadata: None, }; let config = self - .task_manager - .get_push_notification_config(¶ms) + .service + .get_push_config(¶ms) .await .map_err(map_err)?; Ok((config, ctx)) @@ -545,8 +437,8 @@ where metadata: None, }; let configs = self - .task_manager - .list_push_notification_configs(¶ms) + .service + .list_push_configs(¶ms) .await .map_err(map_err)?; let response = ListTaskPushNotificationConfigsResponse { @@ -562,11 +454,7 @@ where request: ::buffa::view::OwnedView>, ) -> Result<(AgentCard, ::connectrpc::Context), ::connectrpc::ConnectError> { let _req = request.to_owned_message(); - let card = self - .agent_info - .get_authenticated_extended_card() - .await - .map_err(map_err)?; + let card = self.service.extended_agent_card().await.map_err(map_err)?; Ok((card, ctx)) } @@ -587,15 +475,62 @@ where push_notification_config_id: req.id, metadata: None, }; - self.task_manager - .delete_push_notification_config(¶ms) + self.service + .delete_push_config(¶ms) .await .map_err(map_err)?; Ok((::buffa_types::google::protobuf::Empty::default(), ctx)) } } -/// A no-op AsyncStreamingHandler implementation for request processor defaulting +/// Map a generated `ListTasksRequest` (proto wire message) onto the domain +/// [`ListTasksParams`]. Shared with the JSON-RPC adapter. +pub(super) fn list_request_to_params(req: ListTasksRequest) -> crate::domain::ListTasksParams { + crate::domain::ListTasksParams { + context_id: if req.context_id.is_empty() { + None + } else { + Some(req.context_id) + }, + status: match req.status.to_i32() { + 0 => None, + val => Some(TaskState::from_i32(val).unwrap_or(TaskState::TASK_STATE_UNSPECIFIED)), + }, + page_size: req.page_size, + page_token: if req.page_token.is_empty() { + None + } else { + Some(req.page_token) + }, + history_length: req.history_length, + include_artifacts: req.include_artifacts, + status_timestamp_after: req.status_timestamp_after.as_option().map(|t| { + let dt = chrono::DateTime::::from_timestamp(t.seconds, t.nanos as u32) + .unwrap_or_default(); + dt.to_rfc3339() + }), + metadata: None, + } +} + +/// Decode the optional `SendMessageConfiguration` view into the domain push +/// config + history limit the service expects. +/// +/// Shared with the JSON-RPC adapter (both decode the same generated config +/// message), so the two transports agree on the wire shape. +pub(super) fn decode_send_config( + config: Option, +) -> (Option, Option) { + let Some(c) = config else { + return (None, None); + }; + let push_config = c.task_push_notification_config.into_option(); + let history_limit = c.history_length.map(|limit| limit as u32); + (push_config, history_limit) +} + +/// A no-op [`AsyncStreamingHandler`] used as the adapter's default streaming port +/// when the caller has no real streaming backend to inject. #[derive(Clone, Debug, Default)] pub struct NoopStreamingHandler; @@ -676,10 +611,9 @@ impl AsyncStreamingHandler for NoopStreamingHandler { async fn combined_update_stream( &self, _task_id: &str, - ) -> Result< - Pin> + Send>>, - A2AError, - > { + _from_event_id: Option, + ) -> Result> + Send>>, A2AError> + { Err(A2AError::UnsupportedOperation( "Streaming not supported by this processor".to_string(), )) diff --git a/a2a-rs/src/adapter/transport/http/client.rs b/a2a-rs/src/adapter/transport/http/client.rs index 84b237f..81b3ab5 100644 --- a/a2a-rs/src/adapter/transport/http/client.rs +++ b/a2a-rs/src/adapter/transport/http/client.rs @@ -13,6 +13,7 @@ use tracing::{debug, instrument}; use crate::{ adapter::error::HttpClientError, + adapter::transport::codec::stream_response_to_item, domain::{ A2AError, AgentCard, ListTasksParams, ListTasksResult, Message, Task, TaskPushNotificationConfig, @@ -21,10 +22,9 @@ use crate::{ GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, ListTasksRequest, SendMessageConfiguration, SendMessageRequest, SubscribeToTaskRequest, TaskState, send_message_response, - stream_response, }, }, - services::client::{AsyncA2AClient, StreamItem}, + port::{StreamEvent, Transport}, }; fn map_connect_err(err: connectrpc::ConnectError) -> A2AError { @@ -52,19 +52,6 @@ fn map_connect_err(err: connectrpc::ConnectError) -> A2AError { } } -fn map_stream_response(resp: crate::domain::generated::StreamResponse) -> Option { - match resp.payload { - Some(stream_response::Payload::Task(task)) => Some(StreamItem::Task(*task)), - Some(stream_response::Payload::StatusUpdate(update)) => { - Some(StreamItem::StatusUpdate((*update).into())) - } - Some(stream_response::Payload::ArtifactUpdate(update)) => { - Some(StreamItem::ArtifactUpdate((*update).into())) - } - _ => None, - } -} - /// HTTP client for interacting with the A2A protocol via ConnectRPC pub struct HttpClient { /// Base URL of the A2A API @@ -248,7 +235,11 @@ impl HttpClient { } #[async_trait] -impl AsyncA2AClient for HttpClient { +impl Transport for HttpClient { + fn protocol(&self) -> &str { + "CONNECTRPC" + } + #[cfg_attr( feature = "tracing", instrument(skip(self, message), fields(task_id, session_id, history_length)) @@ -455,7 +446,10 @@ impl AsyncA2AClient for HttpClient { &self, task_id: &str, _history_length: Option, - ) -> Result> + Send>>, A2AError> { + // ConnectRPC streaming has no SSE `Last-Event-ID`; resumption is not + // supported on this transport, so the hint is ignored. + _last_event_id: Option<&str>, + ) -> Result> + Send>>, A2AError> { let request = SubscribeToTaskRequest { id: task_id.to_string(), ..Default::default() @@ -470,8 +464,8 @@ impl AsyncA2AClient for HttpClient { match s.message().await { Ok(Some(view)) => { let resp = view.to_owned_message(); - if let Some(item) = map_stream_response(resp) { - Some((Ok(item), s)) + if let Some(item) = stream_response_to_item(resp) { + Some((Ok(StreamEvent::untagged(item)), s)) } else { Some(( Err(A2AError::Internal( diff --git a/a2a-rs/src/adapter/transport/http/server.rs b/a2a-rs/src/adapter/transport/http/server.rs index 8ff6813..59421f1 100644 --- a/a2a-rs/src/adapter/transport/http/server.rs +++ b/a2a-rs/src/adapter/transport/http/server.rs @@ -29,7 +29,8 @@ where A: AgentInfoProvider + Send + Sync + 'static, Auth: Authenticator + Send + Sync + 'static, { - /// Request processor + /// The `A2aService` implementation this server dispatches requests to + /// (e.g. [`ConnectRpcAdapter`](crate::adapter::ConnectRpcAdapter)). processor: Arc

, /// Agent info provider agent_info: Arc, diff --git a/a2a-rs/src/adapter/transport/jsonrpc.rs b/a2a-rs/src/adapter/transport/jsonrpc.rs new file mode 100644 index 0000000..d3d0954 --- /dev/null +++ b/a2a-rs/src/adapter/transport/jsonrpc.rs @@ -0,0 +1,799 @@ +//! The JSON-RPC 2.0 + HTTP+JSON (REST) transport adapter. +//! +//! `JsonRpcAdapter` is a **sibling** of [`ConnectRpcAdapter`](super::connectrpc::ConnectRpcAdapter): +//! a thin transport adapter that wraps the same inner [`TaskService`] but speaks +//! the spec-mandated, ecosystem-interoperable **JSON-RPC 2.0** wire format +//! (and, via [`rest_router`], HTTP+JSON / REST). Its only job is to parse a +//! JSON-RPC envelope, deserialize `params` into the matching A2A request type, +//! delegate to [`TaskService`], and re-encode the domain result — mapping +//! [`A2AError`] onto JSON-RPC error codes. +//! +//! All use-case orchestration lives in [`TaskService`]; this layer holds no port +//! traits directly — exactly the layering of `connectrpc.rs`. +//! +//! # Wire format +//! +//! Request `params` and the `result` body are the **generated proto types** +//! (`SendMessageRequest`, `Task`, `SendMessageResponse`, …). Those already +//! serialize as canonical ProtoJSON — camelCase fields, SCREAMING_SNAKE enums, +//! RFC3339 timestamps, base64 `bytes`, bare `Struct` metadata, and tag-free +//! field-presence unions (`{"task": …}` / `{"statusUpdate": …}`). This is the +//! same representation the official SDK and the Go/C#/Python SDKs use, so an +//! off-the-shelf A2A client can talk to this server. The decode/encode helpers +//! (`decode_send_config`, `list_request_to_params`, `map_update_event`) are +//! **shared with the Connect adapter** so both transports agree on the wire. +//! +//! The hand-written A2A param types in `domain/core/task.rs` (`MessageSendParams` +//! with `pushNotificationConfig`/`blocking`) are the *legacy* JSON-RPC v0.x shape +//! and are intentionally **not** used here — the proto request types are the +//! v1.0 contract. + +use std::convert::Infallible; +use std::pin::Pin; +use std::sync::Arc; + +use axum::{ + Json, Router, + body::Bytes, + extract::{Path, Query, State}, + http::{HeaderMap, StatusCode}, + response::{ + IntoResponse, Response, + sse::{Event, KeepAlive, Sse}, + }, + routing::{get, post}, +}; +use futures::{Stream, StreamExt}; +use serde::Serialize; +use serde_json::Value; + +use crate::{ + application::TaskService, + domain::{ + A2AError, TaskId, TaskPushNotificationConfig, + generated::{ + CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, GetTaskRequest, + ListTaskPushNotificationConfigsRequest, ListTaskPushNotificationConfigsResponse, + ListTasksRequest, ListTasksResponse, SendMessageRequest, SendMessageResponse, + StreamResponse, SubscribeToTaskRequest, send_message_response, stream_response, + }, + }, + port::{ + AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, + AsyncTaskQuery, CallContext, CallInterceptor, CallSide, run_after, run_before, + }, + services::server::AgentInfoProvider, +}; + +use super::connectrpc::{ + NoopStreamingHandler, decode_send_config, list_request_to_params, map_update_event, +}; +// Re-exported so existing `transport::jsonrpc::{methods, error_code, JsonRpc*}` +// paths keep working now that these live in the shared wire module. +pub use super::jsonrpc_wire::{ + JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, a2a_to_jsonrpc, error_code, methods, +}; + +/// A stream of wire [`StreamResponse`]s — the unified output of both streaming +/// methods, before it is framed as SSE (enveloped for JSON-RPC, bare for REST). +/// A stream of wire responses, each tagged with an optional per-task event id. +/// The id (when present) is emitted as the SSE `id:` field so a client can +/// resume via `Last-Event-ID`. The initial task snapshot carries `None`. +type StreamResponseStream = + Pin, StreamResponse), A2AError>> + Send>>; + +// --------------------------------------------------------------------------- +// Adapter +// --------------------------------------------------------------------------- + +/// JSON-RPC 2.0 / HTTP+JSON transport adapter over a [`TaskService`]. +/// +/// Mirrors [`ConnectRpcAdapter`](super::connectrpc::ConnectRpcAdapter)'s +/// constructors so an agent author swaps transports with one line. +#[derive(Clone)] +pub struct JsonRpcAdapter { + service: TaskService, + /// Server-side interceptor chain wrapping every unary/streaming dispatch. + interceptors: Vec>, +} + +impl JsonRpcAdapter { + /// Create an adapter from separate handlers (no real streaming backend). + /// + /// `tasks` supplies both the lifecycle and query capabilities. Uses the same + /// [`NoopStreamingHandler`] default as the Connect adapter. + pub fn new( + message_handler: impl AsyncMessageHandler + 'static, + tasks: impl AsyncTaskLifecycle + AsyncTaskQuery + 'static, + notification_manager: impl AsyncNotificationManager + 'static, + agent_info: impl AgentInfoProvider + 'static, + ) -> Self { + Self { + service: TaskService::new( + message_handler, + tasks, + notification_manager, + agent_info, + NoopStreamingHandler, + crate::port::NoopPushNotifier, + ), + interceptors: Vec::new(), + } + } + + /// Create an adapter from a single handler implementing every port. + pub fn with_handler( + handler: impl AsyncMessageHandler + + AsyncTaskLifecycle + + AsyncTaskQuery + + AsyncNotificationManager + + 'static, + agent_info: impl AgentInfoProvider + 'static, + ) -> Self { + Self { + service: TaskService::with_handler( + handler, + agent_info, + NoopStreamingHandler, + crate::port::NoopPushNotifier, + ), + interceptors: Vec::new(), + } + } + + /// Inject a real streaming handler (required for the streaming methods). + pub fn with_streaming_handler( + self, + streaming_handler: impl AsyncStreamingHandler + 'static, + ) -> Self { + Self { + service: self.service.with_streaming_handler(streaming_handler), + interceptors: self.interceptors, + } + } + + /// Inject a real push notifier (required for webhook delivery). + pub fn with_push_notifier( + self, + push_notifier: impl crate::port::AsyncPushNotifier + 'static, + ) -> Self { + Self { + service: self.service.with_push_notifier(push_notifier), + interceptors: self.interceptors, + } + } + + /// Append a server-side [`CallInterceptor`] to the chain. + /// + /// Interceptors wrap every unary and streaming dispatch: `before` hooks run + /// in registration order, then the method runs, then `after` hooks run in + /// reverse. Chainable, so callers can register several. + pub fn with_interceptor(mut self, interceptor: impl CallInterceptor + 'static) -> Self { + self.interceptors.push(Arc::new(interceptor)); + self + } +} + +// --------------------------------------------------------------------------- +// Method dispatch (transport-neutral core, shared by JSON-RPC and REST) +// --------------------------------------------------------------------------- + +impl JsonRpcAdapter { + /// Handle a single non-streaming JSON-RPC request, producing a response + /// envelope. Streaming methods are handled by the SSE path in the router and + /// must not reach here. + pub async fn handle_unary(&self, req: JsonRpcRequest) -> JsonRpcResponse { + let id = req.id.clone(); + let result = self.dispatch_intercepted(&req.method, req.params).await; + match result { + Ok(value) => JsonRpcResponse::ok(id, value), + Err(e) => JsonRpcResponse::err(id, a2a_to_jsonrpc(&e)), + } + } + + /// Run the server interceptor chain around a unary [`dispatch_unary`], so + /// both the JSON-RPC and REST entry points share one interception point. + /// + /// [`dispatch_unary`]: Self::dispatch_unary + async fn dispatch_intercepted( + &self, + method: &str, + params: Option, + ) -> Result { + if self.interceptors.is_empty() { + return self.dispatch_unary(method, params).await; + } + let ctx = CallContext::new(method, CallSide::Server); + if let Err(e) = run_before(&self.interceptors, &ctx).await { + run_after(&self.interceptors, &ctx, Err(&e)).await; + return Err(e); + } + let result = self.dispatch_unary(method, params).await; + run_after(&self.interceptors, &ctx, result.as_ref().map(|_| ())).await; + result + } + + /// Route a unary method name + `params` to the service and return the wire + /// `result` value. Reused by both JSON-RPC and REST. + async fn dispatch_unary(&self, method: &str, params: Option) -> Result { + match method { + methods::GET_TASK => self.get_task(params).await, + methods::LIST_TASKS => self.list_tasks(params).await, + methods::CANCEL_TASK => self.cancel_task(params).await, + methods::SEND_MESSAGE => self.send_message(params).await, + methods::CREATE_PUSH_CONFIG => self.create_push_config(params).await, + methods::GET_PUSH_CONFIG => self.get_push_config(params).await, + methods::LIST_PUSH_CONFIGS => self.list_push_configs(params).await, + methods::DELETE_PUSH_CONFIG => self.delete_push_config(params).await, + methods::GET_EXTENDED_AGENT_CARD => self.extended_card().await, + methods::SEND_STREAMING_MESSAGE | methods::SUBSCRIBE_TO_TASK => Err( + A2AError::InvalidParams("streaming method requires SSE transport".to_string()), + ), + unknown => Err(A2AError::MethodNotFound(unknown.to_string())), + } + } + + async fn get_task(&self, params: Option) -> Result { + let req: GetTaskRequest = parse_params(params)?; + let id: TaskId = req.id.parse()?; + let task = self + .service + .get(&id, req.history_length.map(|l| l as u32)) + .await?; + to_value(&task) + } + + async fn list_tasks(&self, params: Option) -> Result { + let req: ListTasksRequest = parse_params(params)?; + let result = self.service.list(&list_request_to_params(req)).await?; + let response = ListTasksResponse { + tasks: result.tasks, + next_page_token: result.next_page_token, + page_size: result.page_size, + total_size: result.total_size, + ..Default::default() + }; + to_value(&response) + } + + async fn cancel_task(&self, params: Option) -> Result { + let req: CancelTaskRequest = parse_params(params)?; + let id: TaskId = req.id.parse()?; + let task = self.service.cancel(&id).await?; + to_value(&task) + } + + async fn send_message(&self, params: Option) -> Result { + let (task_id, message, session_id, push_config, history_limit) = + decode_send_message(parse_params(params)?)?; + let task = self + .service + .send_message( + &task_id, + &message, + session_id.as_deref(), + push_config, + history_limit, + ) + .await?; + let response = SendMessageResponse { + payload: Some(send_message_response::Payload::Task(Box::new(task))), + ..Default::default() + }; + to_value(&response) + } + + async fn create_push_config(&self, params: Option) -> Result { + let config: TaskPushNotificationConfig = parse_params(params)?; + let created = self.service.set_push_config(&config).await?; + to_value(&created) + } + + async fn get_push_config(&self, params: Option) -> Result { + let req: GetTaskPushNotificationConfigRequest = parse_params(params)?; + let domain_params = crate::domain::GetTaskPushNotificationConfigParams { + id: req.task_id, + push_notification_config_id: Some(req.id), + metadata: None, + }; + let config = self.service.get_push_config(&domain_params).await?; + to_value(&config) + } + + async fn list_push_configs(&self, params: Option) -> Result { + let req: ListTaskPushNotificationConfigsRequest = parse_params(params)?; + let domain_params = crate::domain::ListTaskPushNotificationConfigsParams { + id: req.task_id, + metadata: None, + }; + let configs = self.service.list_push_configs(&domain_params).await?; + let response = ListTaskPushNotificationConfigsResponse { + configs, + ..Default::default() + }; + to_value(&response) + } + + async fn delete_push_config(&self, params: Option) -> Result { + let req: DeleteTaskPushNotificationConfigRequest = parse_params(params)?; + let domain_params = crate::domain::DeleteTaskPushNotificationConfigParams { + id: req.task_id, + push_notification_config_id: req.id, + metadata: None, + }; + self.service.delete_push_config(&domain_params).await?; + Ok(serde_json::json!({})) + } + + async fn extended_card(&self) -> Result { + let card = self.service.extended_agent_card().await?; + to_value(&card) + } + + // -- streaming -------------------------------------------------------- + + /// Open the SSE stream for a streaming method, running the server + /// interceptor chain around the open (per-frame interception is out of + /// scope — `after` observes whether the stream opened, not each event). + async fn open_stream( + &self, + method: &str, + params: Option, + from_event_id: Option, + ) -> Result { + if self.interceptors.is_empty() { + return self.open_stream_inner(method, params, from_event_id).await; + } + let ctx = CallContext::new(method, CallSide::Server); + if let Err(e) = run_before(&self.interceptors, &ctx).await { + run_after(&self.interceptors, &ctx, Err(&e)).await; + return Err(e); + } + let result = self.open_stream_inner(method, params, from_event_id).await; + run_after(&self.interceptors, &ctx, result.as_ref().map(|_| ())).await; + result + } + + /// Open the SSE stream for a streaming method, returning a unified stream of + /// wire [`StreamResponse`]s (initial task snapshot first, then updates). + /// + /// `from_event_id` carries the client's `Last-Event-ID` for resumption; it + /// applies to `tasks/subscribe` (a fresh `message/stream` always starts from + /// the beginning). + async fn open_stream_inner( + &self, + method: &str, + params: Option, + from_event_id: Option, + ) -> Result { + match method { + methods::SEND_STREAMING_MESSAGE => { + let (task_id, message, session_id, push_config, history_limit) = + decode_send_message(parse_params(params)?)?; + let (task, updates) = self + .service + .send_streaming_message( + &task_id, + &message, + session_id.as_deref(), + push_config, + history_limit, + ) + .await?; + Ok(chain_initial_task(Some(task), updates)) + } + methods::SUBSCRIBE_TO_TASK => { + let req: SubscribeToTaskRequest = parse_params(params)?; + let (initial, updates) = self.service.subscribe(&req.id, from_event_id).await?; + Ok(chain_initial_task(initial, updates)) + } + unknown => Err(A2AError::MethodNotFound(unknown.to_string())), + } + } +} + +// --------------------------------------------------------------------------- +// Routers (axum) +// --------------------------------------------------------------------------- + +/// Build the JSON-RPC 2.0 router: a single `POST /` endpoint. +/// +/// Compose it at the edge with [`rest_router`] and the agent-card route, e.g. +/// `jsonrpc_router(adapter.clone()).merge(rest_router(adapter))`. +pub fn jsonrpc_router(adapter: Arc) -> Router { + Router::new() + .route("/", post(jsonrpc_handler)) + .with_state(adapter) +} + +async fn jsonrpc_handler( + State(adapter): State>, + headers: HeaderMap, + body: Bytes, +) -> Response { + let req: JsonRpcRequest = match serde_json::from_slice(&body) { + Ok(r) => r, + Err(e) => { + return Json(JsonRpcResponse::err( + JsonRpcId::Null, + JsonRpcError { + code: error_code::PARSE_ERROR, + message: e.to_string(), + data: None, + }, + )) + .into_response(); + } + }; + + if req.jsonrpc != "2.0" { + return Json(JsonRpcResponse::err( + req.id, + JsonRpcError { + code: error_code::INVALID_REQUEST, + message: "jsonrpc must be \"2.0\"".to_string(), + data: None, + }, + )) + .into_response(); + } + + if methods::is_streaming(&req.method) { + let id = req.id.clone(); + let from_event_id = parse_last_event_id(&headers); + match adapter + .open_stream(&req.method, req.params, from_event_id) + .await + { + Ok(stream) => jsonrpc_sse(id, stream).into_response(), + Err(e) => Json(JsonRpcResponse::err(id, a2a_to_jsonrpc(&e))).into_response(), + } + } else { + Json(adapter.handle_unary(req).await).into_response() + } +} + +/// Frame a [`StreamResponseStream`] as JSON-RPC SSE — each event is a +/// `JsonRpcResponse` whose `result` is the (tag-free union) `StreamResponse`. +fn jsonrpc_sse( + id: JsonRpcId, + stream: StreamResponseStream, +) -> Sse>> { + let events = stream.map(move |item| { + let (seq_id, resp) = match item { + Ok((seq_id, sr)) => ( + seq_id, + JsonRpcResponse::ok(id.clone(), serde_json::to_value(&sr).unwrap_or(Value::Null)), + ), + Err(e) => (None, JsonRpcResponse::err(id.clone(), a2a_to_jsonrpc(&e))), + }; + let event = Event::default().data(serde_json::to_string(&resp).unwrap_or_default()); + Ok(match seq_id { + Some(n) => event.id(n.to_string()), + None => event, + }) + }); + Sse::new(events).keep_alive(KeepAlive::default()) +} + +/// Build the HTTP+JSON (REST) router with the official-SDK paths (no `/v1` +/// prefix). Bodies and responses are bare ProtoJSON (no JSON-RPC envelope). +/// +/// The canonical custom-method paths use a `:`-suffix on a collection segment +/// (`/message:send`) — those work as pure-literal segments. The *task* +/// custom-method paths (`/tasks/{id}:cancel`) would put a `:`-suffix on the +/// **same segment as a path parameter**, which axum's matchit router rejects +/// (it conflicts with `/tasks/{id}`). We therefore serve the equivalent +/// slash-form aliases (`/tasks/{id}/cancel`) for those, which official clients +/// also accept. +pub fn rest_router(adapter: Arc) -> Router { + Router::new() + .route("/message:send", post(rest_send_message)) + .route("/message/send", post(rest_send_message)) + .route("/message:stream", post(rest_stream_message)) + .route("/message/stream", post(rest_stream_message)) + .route("/tasks", get(rest_list_tasks)) + .route("/tasks/{id}", get(rest_get_task)) + .route("/tasks/{id}/cancel", post(rest_cancel_task)) + .route("/tasks/{id}/subscribe", get(rest_subscribe)) + .route( + "/tasks/{id}/pushNotificationConfigs", + post(rest_create_push_config).get(rest_list_push_configs), + ) + .route( + "/tasks/{id}/pushNotificationConfigs/{cfg}", + get(rest_get_push_config).delete(rest_delete_push_config), + ) + .route("/extendedAgentCard", get(rest_extended_card)) + .route("/card", get(rest_extended_card)) + .with_state(adapter) +} + +/// Convert a unary `Result` into a REST HTTP response: 200 with +/// the bare ProtoJSON body, or an error status + `{code, message, data?}`. +fn rest_result(result: Result) -> Response { + match result { + Ok(value) => Json(value).into_response(), + Err(e) => a2a_to_http(&e), + } +} + +/// Map a domain [`A2AError`] onto an HTTP status + JSON error body for REST. +fn a2a_to_http(err: &A2AError) -> Response { + let status = match err { + A2AError::TaskNotFound(_) | A2AError::MethodNotFound(_) => StatusCode::NOT_FOUND, + A2AError::InvalidParams(_) | A2AError::ValidationError { .. } => StatusCode::BAD_REQUEST, + A2AError::UnsupportedOperation(_) => StatusCode::NOT_IMPLEMENTED, + A2AError::AuthenticatedExtendedCardNotConfigured => StatusCode::PRECONDITION_FAILED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + (status, Json(a2a_to_jsonrpc(err))).into_response() +} + +async fn rest_send_message(State(a): State>, body: Bytes) -> Response { + rest_result( + a.dispatch_intercepted(methods::SEND_MESSAGE, parse_body(&body)) + .await, + ) +} + +async fn rest_list_tasks( + State(a): State>, + Query(q): Query>, +) -> Response { + rest_result( + a.dispatch_intercepted(methods::LIST_TASKS, Some(query_to_list_request(&q))) + .await, + ) +} + +async fn rest_get_task( + State(a): State>, + Path(id): Path, + Query(q): Query>, +) -> Response { + let mut req = serde_json::json!({ "id": id }); + if let Some(h) = q.get("historyLength").and_then(|s| s.parse::().ok()) { + req["historyLength"] = h.into(); + } + rest_result(a.dispatch_intercepted(methods::GET_TASK, Some(req)).await) +} + +async fn rest_cancel_task( + State(a): State>, + Path(id): Path, +) -> Response { + rest_result( + a.dispatch_intercepted(methods::CANCEL_TASK, Some(serde_json::json!({ "id": id }))) + .await, + ) +} + +async fn rest_create_push_config( + State(a): State>, + Path(id): Path, + body: Bytes, +) -> Response { + // The path task id is authoritative for the config's parent. + let mut config = parse_body(&body).unwrap_or_else(|| serde_json::json!({})); + config["taskId"] = id.into(); + rest_result( + a.dispatch_intercepted(methods::CREATE_PUSH_CONFIG, Some(config)) + .await, + ) +} + +async fn rest_list_push_configs( + State(a): State>, + Path(id): Path, +) -> Response { + rest_result( + a.dispatch_intercepted( + methods::LIST_PUSH_CONFIGS, + Some(serde_json::json!({ "taskId": id })), + ) + .await, + ) +} + +async fn rest_get_push_config( + State(a): State>, + Path((id, cfg)): Path<(String, String)>, +) -> Response { + rest_result( + a.dispatch_intercepted( + methods::GET_PUSH_CONFIG, + Some(serde_json::json!({ "taskId": id, "id": cfg })), + ) + .await, + ) +} + +async fn rest_delete_push_config( + State(a): State>, + Path((id, cfg)): Path<(String, String)>, +) -> Response { + rest_result( + a.dispatch_intercepted( + methods::DELETE_PUSH_CONFIG, + Some(serde_json::json!({ "taskId": id, "id": cfg })), + ) + .await, + ) +} + +async fn rest_extended_card(State(a): State>) -> Response { + rest_result( + a.dispatch_intercepted(methods::GET_EXTENDED_AGENT_CARD, None) + .await, + ) +} + +async fn rest_stream_message( + State(a): State>, + headers: HeaderMap, + body: Bytes, +) -> Response { + let from_event_id = parse_last_event_id(&headers); + match a + .open_stream( + methods::SEND_STREAMING_MESSAGE, + parse_body(&body), + from_event_id, + ) + .await + { + Ok(stream) => rest_sse(stream).into_response(), + Err(e) => a2a_to_http(&e), + } +} + +async fn rest_subscribe( + State(a): State>, + headers: HeaderMap, + Path(id): Path, +) -> Response { + let from_event_id = parse_last_event_id(&headers); + match a + .open_stream( + methods::SUBSCRIBE_TO_TASK, + Some(serde_json::json!({ "id": id })), + from_event_id, + ) + .await + { + Ok(stream) => rest_sse(stream).into_response(), + Err(e) => a2a_to_http(&e), + } +} + +/// Parse the SSE `Last-Event-ID` header into a per-task event id for resumption. +/// +/// This is the server half of the a2a-rs resumption enhancement (not an A2A v1.0 +/// spec feature). Spec-compliant clients never send the header, so they always +/// get a fresh stream from current state — the `SubscribeToTask` behavior the +/// spec defines. The complementary SSE `id:` field is emitted by [`jsonrpc_sse`] +/// / [`rest_sse`] and is inert for clients that don't use it. +fn parse_last_event_id(headers: &HeaderMap) -> Option { + headers + .get("last-event-id") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.trim().parse::().ok()) +} + +/// Frame a [`StreamResponseStream`] as bare-ProtoJSON SSE (REST has no envelope). +fn rest_sse(stream: StreamResponseStream) -> Sse>> { + let events = stream.map(|item| { + let (seq_id, data) = match item { + Ok((seq_id, sr)) => (seq_id, serde_json::to_string(&sr).unwrap_or_default()), + Err(e) => ( + None, + serde_json::to_string(&a2a_to_jsonrpc(&e)).unwrap_or_default(), + ), + }; + let event = Event::default().data(data); + Ok(match seq_id { + Some(n) => event.id(n.to_string()), + None => event, + }) + }); + Sse::new(events).keep_alive(KeepAlive::default()) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Deserialize JSON-RPC `params` into a concrete proto request type, mapping +/// serde failures to `InvalidParams`. +fn parse_params(params: Option) -> Result { + serde_json::from_value(params.unwrap_or(Value::Null)) + .map_err(|e| A2AError::InvalidParams(format!("invalid params: {e}"))) +} + +/// Parse a REST request body into a JSON value (empty body → `None`). +fn parse_body(body: &Bytes) -> Option { + if body.is_empty() { + None + } else { + serde_json::from_slice(body).ok() + } +} + +/// Serialize a domain/wire value into a JSON `result`, mapping failures to an +/// internal error. +fn to_value(value: &T) -> Result { + serde_json::to_value(value) + .map_err(|e| A2AError::InvalidParams(format!("failed to serialize result: {e}"))) +} + +/// Decode a [`SendMessageRequest`] into the arguments [`TaskService::send_message`] +/// expects. Mirrors the Connect adapter's `send_message` decode exactly. +type SendArgs = ( + String, + crate::domain::Message, + Option, + Option, + Option, +); +fn decode_send_message(req: SendMessageRequest) -> Result { + let message = req + .message + .into_option() + .ok_or_else(|| A2AError::InvalidParams("missing message".to_string()))?; + let task_id = message.task_id.clone(); + let session_id = (!message.context_id.is_empty()).then(|| message.context_id.clone()); + let (push_config, history_limit) = decode_send_config(req.configuration.into_option()); + Ok((task_id, message, session_id, push_config, history_limit)) +} + +/// Build the SSE stream: initial task snapshot (if present) followed by the +/// mapped update events. Mirrors `connectrpc.rs`'s `stream::once(task).chain(...)`. +fn chain_initial_task( + initial: Option, + updates: crate::application::UpdateStream, +) -> StreamResponseStream { + let mapped = updates.map(|item| item.map(|seq| (Some(seq.id), map_update_event(seq.event)))); + match initial { + Some(task) => { + let head = StreamResponse { + payload: Some(stream_response::Payload::Task(Box::new(task))), + ..Default::default() + }; + Box::pin(futures::stream::once(async move { Ok((None, head)) }).chain(mapped)) + } + None => Box::pin(mapped), + } +} + +/// Assemble a `ListTasksRequest`-shaped JSON object from REST query parameters, +/// coercing numeric/boolean fields to their proto JSON types. +fn query_to_list_request(q: &std::collections::HashMap) -> Value { + let mut req = serde_json::Map::new(); + if let Some(v) = q.get("contextId") { + req.insert("contextId".to_string(), v.clone().into()); + } + if let Some(v) = q.get("status") { + req.insert("status".to_string(), v.clone().into()); + } + if let Some(v) = q.get("pageToken") { + req.insert("pageToken".to_string(), v.clone().into()); + } + if let Some(v) = q.get("pageSize").and_then(|s| s.parse::().ok()) { + req.insert("pageSize".to_string(), v.into()); + } + if let Some(v) = q.get("historyLength").and_then(|s| s.parse::().ok()) { + req.insert("historyLength".to_string(), v.into()); + } + if let Some(v) = q + .get("includeArtifacts") + .and_then(|s| s.parse::().ok()) + { + req.insert("includeArtifacts".to_string(), v.into()); + } + if let Some(v) = q.get("statusTimestampAfter") { + req.insert("statusTimestampAfter".to_string(), v.clone().into()); + } + Value::Object(req) +} diff --git a/a2a-rs/src/adapter/transport/jsonrpc_client.rs b/a2a-rs/src/adapter/transport/jsonrpc_client.rs new file mode 100644 index 0000000..51baaf4 --- /dev/null +++ b/a2a-rs/src/adapter/transport/jsonrpc_client.rs @@ -0,0 +1,547 @@ +//! Wire-compatible JSON-RPC 2.0 client adapter. +//! +//! [`JsonRpcClient`] is the client-side counterpart of +//! [`JsonRpcAdapter`](super::jsonrpc::JsonRpcAdapter): it implements the +//! [`Transport`] port by speaking the spec-mandated JSON-RPC 2.0 wire format +//! (single `POST` endpoint, SSE for streaming) that the official Go/C#/Python +//! SDKs use. This lets our client talk to any standard A2A agent. +//! +//! Request `params` and response `result` bodies are the **generated proto +//! types** (`SendMessageRequest`, `Task`, `SendMessageResponse`, …), which +//! already serialize as canonical ProtoJSON — the same representation the server +//! adapter decodes. The method names, error codes, and envelopes come from the +//! shared [`jsonrpc_wire`](super::jsonrpc_wire) module so the two directions +//! cannot drift. + +use async_trait::async_trait; +use futures::stream::Stream; +use reqwest::{ + Client, + header::{HeaderMap, HeaderValue}, +}; +use serde::{Serialize, de::DeserializeOwned}; +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use crate::{ + adapter::error::HttpClientError, + adapter::transport::codec::stream_response_to_item, + domain::{ + A2AError, AgentCard, ListTasksParams, ListTasksResult, Message, Task, + TaskPushNotificationConfig, + generated::{ + CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, GetTaskRequest, + ListTaskPushNotificationConfigsRequest, ListTaskPushNotificationConfigsResponse, + ListTasksRequest, ListTasksResponse, SendMessageConfiguration, SendMessageRequest, + SendMessageResponse, StreamResponse, SubscribeToTaskRequest, TaskState, + send_message_response, + }, + }, + port::{ + CallContext, CallInterceptor, CallSide, StreamEvent, StreamItem, Transport, run_after, + run_before, + }, +}; + +use super::jsonrpc_wire::{JsonRpcId, JsonRpcRequest, JsonRpcResponse, jsonrpc_to_a2a, methods}; + +/// A wire-compatible JSON-RPC 2.0 client for the A2A protocol. +/// +/// Mirrors [`HttpClient`](super::http::HttpClient)'s constructors so an +/// application can swap the ConnectRPC transport for JSON-RPC with one line. +pub struct JsonRpcClient { + /// Base URL of the agent (also the JSON-RPC `POST` endpoint root). + base_url: String, + client: Client, + auth_token: Option, + /// Request timeout in seconds. + timeout: u64, + /// Client-side interceptor chain wrapping every call. + interceptors: Vec>, +} + +impl JsonRpcClient { + /// Create a new JSON-RPC client targeting `base_url`. + pub fn new(base_url: String) -> Self { + Self { + base_url, + client: Client::new(), + auth_token: None, + timeout: 30, + interceptors: Vec::new(), + } + } + + /// Create a JSON-RPC client with a bearer auth token. + pub fn with_auth(base_url: String, auth_token: String) -> Self { + Self { + base_url, + client: Client::new(), + auth_token: Some(auth_token), + timeout: 30, + interceptors: Vec::new(), + } + } + + /// Set the request timeout (seconds). + pub fn with_timeout(mut self, timeout: u64) -> Self { + self.timeout = timeout; + self + } + + /// Append a client-side [`CallInterceptor`] to the chain. + /// + /// Interceptors wrap every call (`rpc` and the streaming subscribe): + /// `before` hooks run in registration order, then the request is sent, then + /// `after` hooks run in reverse. Chainable. + pub fn with_interceptor(mut self, interceptor: impl CallInterceptor + 'static) -> Self { + self.interceptors.push(Arc::new(interceptor)); + self + } + + /// Get the base URL of the client. + pub fn base_url(&self) -> &str { + &self.base_url + } + + fn headers(&self) -> Result { + let mut headers = HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + if let Some(token) = &self.auth_token { + let value = HeaderValue::from_str(&format!("Bearer {token}")) + .map_err(|e| A2AError::Internal(format!("Invalid auth token for header: {e}")))?; + headers.insert(reqwest::header::AUTHORIZATION, value); + } + Ok(headers) + } + + /// Resolve a path relative to the base URL (handles trailing-slash variance). + fn join(&self, path: &str) -> String { + let base = self.base_url.trim_end_matches('/'); + let path = path.trim_start_matches('/'); + format!("{base}/{path}") + } + + /// Fetch the agent card from the well-known endpoint (plain HTTP GET). + /// + /// Tries the spec path `/.well-known/agent-card.json` first, falling back to + /// the legacy `/agent-card` path. + pub async fn get_agent_card(&self) -> Result { + for path in [".well-known/agent-card.json", "agent-card"] { + let url = self.join(path); + let resp = self + .client + .get(&url) + .headers(self.headers()?) + .timeout(Duration::from_secs(self.timeout)) + .send() + .await + .map_err(HttpClientError::Reqwest)?; + if resp.status().is_success() { + return resp.json::().await.map_err(|e| { + A2AError::Internal(format!("Failed to parse agent card JSON: {e}")) + }); + } + } + Err(A2AError::Internal(format!( + "Agent card not found at {}", + self.base_url + ))) + } + + /// Send a JSON-RPC request envelope and decode the typed `result`, running + /// the client interceptor chain around the call. + async fn rpc( + &self, + method: &str, + params: &P, + ) -> Result { + if self.interceptors.is_empty() { + return self.rpc_inner(method, params).await; + } + let ctx = CallContext::new(method, CallSide::Client); + run_before(&self.interceptors, &ctx).await?; + let result = self.rpc_inner(method, params).await; + run_after(&self.interceptors, &ctx, result.as_ref().map(|_| ())).await; + result + } + + /// The un-intercepted JSON-RPC round-trip. + async fn rpc_inner( + &self, + method: &str, + params: &P, + ) -> Result { + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Num(1), + method: method.to_string(), + params: Some( + serde_json::to_value(params) + .map_err(|e| A2AError::Internal(format!("failed to encode params: {e}")))?, + ), + }; + + let response = self + .client + .post(&self.base_url) + .headers(self.headers()?) + .timeout(Duration::from_secs(self.timeout)) + .json(&request) + .send() + .await + .map_err(HttpClientError::Reqwest)?; + + let body: JsonRpcResponse = response + .json() + .await + .map_err(|e| A2AError::Internal(format!("invalid JSON-RPC response: {e}")))?; + + if let Some(err) = body.error { + return Err(jsonrpc_to_a2a(&err)); + } + let result = body + .result + .ok_or_else(|| A2AError::Internal("JSON-RPC response missing result".to_string()))?; + serde_json::from_value(result) + .map_err(|e| A2AError::Internal(format!("failed to decode result: {e}"))) + } + + /// The un-intercepted streaming subscribe (SSE round-trip). `last_event_id`, + /// when set, is sent as the `Last-Event-ID` header so the server replays + /// events after that id before streaming live updates. + async fn subscribe_inner( + &self, + task_id: &str, + last_event_id: Option<&str>, + ) -> Result> + Send>>, A2AError> { + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Num(1), + method: methods::SUBSCRIBE_TO_TASK.to_string(), + params: Some( + serde_json::to_value(SubscribeToTaskRequest { + id: task_id.to_string(), + ..Default::default() + }) + .map_err(|e| A2AError::Internal(format!("failed to encode params: {e}")))?, + ), + }; + + let mut builder = self + .client + .post(&self.base_url) + .headers(self.headers()?) + .header(reqwest::header::ACCEPT, "text/event-stream"); + if let Some(id) = last_event_id { + builder = builder.header("last-event-id", id); + } + let response = builder + .json(&request) + .send() + .await + .map_err(HttpClientError::Reqwest)?; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let body = response.text().await.unwrap_or_default(); + return Err(HttpClientError::Response { + status, + message: body, + } + .into()); + } + + Ok(Box::pin(sse_stream(response))) + } +} + +#[async_trait] +impl Transport for JsonRpcClient { + fn protocol(&self) -> &str { + "JSONRPC" + } + + async fn send_task_message( + &self, + task_id: &str, + message: &Message, + session_id: Option<&str>, + history_length: Option, + ) -> Result { + let mut msg = message.clone(); + msg.task_id = task_id.to_string(); + if let Some(sid) = session_id { + msg.context_id = sid.to_string(); + } + + let request = SendMessageRequest { + message: ::buffa::MessageField::some(msg), + configuration: ::buffa::MessageField::some(SendMessageConfiguration { + history_length: history_length.map(|l| l as i32), + ..Default::default() + }), + ..Default::default() + }; + + let response: SendMessageResponse = self.rpc(methods::SEND_MESSAGE, &request).await?; + match response.payload { + Some(send_message_response::Payload::Task(task)) => Ok(*task), + _ => Err(A2AError::Internal( + "Expected task in SendMessageResponse payload".to_string(), + )), + } + } + + async fn get_task(&self, task_id: &str, history_length: Option) -> Result { + let request = GetTaskRequest { + id: task_id.to_string(), + history_length: history_length.map(|l| l as i32), + ..Default::default() + }; + self.rpc(methods::GET_TASK, &request).await + } + + async fn cancel_task(&self, task_id: &str) -> Result { + let request = CancelTaskRequest { + id: task_id.to_string(), + ..Default::default() + }; + self.rpc(methods::CANCEL_TASK, &request).await + } + + async fn set_task_push_notification( + &self, + config: &TaskPushNotificationConfig, + ) -> Result { + self.rpc(methods::CREATE_PUSH_CONFIG, config).await + } + + async fn get_task_push_notification( + &self, + task_id: &str, + ) -> Result { + // Mirrors the ConnectRPC client: list configs and take the first. + let configs = self.list_push_notification_configs(task_id).await?; + configs.into_iter().next().ok_or_else(|| { + A2AError::TaskNotFound(format!( + "No push notification config found for task {task_id}" + )) + }) + } + + async fn list_tasks(&self, params: &ListTasksParams) -> Result { + let mut request = ListTasksRequest { + context_id: params.context_id.clone().unwrap_or_default(), + status: ::buffa::EnumValue::from( + params.status.unwrap_or(TaskState::TASK_STATE_UNSPECIFIED), + ), + page_size: params.page_size, + page_token: params.page_token.clone().unwrap_or_default(), + history_length: params.history_length, + include_artifacts: params.include_artifacts, + ..Default::default() + }; + if let Some(ref t_str) = params.status_timestamp_after { + if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(t_str) { + let utc_dt = dt.with_timezone(&chrono::Utc); + request.status_timestamp_after = + ::buffa::MessageField::some(::buffa_types::google::protobuf::Timestamp { + seconds: utc_dt.timestamp(), + nanos: utc_dt.timestamp_subsec_nanos() as i32, + ..Default::default() + }); + } + } + + let response: ListTasksResponse = self.rpc(methods::LIST_TASKS, &request).await?; + Ok(ListTasksResult { + tasks: response.tasks, + total_size: response.total_size, + page_size: response.page_size, + next_page_token: response.next_page_token, + }) + } + + async fn list_push_notification_configs( + &self, + task_id: &str, + ) -> Result, A2AError> { + let request = ListTaskPushNotificationConfigsRequest { + task_id: task_id.to_string(), + ..Default::default() + }; + let response: ListTaskPushNotificationConfigsResponse = + self.rpc(methods::LIST_PUSH_CONFIGS, &request).await?; + Ok(response.configs) + } + + async fn get_push_notification_config( + &self, + task_id: &str, + config_id: &str, + ) -> Result { + let request = GetTaskPushNotificationConfigRequest { + task_id: task_id.to_string(), + id: config_id.to_string(), + ..Default::default() + }; + self.rpc(methods::GET_PUSH_CONFIG, &request).await + } + + async fn delete_push_notification_config( + &self, + task_id: &str, + config_id: &str, + ) -> Result<(), A2AError> { + let request = DeleteTaskPushNotificationConfigRequest { + task_id: task_id.to_string(), + id: config_id.to_string(), + ..Default::default() + }; + // The server replies with an empty object `{}`; we only care that it + // succeeded. + let _: serde::de::IgnoredAny = self.rpc(methods::DELETE_PUSH_CONFIG, &request).await?; + Ok(()) + } + + async fn subscribe_to_task( + &self, + task_id: &str, + _history_length: Option, + last_event_id: Option<&str>, + ) -> Result> + Send>>, A2AError> { + if self.interceptors.is_empty() { + return self.subscribe_inner(task_id, last_event_id).await; + } + let ctx = CallContext::new(methods::SUBSCRIBE_TO_TASK, CallSide::Client); + run_before(&self.interceptors, &ctx).await?; + let result = self.subscribe_inner(task_id, last_event_id).await; + run_after(&self.interceptors, &ctx, result.as_ref().map(|_| ())).await; + result + } +} + +// --------------------------------------------------------------------------- +// SSE consumption +// --------------------------------------------------------------------------- + +/// Reassemble an `text/event-stream` body into a stream of [`StreamEvent`]s. +/// +/// Each SSE event is a `data:` payload carrying a [`JsonRpcResponse`] whose +/// `result` is a [`StreamResponse`] union (this is exactly what +/// [`JsonRpcAdapter`](super::jsonrpc::JsonRpcAdapter)'s SSE path emits), plus an +/// optional `id:` line carrying the server's per-task event id (surfaced on the +/// [`StreamEvent`] for `Last-Event-ID` resumption). Chunks from the socket may +/// split mid-event or mid-UTF-8-sequence, so we buffer and only emit on a +/// complete event boundary (`\n\n`). +fn sse_stream( + response: reqwest::Response, +) -> impl Stream> + Send { + struct State { + response: reqwest::Response, + buf: String, + pending: VecDeque>, + done: bool, + } + + let state = State { + response, + buf: String::new(), + pending: VecDeque::new(), + done: false, + }; + + futures::stream::unfold(state, |mut st| async move { + loop { + if let Some(item) = st.pending.pop_front() { + return Some((item, st)); + } + if st.done { + return None; + } + match st.response.chunk().await { + Ok(Some(chunk)) => { + st.buf.push_str(&String::from_utf8_lossy(&chunk)); + drain_sse_events(&mut st.buf, &mut st.pending, false); + } + Ok(None) => { + drain_sse_events(&mut st.buf, &mut st.pending, true); + st.done = true; + } + Err(e) => { + st.pending + .push_back(Err(A2AError::Internal(format!("SSE read error: {e}")))); + st.done = true; + } + } + } + }) +} + +/// Extract complete SSE events from `buf`, pushing each decoded event to `out`. +/// When `flush` is true, a trailing event with no terminating blank line is also +/// processed (end of stream). +fn drain_sse_events( + buf: &mut String, + out: &mut VecDeque>, + flush: bool, +) { + loop { + let event = match buf.find("\n\n") { + Some(i) => { + let event = buf[..i].to_string(); + *buf = buf[i + 2..].to_string(); + event + } + None => { + if flush && !buf.trim().is_empty() { + std::mem::take(buf) + } else { + return; + } + } + }; + + let data: String = event + .lines() + .filter_map(|line| line.strip_prefix("data:").map(str::trim_start)) + .collect::>() + .join("\n"); + + let event_id = event + .lines() + .find_map(|line| line.strip_prefix("id:").map(str::trim_start)) + .and_then(|s| s.parse::().ok()); + + if !data.is_empty() { + out.push_back(parse_sse_frame(&data).map(|item| StreamEvent::new(event_id, item))); + } + + if flush && buf.is_empty() { + return; + } + } +} + +/// Decode one SSE `data:` payload (a JSON-RPC response frame) into a [`StreamItem`]. +fn parse_sse_frame(data: &str) -> Result { + let frame: JsonRpcResponse = serde_json::from_str(data) + .map_err(|e| A2AError::Internal(format!("invalid SSE JSON-RPC frame: {e}")))?; + if let Some(err) = frame.error { + return Err(jsonrpc_to_a2a(&err)); + } + let value = frame + .result + .ok_or_else(|| A2AError::Internal("SSE frame missing result".to_string()))?; + let stream_response: StreamResponse = serde_json::from_value(value) + .map_err(|e| A2AError::Internal(format!("invalid StreamResponse: {e}")))?; + stream_response_to_item(stream_response) + .ok_or_else(|| A2AError::Internal("empty stream response payload".to_string())) +} diff --git a/a2a-rs/src/adapter/transport/jsonrpc_wire.rs b/a2a-rs/src/adapter/transport/jsonrpc_wire.rs new file mode 100644 index 0000000..65cea5a --- /dev/null +++ b/a2a-rs/src/adapter/transport/jsonrpc_wire.rs @@ -0,0 +1,253 @@ +//! Shared JSON-RPC 2.0 wire vocabulary for the A2A protocol. +//! +//! These method names, error codes, request/response envelopes, and the +//! `A2AError` ⇄ JSON-RPC error mappings are the contract both the server adapter +//! ([`JsonRpcAdapter`](super::jsonrpc::JsonRpcAdapter)) and the client adapter +//! ([`JsonRpcClient`](super::jsonrpc_client::JsonRpcClient)) must agree on +//! byte-for-byte. Keeping them in one module guarantees the two directions never +//! drift. +//! +//! Both envelopes derive `Serialize + Deserialize`: the server deserializes +//! requests / serializes responses, and the client does the inverse. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::domain::{A2AError, ErrorDetail, ErrorInfo}; + +/// A2A JSON-RPC method names (PascalCase, per spec). +pub mod methods { + pub const SEND_MESSAGE: &str = "SendMessage"; + pub const SEND_STREAMING_MESSAGE: &str = "SendStreamingMessage"; + pub const GET_TASK: &str = "GetTask"; + pub const LIST_TASKS: &str = "ListTasks"; + pub const CANCEL_TASK: &str = "CancelTask"; + pub const SUBSCRIBE_TO_TASK: &str = "SubscribeToTask"; + pub const CREATE_PUSH_CONFIG: &str = "CreateTaskPushNotificationConfig"; + pub const GET_PUSH_CONFIG: &str = "GetTaskPushNotificationConfig"; + pub const LIST_PUSH_CONFIGS: &str = "ListTaskPushNotificationConfigs"; + pub const DELETE_PUSH_CONFIG: &str = "DeleteTaskPushNotificationConfig"; + pub const GET_EXTENDED_AGENT_CARD: &str = "GetExtendedAgentCard"; + + /// Streaming methods respond with SSE rather than a single response. + pub fn is_streaming(method: &str) -> bool { + matches!(method, SEND_STREAMING_MESSAGE | SUBSCRIBE_TO_TASK) + } +} + +/// Standard + A2A-specific JSON-RPC error codes (mirrors the official SDK). +pub mod error_code { + pub const PARSE_ERROR: i32 = -32700; + pub const INVALID_REQUEST: i32 = -32600; + pub const METHOD_NOT_FOUND: i32 = -32601; + pub const INVALID_PARAMS: i32 = -32602; + pub const INTERNAL_ERROR: i32 = -32603; + + pub const TASK_NOT_FOUND: i32 = -32001; + pub const TASK_NOT_CANCELABLE: i32 = -32002; + pub const PUSH_NOTIFICATION_NOT_SUPPORTED: i32 = -32003; + pub const UNSUPPORTED_OPERATION: i32 = -32004; + pub const CONTENT_TYPE_NOT_SUPPORTED: i32 = -32005; + pub const INVALID_AGENT_RESPONSE: i32 = -32006; + pub const EXTENDED_CARD_NOT_CONFIGURED: i32 = -32007; + + /// Custom application range (outside the spec's reserved codes). + pub const VERSION_CONFLICT: i32 = -32101; +} + +/// JSON-RPC request envelope (server deserializes; client serializes). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + #[serde(default)] + pub id: JsonRpcId, + pub method: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// JSON-RPC response envelope (server serializes; client deserializes). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcResponse { + pub jsonrpc: String, + #[serde(default)] + pub id: JsonRpcId, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl JsonRpcResponse { + pub fn ok(id: JsonRpcId, result: Value) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + result: Some(result), + error: None, + } + } + + pub fn err(id: JsonRpcId, error: JsonRpcError) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(error), + } + } +} + +/// JSON-RPC request id — preserves the wire type (string, number, or null). +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(untagged)] +pub enum JsonRpcId { + Str(String), + Num(i64), + #[default] + Null, +} + +/// JSON-RPC error object. `data` carries typed A2A error details when available. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcError { + pub code: i32, + pub message: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// The JSON-RPC error code for a domain [`A2AError`]. +fn a2a_error_code(err: &A2AError) -> i32 { + use error_code::*; + match err { + A2AError::JsonRpc { code, .. } => *code, + A2AError::JsonParse(_) => PARSE_ERROR, + A2AError::InvalidRequest(_) => INVALID_REQUEST, + A2AError::MethodNotFound(_) => METHOD_NOT_FOUND, + A2AError::InvalidParams(_) | A2AError::ValidationError { .. } => INVALID_PARAMS, + A2AError::TaskNotFound(_) => TASK_NOT_FOUND, + A2AError::TaskNotCancelable(_) => TASK_NOT_CANCELABLE, + A2AError::PushNotificationNotSupported => PUSH_NOTIFICATION_NOT_SUPPORTED, + A2AError::UnsupportedOperation(_) => UNSUPPORTED_OPERATION, + A2AError::ContentTypeNotSupported(_) => CONTENT_TYPE_NOT_SUPPORTED, + A2AError::InvalidAgentResponse(_) => INVALID_AGENT_RESPONSE, + A2AError::AuthenticatedExtendedCardNotConfigured => EXTENDED_CARD_NOT_CONFIGURED, + A2AError::VersionConflict { .. } => VERSION_CONFLICT, + _ => INTERNAL_ERROR, + } +} + +/// Map a domain [`A2AError`] onto a JSON-RPC error object. +/// +/// This is the JSON-RPC analogue of `connectrpc::map_err`. The `data` array +/// carries the error's [typed details](A2AError::error_details): a Google-RPC +/// `BadRequest` for validation failures, plus an `ErrorInfo` reason code on every +/// error so clients can branch on a stable machine code rather than the message +/// string. [`jsonrpc_to_a2a`] reverses this. +pub fn a2a_to_jsonrpc(err: &A2AError) -> JsonRpcError { + let message = match err { + A2AError::ValidationError { field, message } => format!("{field}: {message}"), + other => other.to_string(), + }; + let details = err.error_details(); + let data = (!details.is_empty()).then(|| serde_json::json!(details)); + JsonRpcError { + code: a2a_error_code(err), + message, + data, + } +} + +/// Map a JSON-RPC error object back onto a domain [`A2AError`] — the inverse of +/// [`a2a_to_jsonrpc`], used by the client to reconstruct typed errors. +/// +/// A [`A2AError::VersionConflict`] is rebuilt from its `ErrorInfo` metadata when +/// present, so the typed expected/actual versions survive the round-trip. +pub fn jsonrpc_to_a2a(err: &JsonRpcError) -> A2AError { + use error_code::*; + match err.code { + TASK_NOT_FOUND => A2AError::TaskNotFound(err.message.clone()), + INVALID_PARAMS => A2AError::InvalidParams(err.message.clone()), + METHOD_NOT_FOUND => A2AError::MethodNotFound(err.message.clone()), + UNSUPPORTED_OPERATION => A2AError::UnsupportedOperation(err.message.clone()), + EXTENDED_CARD_NOT_CONFIGURED => A2AError::AuthenticatedExtendedCardNotConfigured, + VERSION_CONFLICT => version_conflict_from_data(err) + .unwrap_or_else(|| A2AError::Internal(err.message.clone())), + code => A2AError::JsonRpc { + code, + message: err.message.clone(), + data: err.data.clone(), + }, + } +} + +/// Reconstruct a [`A2AError::VersionConflict`] from the `ErrorInfo` metadata in a +/// wire error's `data` array, if it carries the expected/actual versions. +fn version_conflict_from_data(err: &JsonRpcError) -> Option { + let details: Vec = serde_json::from_value(err.data.clone()?).ok()?; + let ErrorInfo { metadata, .. } = details.into_iter().find_map(|d| match d { + ErrorDetail::ErrorInfo(info) => Some(info), + _ => None, + })?; + Some(A2AError::VersionConflict { + id: metadata.get("task_id").cloned().unwrap_or_default(), + expected: metadata.get("expected").and_then(|s| s.parse().ok())?, + actual: metadata.get("actual").and_then(|s| s.parse().ok())?, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validation_error_surfaces_bad_request_details() { + let err = A2AError::ValidationError { + field: "history_length".to_string(), + message: "too large".to_string(), + }; + let wire = a2a_to_jsonrpc(&err); + assert_eq!(wire.code, error_code::INVALID_PARAMS); + let data = wire.data.expect("validation errors carry data"); + // First detail is a Google-RPC BadRequest naming the field. + assert_eq!( + data[0]["@type"], + "type.googleapis.com/google.rpc.BadRequest" + ); + assert_eq!(data[0]["fieldViolations"][0]["field"], "history_length"); + // Second detail is the stable reason code. + assert_eq!(data[1]["reason"], "VALIDATION_ERROR"); + } + + #[test] + fn version_conflict_round_trips_through_the_wire() { + let err = A2AError::VersionConflict { + id: "task-42".to_string(), + expected: 3, + actual: 5, + }; + let wire = a2a_to_jsonrpc(&err); + assert_eq!(wire.code, error_code::VERSION_CONFLICT); + match jsonrpc_to_a2a(&wire) { + A2AError::VersionConflict { + id, + expected, + actual, + } => { + assert_eq!(id, "task-42"); + assert_eq!(expected, 3); + assert_eq!(actual, 5); + } + other => panic!("expected VersionConflict, got {other:?}"), + } + } + + #[test] + fn every_error_carries_a_reason_code() { + let wire = a2a_to_jsonrpc(&A2AError::TaskNotFound("x".to_string())); + let data = wire.data.expect("errors carry an ErrorInfo reason"); + assert_eq!(data[0]["reason"], "TASK_NOT_FOUND"); + assert_eq!(data[0]["domain"], "a2a-rs"); + } +} diff --git a/a2a-rs/src/adapter/transport/mod.rs b/a2a-rs/src/adapter/transport/mod.rs index c06dc56..8e52476 100644 --- a/a2a-rs/src/adapter/transport/mod.rs +++ b/a2a-rs/src/adapter/transport/mod.rs @@ -1,4 +1,38 @@ //! Transport protocol adapter implementations +/// Shared client-side wire decoding (`StreamResponse` → `StreamItem`). +#[cfg(feature = "client")] +pub mod codec; +/// ConnectRPC transport adapter (`impl A2aService`) over the application service. +#[cfg(feature = "server")] +pub mod connectrpc; #[cfg(any(feature = "http-client", feature = "http-server"))] pub mod http; +/// Wire-compatible JSON-RPC 2.0 + HTTP+JSON (REST) transport adapter. +#[cfg(feature = "jsonrpc-server")] +pub mod jsonrpc; +/// Wire-compatible JSON-RPC 2.0 client adapter (`impl Transport`). +#[cfg(feature = "jsonrpc-client")] +pub mod jsonrpc_client; +/// Shared JSON-RPC 2.0 wire vocabulary (method names, error codes, envelopes, +/// error maps) — the byte-for-byte contract between the JSON-RPC server and +/// client adapters. +#[cfg(any(feature = "jsonrpc-server", feature = "jsonrpc-client"))] +pub mod jsonrpc_wire; +/// Client-side transport negotiation from an agent card. +#[cfg(feature = "client")] +pub mod negotiation; +/// Resilient streaming: reconnect-with-backoff over the `Transport` port. +#[cfg(feature = "client")] +pub mod retry; + +#[cfg(feature = "server")] +pub use connectrpc::ConnectRpcAdapter; +#[cfg(feature = "jsonrpc-server")] +pub use jsonrpc::{JsonRpcAdapter, jsonrpc_router, rest_router}; +#[cfg(feature = "jsonrpc-client")] +pub use jsonrpc_client::JsonRpcClient; +#[cfg(feature = "client")] +pub use negotiation::{TransportFactory, TransportNegotiator, default_registry}; +#[cfg(feature = "client")] +pub use retry::{RetryingTransport, subscribe_resilient}; diff --git a/a2a-rs/src/adapter/transport/negotiation.rs b/a2a-rs/src/adapter/transport/negotiation.rs new file mode 100644 index 0000000..82d1ebd --- /dev/null +++ b/a2a-rs/src/adapter/transport/negotiation.rs @@ -0,0 +1,194 @@ +//! Client-side transport negotiation. +//! +//! A [`TransportFactory`] knows how to build a [`Transport`] for one wire +//! protocol from an agent interface. A [`TransportNegotiator`] holds an ordered +//! set of factories and, given an [`AgentCard`], picks the first interface it can +//! satisfy — ranked by **client preference** (factory registration order), which +//! dominates the card's own `preferred_transport`. +//! +//! This is composition-at-the-edge: the application assembles a negotiator with +//! exactly the transports it compiled in, then calls [`connect`] (or +//! [`TransportNegotiator::negotiate`]) to obtain a ready `Box`. + +use async_trait::async_trait; + +use crate::domain::{A2AError, AgentCard, AgentInterface}; +use crate::port::Transport; + +/// Builds a [`Transport`] for a single wire protocol from an agent interface. +#[async_trait] +pub trait TransportFactory: Send + Sync { + /// The protocol this factory handles, matching `AgentInterface::protocol_binding` + /// (e.g. `"JSONRPC"`, `"CONNECTRPC"`). + fn protocol(&self) -> &str; + + /// Construct a transport for `iface`. Returning `Err` lets the negotiator + /// fall through to the next compatible interface/factory. + async fn create( + &self, + card: &AgentCard, + iface: &AgentInterface, + ) -> Result, A2AError>; +} + +/// Factory for the wire-compatible JSON-RPC 2.0 transport. +#[cfg(feature = "jsonrpc-client")] +pub struct JsonRpcTransportFactory; + +#[cfg(feature = "jsonrpc-client")] +#[async_trait] +impl TransportFactory for JsonRpcTransportFactory { + fn protocol(&self) -> &str { + "JSONRPC" + } + + async fn create( + &self, + _card: &AgentCard, + iface: &AgentInterface, + ) -> Result, A2AError> { + Ok(Box::new(super::jsonrpc_client::JsonRpcClient::new( + iface.url.clone(), + ))) + } +} + +/// Factory for the ConnectRPC transport. +#[cfg(feature = "http-client")] +pub struct ConnectRpcTransportFactory; + +#[cfg(feature = "http-client")] +#[async_trait] +impl TransportFactory for ConnectRpcTransportFactory { + fn protocol(&self) -> &str { + "CONNECTRPC" + } + + async fn create( + &self, + _card: &AgentCard, + iface: &AgentInterface, + ) -> Result, A2AError> { + // `HttpClient::new` panics on an unparseable URL; validate first so a bad + // interface is a recoverable negotiation miss, not a crash. + iface.url.parse::().map_err(|e| { + A2AError::InvalidParams(format!("invalid interface url {}: {e}", iface.url)) + })?; + Ok(Box::new(super::http::HttpClient::new(iface.url.clone()))) + } +} + +/// An ordered registry of [`TransportFactory`]s that negotiates a transport from +/// an agent card. Registration order is the client's preference order. +#[derive(Default)] +pub struct TransportNegotiator { + factories: Vec>, +} + +impl TransportNegotiator { + /// An empty negotiator. Add factories with [`with`](Self::with). + pub fn new() -> Self { + Self::default() + } + + /// Register a factory (appended at lowest preference). + pub fn with(mut self, factory: impl TransportFactory + 'static) -> Self { + self.factories.push(Box::new(factory)); + self + } + + /// The protocols this negotiator can construct, in preference order. + pub fn supported(&self) -> impl Iterator { + self.factories.iter().map(|f| f.protocol()) + } + + /// Pick and construct the first transport that matches a card interface, + /// ranked by client preference (registration order). + pub async fn negotiate(&self, card: &AgentCard) -> Result, A2AError> { + for factory in &self.factories { + for iface in &card.supported_interfaces { + if iface.protocol_binding == factory.protocol() + && version_compatible(&iface.protocol_version) + { + match factory.create(card, iface).await { + Ok(transport) => return Ok(transport), + Err(_err) => continue, + } + } + } + } + Err(A2AError::UnsupportedOperation(format!( + "no compatible transport: client supports [{}], card offers [{}]", + self.supported().collect::>().join(", "), + card.supported_interfaces + .iter() + .map(|i| i.protocol_binding.as_str()) + .collect::>() + .join(", "), + ))) + } +} + +/// Permissive major-version check: accept the v1.x line (or an unspecified +/// version). A future major version on an interface is skipped, not errored. +fn version_compatible(version: &str) -> bool { + version.is_empty() || version.split('.').next() == Some("1") +} + +/// The default registry, holding every transport compiled into this build. +/// +/// Preference order is **CONNECTRPC then JSONRPC**: ConnectRPC is the in-tree, +/// first-class streaming transport, with JSON-RPC 2.0 as the interoperable +/// fallback. Flip the two `with` lines below for spec-default JSONRPC-first. +pub fn default_registry() -> TransportNegotiator { + #[allow(unused_mut)] + let mut negotiator = TransportNegotiator::new(); + #[cfg(feature = "http-client")] + { + negotiator = negotiator.with(ConnectRpcTransportFactory); + } + #[cfg(feature = "jsonrpc-client")] + { + negotiator = negotiator.with(JsonRpcTransportFactory); + } + negotiator +} + +/// Fetch an agent's card and negotiate a transport in one step. +/// +/// Fetches `/.well-known/agent-card.json` (falling back to `/agent-card`) from +/// `base_url`, then runs [`TransportNegotiator::negotiate`]. +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] +pub async fn connect( + base_url: &str, + negotiator: &TransportNegotiator, +) -> Result, A2AError> { + let card = fetch_agent_card(base_url).await?; + negotiator.negotiate(&card).await +} + +/// Fetch an [`AgentCard`] from the agent's well-known endpoint (plain HTTP GET). +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] +pub async fn fetch_agent_card(base_url: &str) -> Result { + use crate::adapter::error::HttpClientError; + + let client = reqwest::Client::new(); + let base = base_url.trim_end_matches('/'); + for path in [".well-known/agent-card.json", "agent-card"] { + let url = format!("{base}/{path}"); + let resp = client + .get(&url) + .send() + .await + .map_err(HttpClientError::Reqwest)?; + if resp.status().is_success() { + return resp + .json::() + .await + .map_err(|e| A2AError::Internal(format!("Failed to parse agent card JSON: {e}"))); + } + } + Err(A2AError::Internal(format!( + "Agent card not found at {base_url}" + ))) +} diff --git a/a2a-rs/src/adapter/transport/retry.rs b/a2a-rs/src/adapter/transport/retry.rs new file mode 100644 index 0000000..90bdf26 --- /dev/null +++ b/a2a-rs/src/adapter/transport/retry.rs @@ -0,0 +1,292 @@ +//! Resilient streaming over the [`Transport`] port: reconnect with exponential +//! backoff and resume via `Last-Event-ID`. +//! +//! This is the one place backoff lives. [`subscribe_resilient`] is the reusable +//! core — a free function that owns an `Arc` so the stream it +//! returns is `'static` and can re-subscribe after a disconnect, threading the +//! last observed event id back as `Last-Event-ID` so the server replays the gap. +//! [`RetryingTransport`] is a thin decorator that *is* a [`Transport`]: it passes +//! unary calls straight through and only wraps `subscribe_to_task`, so wrapping a +//! negotiated transport at the composition edge makes every existing call site +//! resilient with no signature change. +//! +//! # Spec note (A2A v1.0): this is an opt-in enhancement, not a spec feature +//! +//! The A2A protocol defines reconnection by re-issuing the subscribe call +//! (`SubscribeToTask`), which re-attaches from the task's *current* state; it +//! does **not** define `Last-Event-ID` gap-replay, and `SubscribeToTaskRequest` +//! has no resume field. The gap-free resumption here is an a2a-rs enhancement +//! built on the **W3C SSE-standard** `id:` field and `Last-Event-ID` header: +//! +//! - **Interop is preserved.** Against a spec-compliant server that ignores +//! `Last-Event-ID`, [`subscribe_resilient`] still reconnects via the spec's +//! subscribe call and resumes from current state — it simply can't replay the +//! gap. Against our own server it replays the missed tail. +//! - **It is not cross-SDK guaranteed.** Only an a2a-rs server honors our +//! `Last-Event-ID`; do not assume gap-free resume against third-party agents. +//! +//! For a strictly spec-shaped single subscribe (no reconnection, no +//! `Last-Event-ID`), call [`Transport::subscribe_to_task`] directly with +//! `last_event_id = None` instead of using this module. + +use std::pin::Pin; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use crate::domain::core::task::TaskStateExt; +use crate::domain::{ + A2AError, ListTasksParams, ListTasksResult, Message, RetryPolicy, Task, + TaskPushNotificationConfig, +}; +use crate::port::{StreamEvent, StreamItem, Transport}; + +type EventStream = Pin> + Send>>; + +/// Subscribe to a task's updates with automatic reconnect + backoff. +/// +/// The returned stream forwards [`StreamEvent`]s until the task reaches a +/// terminal state, recording each event id along the way. On a disconnect (the +/// inner stream errors or ends without a terminal status) it sleeps per `policy` +/// and re-subscribes, passing the last seen id as `Last-Event-ID` so a resumable +/// server replays the missed tail. After `policy.max_retries` consecutive failed +/// reconnects it yields a final error and ends. +pub fn subscribe_resilient( + transport: Arc, + task_id: impl Into, + history_length: Option, + last_event_id: Option, + policy: RetryPolicy, +) -> EventStream { + let task_id = task_id.into(); + let seed = seed_for(&task_id); + + struct State { + transport: Arc, + task_id: String, + history_length: Option, + policy: RetryPolicy, + seed: u64, + last_event_id: Option, + attempt: u32, + inner: Option, + done: bool, + } + + let state = State { + transport, + task_id, + history_length, + policy, + seed, + last_event_id, + attempt: 0, + inner: None, + done: false, + }; + + Box::pin(futures::stream::unfold(state, |mut st| async move { + loop { + if st.done { + return None; + } + + // (Re)connect when we have no live inner stream. + if st.inner.is_none() { + if st.attempt > st.policy.max_retries { + st.done = true; + return Some(( + Err(A2AError::Internal(format!( + "subscription to '{}' failed after {} retries", + st.task_id, st.policy.max_retries + ))), + st, + )); + } + if st.attempt > 0 { + let delay = st.policy.backoff(st.attempt, st.seed); + tokio::time::sleep(delay).await; + } + let resume = st.last_event_id.map(|n| n.to_string()); + match st + .transport + .subscribe_to_task(&st.task_id, st.history_length, resume.as_deref()) + .await + { + Ok(stream) => st.inner = Some(stream), + Err(_) => { + st.attempt += 1; + continue; + } + } + } + + // Pull the next event from the live inner stream. + match st.inner.as_mut().unwrap().next().await { + Some(Ok(event)) => { + // Any progress resets the backoff counter. + st.attempt = 0; + if let Some(id) = event.event_id { + st.last_event_id = Some(id); + } + if is_terminal(&event.item) { + st.done = true; + } + return Some((Ok(event), st)); + } + // Stream errored or ended without a terminal status: reconnect. + Some(Err(_)) | None => { + st.inner = None; + st.attempt += 1; + continue; + } + } + } + })) +} + +/// Whether a stream item represents a terminal task state (ends the stream). +fn is_terminal(item: &StreamItem) -> bool { + match item { + StreamItem::Task(task) => task + .status + .as_option() + .map(|s| s.state.is_terminal()) + .unwrap_or(false), + StreamItem::StatusUpdate(event) => event.status.state.is_terminal(), + StreamItem::ArtifactUpdate(_) => false, + } +} + +/// Derive a per-task jitter seed from the task id and a coarse time sample. The +/// time sample (the only impure input) lives here in the adapter, keeping +/// [`RetryPolicy::backoff`] pure. +fn seed_for(task_id: &str) -> u64 { + let mut state = 0u64; + for &b in task_id.as_bytes() { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(b as u64); + } + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos() as u64) + .unwrap_or(0); + state.wrapping_mul(6364136223846793005).wrapping_add(now) +} + +/// A [`Transport`] decorator that adds reconnect + backoff to `subscribe_to_task` +/// and passes every unary method straight through to the inner transport. +/// +/// Wrap a negotiated transport once at the composition edge — +/// `RetryingTransport::wrap(connect(...).await?, policy)` — and all callers gain +/// resilient streaming transparently. +pub struct RetryingTransport { + inner: Arc, + policy: RetryPolicy, +} + +impl RetryingTransport { + /// Decorate a shared transport with a retry policy. + pub fn new(inner: Arc, policy: RetryPolicy) -> Self { + Self { inner, policy } + } + + /// Decorate an owned (e.g. negotiated `Box`) transport. + pub fn wrap(inner: Box, policy: RetryPolicy) -> Self { + Self { + inner: Arc::from(inner), + policy, + } + } +} + +#[async_trait] +impl Transport for RetryingTransport { + fn protocol(&self) -> &str { + self.inner.protocol() + } + + async fn send_task_message( + &self, + task_id: &str, + message: &Message, + session_id: Option<&str>, + history_length: Option, + ) -> Result { + self.inner + .send_task_message(task_id, message, session_id, history_length) + .await + } + + async fn get_task(&self, task_id: &str, history_length: Option) -> Result { + self.inner.get_task(task_id, history_length).await + } + + async fn cancel_task(&self, task_id: &str) -> Result { + self.inner.cancel_task(task_id).await + } + + async fn set_task_push_notification( + &self, + config: &TaskPushNotificationConfig, + ) -> Result { + self.inner.set_task_push_notification(config).await + } + + async fn get_task_push_notification( + &self, + task_id: &str, + ) -> Result { + self.inner.get_task_push_notification(task_id).await + } + + async fn list_tasks(&self, params: &ListTasksParams) -> Result { + self.inner.list_tasks(params).await + } + + async fn list_push_notification_configs( + &self, + task_id: &str, + ) -> Result, A2AError> { + self.inner.list_push_notification_configs(task_id).await + } + + async fn get_push_notification_config( + &self, + task_id: &str, + config_id: &str, + ) -> Result { + self.inner + .get_push_notification_config(task_id, config_id) + .await + } + + async fn delete_push_notification_config( + &self, + task_id: &str, + config_id: &str, + ) -> Result<(), A2AError> { + self.inner + .delete_push_notification_config(task_id, config_id) + .await + } + + async fn subscribe_to_task( + &self, + task_id: &str, + history_length: Option, + last_event_id: Option<&str>, + ) -> Result> + Send>>, A2AError> { + let resume = last_event_id.and_then(|s| s.trim().parse::().ok()); + Ok(subscribe_resilient( + self.inner.clone(), + task_id.to_string(), + history_length, + resume, + self.policy, + )) + } +} diff --git a/a2a-rs/src/application/mod.rs b/a2a-rs/src/application/mod.rs index 59aec1b..584c1a9 100644 --- a/a2a-rs/src/application/mod.rs +++ b/a2a-rs/src/application/mod.rs @@ -1 +1,13 @@ //! Application services for the A2A protocol + +#[cfg(feature = "server")] +pub mod task_service; +#[cfg(feature = "server")] +pub mod task_status_broadcast; + +#[cfg(feature = "server")] +pub use task_service::{TaskService, UpdateStream}; +#[cfg(feature = "server")] +pub use task_status_broadcast::{ + HasPushNotifier, HasStreaming, HasTaskLifecycle, TaskStatusBroadcast, +}; diff --git a/a2a-rs/src/application/task_service.rs b/a2a-rs/src/application/task_service.rs new file mode 100644 index 0000000..9764306 --- /dev/null +++ b/a2a-rs/src/application/task_service.rs @@ -0,0 +1,297 @@ +//! The task application service: use-case orchestration over the port traits. +//! +//! `TaskService` is the **inner** half of the service/transport split: it owns +//! the ports (`Arc`), orchestrates them, and speaks only the domain +//! vocabulary (`Task`, `Message`, `TaskId`, `A2AError`). It knows nothing about +//! ConnectRPC, `buffa` views, or wire error codes — that glue lives in the +//! transport adapter ([`ConnectRpcAdapter`](crate::adapter::ConnectRpcAdapter)), +//! which decodes wire requests into these domain calls and re-encodes the +//! results. +//! +//! Because the service holds both the lifecycle and streaming ports it exposes +//! them as mixin ingredients ([`HasTaskLifecycle`], [`HasStreaming`]) and so +//! gains [`TaskStatusBroadcast::update_and_broadcast`] for free +//! (`.claude/rules/hexagonal_architecture.md` §9). The accessors return `&dyn` +//! **ports**, never the concrete adapters behind them, so the dependency arrow +//! still points inward. +//! +//! [`TaskStatusBroadcast::update_and_broadcast`]: crate::application::TaskStatusBroadcast::update_and_broadcast + +use std::pin::Pin; +use std::sync::Arc; + +use futures::Stream; + +use crate::application::{HasPushNotifier, HasStreaming, HasTaskLifecycle, TaskStatusBroadcast}; +use crate::domain::{ + A2AError, AgentCard, DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, ListTaskPushNotificationConfigsParams, ListTasksParams, + ListTasksResult, Message, Task, TaskId, TaskPushNotificationConfig, +}; +use crate::port::{ + AsyncMessageHandler, AsyncNotificationManager, AsyncNotificationManagerExt, AsyncPushNotifier, + AsyncStreamingHandler, AsyncTaskLifecycle, AsyncTaskQuery, SeqEvent, +}; +use crate::services::server::AgentInfoProvider; + +/// A stream of sequenced update events for a task. Each [`SeqEvent`] carries a +/// per-task monotonic id (surfaced as the SSE `id:` field); the transport +/// adapter maps the inner update onto its wire representation. +pub type UpdateStream = Pin> + Send>>; + +/// Use-case orchestration over the A2A ports. +/// +/// Constructed at the composition edge with concrete adapters injected; the +/// fields are `Arc` so the service type carries no generic parameters. +/// All methods return domain types and [`A2AError`] — there is no transport +/// vocabulary in this layer. +#[derive(Clone)] +pub struct TaskService { + message_handler: Arc, + task_lifecycle: Arc, + task_query: Arc, + notification_manager: Arc, + agent_info: Arc, + streaming_handler: Arc, + push_notifier: Arc, +} + +impl TaskService { + /// Assemble a service from separate handlers. + /// + /// `tasks` supplies both the lifecycle and query capabilities; it is + /// stored once and shared between the two `Arc` fields. + pub fn new( + message_handler: impl AsyncMessageHandler + 'static, + tasks: impl AsyncTaskLifecycle + AsyncTaskQuery + 'static, + notification_manager: impl AsyncNotificationManager + 'static, + agent_info: impl AgentInfoProvider + 'static, + streaming_handler: impl AsyncStreamingHandler + 'static, + push_notifier: impl AsyncPushNotifier + 'static, + ) -> Self { + let tasks = Arc::new(tasks); + Self { + message_handler: Arc::new(message_handler), + task_lifecycle: tasks.clone(), + task_query: tasks, + notification_manager: Arc::new(notification_manager), + agent_info: Arc::new(agent_info), + streaming_handler: Arc::new(streaming_handler), + push_notifier: Arc::new(push_notifier), + } + } + + /// Assemble a service from a single handler that implements every port. + pub fn with_handler( + handler: impl AsyncMessageHandler + + AsyncTaskLifecycle + + AsyncTaskQuery + + AsyncNotificationManager + + 'static, + agent_info: impl AgentInfoProvider + 'static, + streaming_handler: impl AsyncStreamingHandler + 'static, + push_notifier: impl AsyncPushNotifier + 'static, + ) -> Self { + let handler = Arc::new(handler); + Self { + message_handler: handler.clone(), + task_lifecycle: handler.clone(), + task_query: handler.clone(), + notification_manager: handler, + agent_info: Arc::new(agent_info), + streaming_handler: Arc::new(streaming_handler), + push_notifier: Arc::new(push_notifier), + } + } + + /// Replace the streaming handler, returning the updated service. + pub fn with_streaming_handler( + mut self, + streaming_handler: impl AsyncStreamingHandler + 'static, + ) -> Self { + self.streaming_handler = Arc::new(streaming_handler); + self + } + + /// Replace the push notifier, returning the updated service. + pub fn with_push_notifier(mut self, push_notifier: impl AsyncPushNotifier + 'static) -> Self { + self.push_notifier = Arc::new(push_notifier); + self + } + + /// Process a message for a task, optionally configuring push notifications + /// and limiting the returned history. + pub async fn send_message( + &self, + task_id: &str, + message: &Message, + session_id: Option<&str>, + push_config: Option, + history_limit: Option, + ) -> Result { + if let Some(mut push_config) = push_config { + push_config.task_id = task_id.to_string(); + self.notification_manager + .set_validated(&push_config) + .await?; + } + + let mut task = self + .message_handler + .process_message(task_id, message, session_id) + .await?; + + if let Some(limit) = history_limit { + task = task.with_limited_history(Some(limit)); + } + + Ok(task) + } + + /// Process a message and subscribe to its update stream. + /// + /// The update stream is started **before** the message is processed so no + /// early updates are missed. Returns the initial task and the stream; the + /// caller is responsible for emitting the initial task ahead of stream + /// items. + pub async fn send_streaming_message( + &self, + task_id: &str, + message: &Message, + session_id: Option<&str>, + push_config: Option, + history_limit: Option, + ) -> Result<(Task, UpdateStream), A2AError> { + if let Some(mut push_config) = push_config { + push_config.task_id = task_id.to_string(); + self.notification_manager + .set_validated(&push_config) + .await?; + } + + // Start updates stream first so we don't miss early updates. + let update_stream = self + .streaming_handler + .start_task_streaming(task_id, None) + .await?; + + let mut task = self + .message_handler + .process_message(task_id, message, session_id) + .await?; + + if let Some(limit) = history_limit { + task = task.with_limited_history(Some(limit)); + } + + Ok((task, update_stream)) + } + + /// Get a task by ID with optional history length limit. + pub async fn get(&self, id: &TaskId, history_length: Option) -> Result { + self.task_lifecycle.get(id, history_length).await + } + + /// List tasks with filtering and pagination. + pub async fn list(&self, params: &ListTasksParams) -> Result { + self.task_query.list(params).await + } + + /// Cancel a task, then announce the terminal status to streaming + /// subscribers. + /// + /// Storage no longer self-broadcasts on cancellation (§4.0.2), so the + /// service owns the "commit then announce" step via the + /// [`TaskStatusBroadcast`] mixin it hosts. + pub async fn cancel(&self, id: &TaskId) -> Result { + self.cancel_and_broadcast(id).await + } + + /// Subscribe to a task's update stream, returning the current task (if it + /// exists) and the stream of subsequent updates. + /// + /// `from_event_id` carries a client's `Last-Event-ID` for resumption: when + /// set, the handler replays buffered events with a greater id before + /// streaming live updates. + pub async fn subscribe( + &self, + task_id: &str, + from_event_id: Option, + ) -> Result<(Option, UpdateStream), A2AError> { + let id: TaskId = task_id.parse()?; + + let initial_task = match self.task_lifecycle.get(&id, None).await { + Ok(task) => Some(task), + Err(A2AError::TaskNotFound(_)) => None, + Err(e) => return Err(e), + }; + + let update_stream = self + .streaming_handler + .start_task_streaming(task_id, from_event_id) + .await?; + + Ok((initial_task, update_stream)) + } + + /// Create or replace a push-notification config (validated). + pub async fn set_push_config( + &self, + config: &TaskPushNotificationConfig, + ) -> Result { + self.notification_manager.set_validated(config).await + } + + /// Get a push-notification config for a task. + pub async fn get_push_config( + &self, + params: &GetTaskPushNotificationConfigParams, + ) -> Result { + self.notification_manager.get_config(params).await + } + + /// List push-notification configs for a task. + pub async fn list_push_configs( + &self, + params: &ListTaskPushNotificationConfigsParams, + ) -> Result, A2AError> { + self.notification_manager.list_configs(params).await + } + + /// Delete a push-notification config. + pub async fn delete_push_config( + &self, + params: &DeleteTaskPushNotificationConfigParams, + ) -> Result<(), A2AError> { + self.notification_manager.delete_config(params).await + } + + /// Fetch the authenticated extended agent card. + pub async fn extended_agent_card(&self) -> Result { + self.agent_info.get_authenticated_extended_card().await + } +} + +// The service is the composed assembly holding both the lifecycle and streaming +// ports, so it exposes them as mixin ingredients (see +// `.claude/rules/hexagonal_architecture.md` §9). This grants it the +// `TaskStatusBroadcast::update_and_broadcast` "commit then announce" capability +// for free, without coupling either port to the other. The accessors return +// `&dyn` **ports**, never the concrete adapters behind them. +impl HasTaskLifecycle for TaskService { + fn lifecycle(&self) -> &dyn AsyncTaskLifecycle { + self.task_lifecycle.as_ref() + } +} + +impl HasStreaming for TaskService { + fn streaming(&self) -> &dyn AsyncStreamingHandler { + self.streaming_handler.as_ref() + } +} + +impl HasPushNotifier for TaskService { + fn push_notifier(&self) -> &dyn AsyncPushNotifier { + self.push_notifier.as_ref() + } +} diff --git a/a2a-rs/src/application/task_status_broadcast.rs b/a2a-rs/src/application/task_status_broadcast.rs new file mode 100644 index 0000000..cb336b6 --- /dev/null +++ b/a2a-rs/src/application/task_status_broadcast.rs @@ -0,0 +1,349 @@ +//! Cross-port orchestration: update a task's status *and* broadcast it. +//! +//! This is the [capability-mixin] pattern applied at the port boundary +//! (`.claude/rules/hexagonal_architecture.md` §9). Two narrow **accessor** +//! ingredients ([`HasTaskLifecycle`], [`HasStreaming`]) expose the ports a host +//! already holds; the [`TaskStatusBroadcast`] mixin provides the derived +//! "update then broadcast" behavior as a blanket-impl'd default. Any assembly +//! that exposes both ports — the request processor, the MCP bridge, a test +//! rig — gains `update_and_broadcast` for free, and on nothing inner. +//! +//! Why a mixin and not just a method on the processor: the orchestration is +//! defined independently of any one struct (reusable across hosts) and is +//! testable against a minimal rig that wires only these two ports over +//! in-memory adapters — see the tests below. +//! +//! [capability-mixin]: crate::port +//! +//! ## Layering note +//! +//! The accessor associated returns are bounded by **port traits** +//! (`&dyn AsyncTaskLifecycle`, `&dyn AsyncStreamingHandler`), never concrete +//! adapters, and the mixin default touches only those ports plus pure domain +//! constructors (`TaskStatus::new`). The dependency arrow therefore still +//! points inward even though the logic lives in a blanket impl. + +use async_trait::async_trait; + +use crate::domain::{ + A2AError, Message, Task, TaskArtifactUpdateEvent, TaskId, TaskState, TaskStatusUpdateEvent, +}; +use crate::port::{AsyncPushNotifier, AsyncStreamingHandler, AsyncTaskLifecycle}; + +/// Ingredient: an assembly that can hand out a task-lifecycle port. +/// +/// Note the return is a `&dyn` **port**, not a concrete adapter — that is what +/// keeps any mixin built on this ingredient inside the dependency rule. +pub trait HasTaskLifecycle { + fn lifecycle(&self) -> &dyn AsyncTaskLifecycle; +} + +/// Ingredient: an assembly that can hand out a streaming port. +pub trait HasStreaming { + fn streaming(&self) -> &dyn AsyncStreamingHandler; +} + +/// Ingredient: an assembly that can hand out a push-notifier port. +/// +/// Kept separate from [`HasStreaming`] on purpose: in-process streaming fan-out +/// and out-of-band webhook delivery are distinct capabilities with distinct +/// backends, so the mixin orchestrates both rather than fusing them into one +/// adapter. +pub trait HasPushNotifier { + fn push_notifier(&self) -> &dyn AsyncPushNotifier; +} + +/// Derived capability: mutate task status through the lifecycle port, then +/// broadcast the resulting status to streaming subscribers. +/// +/// Blanket-implemented for every `Send + Sync` host that exposes both +/// ingredients, so it never needs an explicit `impl`. A host that exposes only +/// one ingredient does **not** get this method — that omission is a compile +/// error at the call site, not a runtime surprise (see the `compile_fail` doc +/// test on [`update_and_broadcast`]). +/// +/// [`update_and_broadcast`]: TaskStatusBroadcast::update_and_broadcast +#[async_trait] +pub trait TaskStatusBroadcast: + HasTaskLifecycle + HasStreaming + HasPushNotifier + Send + Sync +{ + /// Update a task's status, then broadcast the new status to subscribers. + /// + /// The broadcast is best-effort relative to the store: the status is + /// persisted first (via the lifecycle port) and only then announced, so a + /// subscriber never sees a state the store hasn't committed. + /// + /// A host that exposes only *one* of the two ingredients does not get this + /// method — the missing supertrait makes the blanket impl inapplicable, so + /// the call fails to compile: + /// + /// ```compile_fail + /// use std::sync::Arc; + /// use a2a_rs::AsyncTaskLifecycle; + /// use a2a_rs::adapter::storage::InMemoryTaskStorage; + /// use a2a_rs::application::{HasTaskLifecycle, TaskStatusBroadcast}; + /// use a2a_rs::domain::{TaskId, TaskState}; + /// + /// // Exposes the lifecycle ingredient, but NOT `HasStreaming`. + /// struct HalfRig { + /// store: Arc, + /// } + /// impl HasTaskLifecycle for HalfRig { + /// fn lifecycle(&self) -> &dyn AsyncTaskLifecycle { + /// self.store.as_ref() + /// } + /// } + /// + /// async fn use_it(rig: HalfRig, id: TaskId) { + /// // `update_and_broadcast` does not exist on a one-ingredient host: + /// rig.update_and_broadcast(&id, TaskState::Completed, None).await.unwrap(); + /// } + /// ``` + async fn update_and_broadcast( + &self, + id: &TaskId, + state: TaskState, + message: Option, + ) -> Result { + let task = self.lifecycle().update_status(id, state, message).await?; + self.broadcast_current_status(id, &task).await?; + Ok(task) + } + + /// Cancel a task through the lifecycle port, then broadcast the resulting + /// (terminal) status to subscribers. + /// + /// The counterpart to [`update_and_broadcast`](Self::update_and_broadcast) + /// for cancellation: `cancel` carries its own state transition and history + /// message, so it cannot be expressed as an `update_status` call, but the + /// "commit then announce" ordering is identical. + async fn cancel_and_broadcast(&self, id: &TaskId) -> Result { + let task = self.lifecycle().cancel(id).await?; + self.broadcast_current_status(id, &task).await?; + Ok(task) + } + + /// Broadcast an artifact update: fan it out to streaming subscribers, then + /// deliver it to the task's push endpoint (best-effort). + /// + /// The artifact counterpart to the status path. Hosts that produce artifacts + /// route through here so streaming and push stay consistent — exactly as the + /// status mutators do via [`broadcast_current_status`](Self::broadcast_current_status). + async fn broadcast_artifact( + &self, + id: &TaskId, + event: TaskArtifactUpdateEvent, + ) -> Result<(), A2AError> { + self.streaming() + .broadcast_artifact_update(id.as_str(), event.clone()) + .await?; + self.notify_push_artifact(id, &event).await; + Ok(()) + } + + /// Announce a task's current status to streaming subscribers, then deliver a + /// push notification (best-effort). + /// + /// Shared by the mutate-then-broadcast methods above; not intended to be + /// overridden. The event is built from the freshly-committed `task` so the + /// announcement always reflects what the store now holds. Push delivery is + /// best-effort: a webhook that is down is logged but does not fail the + /// mutation that triggered it. + #[doc(hidden)] + async fn broadcast_current_status(&self, id: &TaskId, task: &Task) -> Result<(), A2AError> { + let event = TaskStatusUpdateEvent { + task_id: task.id.clone(), + context_id: task.context_id.clone(), + kind: "status-update".to_string(), + status: task.status.clone().into_option().unwrap_or_default(), + metadata: None, + }; + + self.streaming() + .broadcast_status_update(id.as_str(), event.clone()) + .await?; + self.notify_push_status(id, &event).await; + Ok(()) + } + + /// Deliver a status push notification, swallowing (and logging) any delivery + /// error so it never fails the mutation. + #[doc(hidden)] + async fn notify_push_status(&self, id: &TaskId, event: &TaskStatusUpdateEvent) { + if let Err(_e) = self.push_notifier().notify_status(id.as_str(), event).await { + #[cfg(feature = "tracing")] + tracing::warn!(task_id = %id.as_str(), error = %_e, "push status notification failed"); + } + } + + /// Deliver an artifact push notification, swallowing (and logging) any + /// delivery error. + #[doc(hidden)] + async fn notify_push_artifact(&self, id: &TaskId, event: &TaskArtifactUpdateEvent) { + if let Err(_e) = self + .push_notifier() + .notify_artifact(id.as_str(), event) + .await + { + #[cfg(feature = "tracing")] + tracing::warn!(task_id = %id.as_str(), error = %_e, "push artifact notification failed"); + } + } +} + +/// The single blanket impl — the linchpin of the pattern. `?Sized` lets the +/// mixin attach to a `dyn`-typed host as well as a concrete one. +impl + TaskStatusBroadcast for T +{ +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::adapter::storage::InMemoryTaskStorage; + use crate::adapter::streaming::InMemoryStreamingHandler; + use crate::port::NoopPushNotifier; + use crate::port::streaming_handler::Subscriber; + use std::sync::{Arc, Mutex}; + + /// A "partial platform" test rig: it wires the three ingredients this mixin + /// needs — a persistence adapter, a separate streaming adapter, and a push + /// notifier — over in-memory implementations. Standing this up requires + /// neither the transport layer nor the full request processor, so the + /// orchestration is tested in isolation. The split between `store` and + /// `streaming` is the whole point: they are distinct ports now. + struct BroadcastRig { + store: Arc, + streaming: InMemoryStreamingHandler, + push: NoopPushNotifier, + } + + impl HasTaskLifecycle for BroadcastRig { + fn lifecycle(&self) -> &dyn AsyncTaskLifecycle { + self.store.as_ref() + } + } + + impl HasStreaming for BroadcastRig { + fn streaming(&self) -> &dyn AsyncStreamingHandler { + &self.streaming + } + } + + impl HasPushNotifier for BroadcastRig { + fn push_notifier(&self) -> &dyn AsyncPushNotifier { + &self.push + } + } + + /// A streaming subscriber that records every status it is handed, so a test + /// can assert exactly which transitions reached subscribers. + #[derive(Clone, Default)] + struct Recorder { + states: Arc>>>, + } + + #[async_trait] + impl Subscriber for Recorder { + async fn on_update(&self, update: TaskStatusUpdateEvent) -> Result<(), A2AError> { + self.states.lock().unwrap().push(update.status.state); + Ok(()) + } + } + + fn rig(store: Arc) -> BroadcastRig { + BroadcastRig { + store, + streaming: InMemoryStreamingHandler::new(), + push: NoopPushNotifier, + } + } + + #[tokio::test] + async fn update_and_broadcast_persists_then_announces() { + let store = Arc::new(InMemoryTaskStorage::new()); + let id = TaskId::try_from("task-1").unwrap(); + let ctx = crate::domain::ContextId::try_from("ctx-1").unwrap(); + + store.create(&id, &ctx).await.unwrap(); + store + .update_status(&id, TaskState::Working, None) + .await + .unwrap(); + + let rig = rig(store); + + // The mixin method exists purely because the rig exposes ALL ingredients. + let task = rig + .update_and_broadcast(&id, TaskState::Completed, None) + .await + .unwrap(); + + assert_eq!(task.status.state, TaskState::Completed); + } + + /// A direct lifecycle mutation must NOT announce anything: persistence and + /// streaming are fully separate adapters now. The subscriber lives on the + /// streaming handler, which the bare store mutation never touches, so the + /// recorder stays empty. + #[tokio::test] + async fn bare_update_status_does_not_broadcast() { + let store = Arc::new(InMemoryTaskStorage::new()); + let id = TaskId::try_from("task-1").unwrap(); + let ctx = crate::domain::ContextId::try_from("ctx-1").unwrap(); + + let streaming = InMemoryStreamingHandler::new(); + let recorder = Recorder::default(); + streaming + .add_status_subscriber(id.as_str(), Box::new(recorder.clone())) + .await + .unwrap(); + + store.create(&id, &ctx).await.unwrap(); + store + .update_status(&id, TaskState::Working, None) + .await + .unwrap(); + store.cancel(&id).await.unwrap(); + + assert!( + recorder.states.lock().unwrap().is_empty(), + "storage mutators must not self-broadcast" + ); + } + + /// Routed through the mixin, the same mutations DO reach subscribers — once + /// each, in order. (One announcement per call proves there is no lingering + /// self-broadcast doubling the events.) The recorder is registered on the + /// rig's *streaming* handler, which the mixin fans out to. + #[tokio::test] + async fn mixin_announces_each_mutation_once() { + let store = Arc::new(InMemoryTaskStorage::new()); + let id = TaskId::try_from("task-1").unwrap(); + let ctx = crate::domain::ContextId::try_from("ctx-1").unwrap(); + + store.create(&id, &ctx).await.unwrap(); + + let rig = rig(store); + + let recorder = Recorder::default(); + rig.streaming + .add_status_subscriber(id.as_str(), Box::new(recorder.clone())) + .await + .unwrap(); + + rig.update_and_broadcast(&id, TaskState::Working, None) + .await + .unwrap(); + rig.cancel_and_broadcast(&id).await.unwrap(); + + assert_eq!( + *recorder.states.lock().unwrap(), + vec![ + ::buffa::EnumValue::from(TaskState::Working), + ::buffa::EnumValue::from(TaskState::Canceled), + ], + ); + } +} diff --git a/a2a-rs/src/domain/core/mod.rs b/a2a-rs/src/domain/core/mod.rs index 8513e89..6909faa 100644 --- a/a2a-rs/src/domain/core/mod.rs +++ b/a2a-rs/src/domain/core/mod.rs @@ -15,5 +15,5 @@ pub use task::{ DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, ListTaskPushNotificationConfigsParams, ListTasksParams, ListTasksResult, MessageSendConfiguration, MessageSendParams, Task, TaskIdParams, TaskPushNotificationConfig, - TaskQueryParams, TaskSendParams, TaskState, TaskStateExt, TaskStatus, + TaskQueryParams, TaskSendParams, TaskState, TaskStateExt, TaskStatus, VersionedTask, }; diff --git a/a2a-rs/src/domain/core/task.rs b/a2a-rs/src/domain/core/task.rs index a9d88ca..631abd8 100644 --- a/a2a-rs/src/domain/core/task.rs +++ b/a2a-rs/src/domain/core/task.rs @@ -392,9 +392,9 @@ impl Task { tracing::debug!("Validating task"); let mut message_ids = std::collections::HashSet::new(); - for (index, message) in self.history.iter().enumerate() { + for (_index, message) in self.history.iter().enumerate() { #[cfg(feature = "tracing")] - tracing::trace!("Validating message {} in history", index); + tracing::trace!("Validating message {} in history", _index); if !message_ids.insert(&message.message_id) { #[cfg(feature = "tracing")] @@ -420,3 +420,27 @@ impl Task { Ok(()) } } + +/// A task paired with its storage version — the optimistic-concurrency token. +/// +/// The version is a monotonic counter the storage adapter bumps on every +/// successful mutation of the task. A caller reads a task and its version, then +/// passes that version back on a conditional update +/// ([`AsyncTaskVersioning::update_status_checked`](crate::port::AsyncTaskVersioning::update_status_checked)); +/// if another writer advanced the task in between, the update fails with +/// [`A2AError::VersionConflict`](crate::domain::A2AError::VersionConflict) instead +/// of silently clobbering it. +#[derive(Debug, Clone, PartialEq)] +pub struct VersionedTask { + /// The task at this version. + pub task: Task, + /// The storage version this snapshot was read or written at. + pub version: u64, +} + +impl VersionedTask { + /// Pair a task with a version. + pub fn new(task: Task, version: u64) -> Self { + Self { task, version } + } +} diff --git a/a2a-rs/src/domain/error.rs b/a2a-rs/src/domain/error.rs index 8f43710..5539d83 100644 --- a/a2a-rs/src/domain/error.rs +++ b/a2a-rs/src/domain/error.rs @@ -1,5 +1,13 @@ use thiserror::Error; +use crate::domain::error_details::{ErrorDetail, FieldViolation}; + +/// Convenience alias for results that fail with [`A2AError`]. +/// +/// Mirrors the `std::io::Result` / `serde_json::Result` convention so call +/// sites can write `Result` instead of `Result`. +pub type Result = std::result::Result; + /// Standard JSON-RPC error codes pub const PARSE_ERROR: i32 = -32700; pub const INVALID_REQUEST: i32 = -32600; @@ -18,6 +26,8 @@ pub const AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED: i32 = -32007; /// Custom application-specific error codes (outside spec range) pub const DATABASE_ERROR: i32 = -32100; +/// Optimistic-concurrency version mismatch on a task mutation. +pub const VERSION_CONFLICT: i32 = -32101; /// Error type for the A2A protocol operations #[derive(Error, Debug)] @@ -68,6 +78,13 @@ pub enum A2AError { #[error("Validation error in {field}: {message}")] ValidationError { field: String, message: String }, + #[error("Version conflict for task {id}: expected {expected}, found {actual}")] + VersionConflict { + id: String, + expected: u64, + actual: u64, + }, + #[error("Database error: {0}")] DatabaseError(String), @@ -101,6 +118,7 @@ impl A2AError { "Authenticated Extended Card is not configured", ), A2AError::ValidationError { .. } => (INVALID_PARAMS, "Validation error"), + A2AError::VersionConflict { .. } => (VERSION_CONFLICT, "Task version conflict"), A2AError::DatabaseError(_) => (DATABASE_ERROR, "Database error"), A2AError::Internal(_) => (INTERNAL_ERROR, "Internal error"), _ => (INTERNAL_ERROR, "Internal error"), @@ -112,4 +130,61 @@ impl A2AError { "data": null, }) } + + /// Stable, machine-readable reason code for this error + /// (`SCREAMING_SNAKE_CASE`, used as the `ErrorInfo.reason` on the wire). + pub fn reason_code(&self) -> &'static str { + match self { + A2AError::JsonRpc { .. } => "JSON_RPC_ERROR", + A2AError::JsonParse(_) => "PARSE_ERROR", + A2AError::InvalidRequest(_) => "INVALID_REQUEST", + A2AError::InvalidParams(_) => "INVALID_PARAMS", + A2AError::MethodNotFound(_) => "METHOD_NOT_FOUND", + A2AError::TaskNotFound(_) => "TASK_NOT_FOUND", + A2AError::TaskNotCancelable(_) => "TASK_NOT_CANCELABLE", + A2AError::PushNotificationNotSupported => "PUSH_NOTIFICATION_NOT_SUPPORTED", + A2AError::UnsupportedOperation(_) => "UNSUPPORTED_OPERATION", + A2AError::ContentTypeNotSupported(_) => "CONTENT_TYPE_NOT_SUPPORTED", + A2AError::InvalidAgentResponse(_) => "INVALID_AGENT_RESPONSE", + A2AError::AuthenticatedExtendedCardNotConfigured => { + "AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED" + } + A2AError::Internal(_) => "INTERNAL_ERROR", + A2AError::ValidationError { .. } => "VALIDATION_ERROR", + A2AError::VersionConflict { .. } => "VERSION_CONFLICT", + A2AError::DatabaseError(_) => "DATABASE_ERROR", + A2AError::Io(_) => "IO_ERROR", + } + } + + /// Typed details for the JSON-RPC `error.data` array. + /// + /// Validation failures surface as a Google-RPC `BadRequest` with field + /// violations; version conflicts attach the expected/actual versions as + /// `ErrorInfo` metadata; every other variant carries at least its stable + /// [`reason_code`](Self::reason_code) as an `ErrorInfo`, so a client can + /// branch on a machine code instead of parsing the message string. + pub fn error_details(&self) -> Vec { + match self { + A2AError::ValidationError { field, message } => vec![ + ErrorDetail::BadRequest { + field_violations: vec![FieldViolation::new(field, message)], + }, + ErrorDetail::reason(self.reason_code()), + ], + A2AError::VersionConflict { + id, + expected, + actual, + } => { + let mut info = crate::domain::error_details::ErrorInfo::new(self.reason_code()); + info = info + .with_metadata("task_id", id) + .with_metadata("expected", expected.to_string()) + .with_metadata("actual", actual.to_string()); + vec![ErrorDetail::ErrorInfo(info)] + } + _ => vec![ErrorDetail::reason(self.reason_code())], + } + } } diff --git a/a2a-rs/src/domain/error_details.rs b/a2a-rs/src/domain/error_details.rs new file mode 100644 index 0000000..f82fec0 --- /dev/null +++ b/a2a-rs/src/domain/error_details.rs @@ -0,0 +1,108 @@ +//! Typed error details, surfaced in the JSON-RPC `error.data` array. +//! +//! The A2A spec (following the Go/C#/Python SDKs) carries machine-readable error +//! details as a list of Google-RPC `Any`-shaped objects in `error.data`. Each +//! entry is tagged by an `@type` URL (`type.googleapis.com/google.rpc.*`). We +//! model the two A2A actually uses — [`ErrorDetail::BadRequest`] for field-level +//! validation failures and [`ErrorDetail::ErrorInfo`] for a stable machine +//! `reason` code — as plain serde types so the adapter layer can attach them +//! without hand-writing JSON, and the client can round-trip them back. +//! +//! These are pure domain value objects: no I/O, no framework types. The +//! [`A2AError`](crate::domain::A2AError) ⇄ wire mapping lives in the transport +//! adapter; [`A2AError::error_details`](crate::domain::A2AError::error_details) +//! derives the default detail set for each variant. + +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; + +/// A single field-level validation failure +/// (mirrors `google.rpc.BadRequest.FieldViolation`). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct FieldViolation { + /// Path to the offending field (e.g. `"history_length"`). + pub field: String, + /// Human-readable explanation of why the field is invalid. + pub description: String, +} + +impl FieldViolation { + /// Construct a field violation from any string-likes. + pub fn new(field: impl Into, description: impl Into) -> Self { + Self { + field: field.into(), + description: description.into(), + } + } +} + +/// Stable, machine-readable error metadata (mirrors `google.rpc.ErrorInfo`). +/// +/// `reason` is a `SCREAMING_SNAKE_CASE` constant identifying the failure (e.g. +/// `"TASK_NOT_FOUND"`), `domain` scopes the reason namespace, and `metadata` +/// carries arbitrary key/value context. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ErrorInfo { + /// Stable reason code, unique within `domain`. + pub reason: String, + /// Logical owner of the reason namespace (always `"a2a-rs"` here). + pub domain: String, + /// Additional structured context; omitted from the wire when empty. + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub metadata: BTreeMap, +} + +impl ErrorInfo { + /// Construct an `ErrorInfo` in the `a2a-rs` domain with no metadata. + pub fn new(reason: impl Into) -> Self { + Self { + reason: reason.into(), + domain: DOMAIN.to_string(), + metadata: BTreeMap::new(), + } + } + + /// Attach a metadata key/value pair (builder-style). + pub fn with_metadata(mut self, key: impl Into, value: impl Into) -> Self { + self.metadata.insert(key.into(), value.into()); + self + } +} + +/// The `domain` namespace for every [`ErrorInfo`] this crate emits. +pub const DOMAIN: &str = "a2a-rs"; + +/// One typed entry of the JSON-RPC `error.data` array. +/// +/// Serializes as a Google-RPC `Any`: an `@type` discriminator plus the payload +/// fields inline (`{"@type": "…/google.rpc.ErrorInfo", "reason": …}`), exactly +/// the shape the official SDKs read. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "@type")] +pub enum ErrorDetail { + /// Field-level validation failures. + #[serde(rename = "type.googleapis.com/google.rpc.BadRequest")] + BadRequest { + /// The set of violated fields. + #[serde(rename = "fieldViolations")] + field_violations: Vec, + }, + /// A stable machine-readable reason code. + #[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")] + ErrorInfo(ErrorInfo), +} + +impl ErrorDetail { + /// Convenience constructor for a single-field `BadRequest`. + pub fn bad_request(field: impl Into, description: impl Into) -> Self { + Self::BadRequest { + field_violations: vec![FieldViolation::new(field, description)], + } + } + + /// Convenience constructor for an `ErrorInfo` reason in the `a2a-rs` domain. + pub fn reason(reason: impl Into) -> Self { + Self::ErrorInfo(ErrorInfo::new(reason)) + } +} diff --git a/a2a-rs/src/domain/ids.rs b/a2a-rs/src/domain/ids.rs new file mode 100644 index 0000000..d648693 --- /dev/null +++ b/a2a-rs/src/domain/ids.rs @@ -0,0 +1,131 @@ +//! Strongly-typed identifiers for the A2A protocol. +//! +//! Applies "parse, don't validate" to the codebase's own identifiers: a +//! [`TaskId`], [`ContextId`], or [`PushConfigId`] can only be constructed from a +//! non-empty string via [`FromStr`]/[`TryFrom`], so port methods that accept one +//! never have to re-check emptiness, and argument-order mix-ups +//! (`cancel(context_id, task_id)`) become compile errors. +//! +//! ## Deserialization caveat +//! +//! These newtypes derive `Deserialize` with `#[serde(transparent)]`, which means +//! a value reconstructed from the wire does **not** pass through the validating +//! [`FromStr`] path. That is intentional: deserialized identifiers are validated +//! once at the RPC boundary (the request processor converts wire strings through +//! [`FromStr`] before they reach a port). Treat [`FromStr`]/[`TryFrom`] as the +//! only validating constructors; `Deserialize` is a transport convenience. + +use std::fmt; +use std::str::FromStr; + +use serde::{Deserialize, Serialize}; + +use crate::domain::error::A2AError; + +/// Generates a validating string newtype identifier. +macro_rules! define_id { + ($(#[$meta:meta])* $name:ident, $field:literal) => { + $(#[$meta])* + #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] + #[serde(transparent)] + pub struct $name(String); + + impl $name { + /// Borrow the identifier as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Consume the identifier, returning the owned string. + pub fn into_string(self) -> String { + self.0 + } + } + + impl FromStr for $name { + type Err = A2AError; + + fn from_str(s: &str) -> Result { + if s.trim().is_empty() { + return Err(A2AError::ValidationError { + field: $field.to_string(), + message: concat!($field, " cannot be empty").to_string(), + }); + } + Ok(Self(s.to_owned())) + } + } + + impl TryFrom<&str> for $name { + type Error = A2AError; + + fn try_from(s: &str) -> Result { + s.parse() + } + } + + impl TryFrom for $name { + type Error = A2AError; + + fn try_from(s: String) -> Result { + s.as_str().parse() + } + } + + impl AsRef for $name { + fn as_ref(&self) -> &str { + &self.0 + } + } + + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } + } + }; +} + +define_id!( + /// Identifies a task within an agent. + TaskId, + "task_id" +); + +define_id!( + /// Identifies a conversation/session context grouping related tasks. + ContextId, + "context_id" +); + +define_id!( + /// Identifies a single push-notification configuration for a task. + PushConfigId, + "push_notification_config_id" +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_empty_and_whitespace() { + assert!(TaskId::from_str("").is_err()); + assert!(TaskId::from_str(" ").is_err()); + assert!(ContextId::from_str("").is_err()); + } + + #[test] + fn accepts_non_empty() { + let id = TaskId::from_str("task-123").unwrap(); + assert_eq!(id.as_str(), "task-123"); + assert_eq!(id.to_string(), "task-123"); + } + + #[test] + fn try_from_owned_and_borrowed() { + assert!(TaskId::try_from("x").is_ok()); + assert!(TaskId::try_from("x".to_string()).is_ok()); + assert!(TaskId::try_from(String::new()).is_err()); + } +} diff --git a/a2a-rs/src/domain/mod.rs b/a2a-rs/src/domain/mod.rs index 06db2f3..c3374fc 100644 --- a/a2a-rs/src/domain/mod.rs +++ b/a2a-rs/src/domain/mod.rs @@ -2,8 +2,11 @@ pub mod core; pub mod error; +pub mod error_details; pub mod events; pub mod generated; +pub mod ids; +pub mod retry; #[cfg(test)] mod tests; pub mod validation; @@ -17,9 +20,12 @@ pub use core::{ ListTasksParams, ListTasksResult, Message, MessageSendConfiguration, MessageSendParams, OAuthFlows, Part, PartBuilder, PushNotificationAuthenticationInfo, Role, SecurityRequirement, SecurityScheme, StringList, Task, TaskIdParams, TaskPushNotificationConfig, TaskQueryParams, - TaskSendParams, TaskState, TaskStateExt, TaskStatus, part, + TaskSendParams, TaskState, TaskStateExt, TaskStatus, VersionedTask, part, }; -pub use error::A2AError; +pub use error::{A2AError, Result}; +pub use error_details::{ErrorDetail, ErrorInfo, FieldViolation}; pub use events::{TaskArtifactUpdateEvent, TaskStatusUpdateEvent}; pub use generated::{o_auth_flows, security_scheme}; +pub use ids::{ContextId, PushConfigId, TaskId}; +pub use retry::RetryPolicy; pub use validation::{Validate, ValidationResult}; diff --git a/a2a-rs/src/domain/retry.rs b/a2a-rs/src/domain/retry.rs new file mode 100644 index 0000000..0e8a192 --- /dev/null +++ b/a2a-rs/src/domain/retry.rs @@ -0,0 +1,135 @@ +//! Retry/backoff policy for resilient streaming subscriptions. +//! +//! [`RetryPolicy`] is a pure value object: it carries the knobs for exponential +//! backoff with jitter and computes the delay for a given attempt, with no I/O, +//! no clock, and no randomness source of its own. The impure parts — sleeping +//! and seeding the jitter — live in the transport adapter that consumes it +//! (`adapter::transport::subscribe_resilient`), keeping this type +//! domain-pure and unit-testable. + +use std::time::Duration; + +/// Exponential-backoff-with-jitter policy for reconnecting a dropped stream. +/// +/// The delay before retry *n* (1-based) is `base_delay * 2^(n-1)`, capped at +/// `max_delay`, plus up to `jitter_ms` of seeded jitter (and re-capped at +/// `max_delay`). After `max_retries` consecutive failures the consumer gives up. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RetryPolicy { + /// Delay before the first retry; doubles each subsequent attempt. + pub base_delay: Duration, + /// Upper bound on any single delay. + pub max_delay: Duration, + /// Maximum consecutive failed attempts before giving up. + pub max_retries: u32, + /// Maximum jitter span in milliseconds added to each delay (`0` disables + /// jitter, making delays deterministic). + pub jitter_ms: u64, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + base_delay: Duration::from_millis(500), + max_delay: Duration::from_secs(10), + max_retries: 15, + jitter_ms: 200, + } + } +} + +impl RetryPolicy { + /// A policy that never retries — the subscription fails on first disconnect. + pub fn no_retry() -> Self { + Self { + max_retries: 0, + ..Self::default() + } + } + + /// Compute the delay before `attempt` (1-based), folding `seed` into the + /// jitter. Pure: identical `(attempt, seed)` always yields the same delay. + pub fn backoff(&self, attempt: u32, seed: u64) -> Duration { + let base_ms = self.base_delay.as_millis() as u64; + let max_ms = self.max_delay.as_millis() as u64; + + // Exponential growth: base * 2^(attempt-1), saturating, then capped. + let shift = attempt.saturating_sub(1).min(63); + let factor = 1u64.checked_shl(shift).unwrap_or(u64::MAX); + let grown = base_ms.saturating_mul(factor).min(max_ms); + + let jitter = if self.jitter_ms == 0 { + 0 + } else { + mix(seed) % self.jitter_ms + }; + + Duration::from_millis(grown.saturating_add(jitter).min(max_ms)) + } +} + +/// Deterministic jitter mixer (a single SplitMix64-style round). Keeps jitter +/// dependency-free (no `rand`) while spreading reconnect storms across clients +/// with different seeds. +#[inline] +fn mix(seed: u64) -> u64 { + let mut z = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn policy(jitter_ms: u64) -> RetryPolicy { + RetryPolicy { + base_delay: Duration::from_millis(100), + max_delay: Duration::from_secs(10), + max_retries: 10, + jitter_ms, + } + } + + #[test] + fn grows_exponentially_without_jitter() { + let p = policy(0); + assert_eq!(p.backoff(1, 0), Duration::from_millis(100)); + assert_eq!(p.backoff(2, 0), Duration::from_millis(200)); + assert_eq!(p.backoff(3, 0), Duration::from_millis(400)); + assert_eq!(p.backoff(4, 0), Duration::from_millis(800)); + } + + #[test] + fn caps_at_max_delay() { + let p = policy(0); + // 100ms * 2^20 is far past the 10s cap. + assert_eq!(p.backoff(20, 12345), Duration::from_secs(10)); + // Huge attempt must not panic on overflow. + assert_eq!(p.backoff(u32::MAX, 1), Duration::from_secs(10)); + } + + #[test] + fn jitter_stays_within_span_and_is_deterministic() { + let p = policy(200); + for seed in 0..1000u64 { + let d = p.backoff(1, seed).as_millis() as u64; + // base 100ms + jitter in [0, 200) + assert!( + (100..300).contains(&d), + "delay {d} out of range for seed {seed}" + ); + // deterministic + assert_eq!(p.backoff(1, seed), p.backoff(1, seed)); + } + } + + #[test] + fn jitter_varies_across_seeds() { + let p = policy(200); + let a = p.backoff(1, 1); + let b = p.backoff(1, 2); + assert_ne!(a, b, "different seeds should generally jitter differently"); + } +} diff --git a/a2a-rs/src/lib.rs b/a2a-rs/src/lib.rs index 82ef276..cdb2281 100644 --- a/a2a-rs/src/lib.rs +++ b/a2a-rs/src/lib.rs @@ -20,7 +20,7 @@ //! # #[cfg(feature = "http-client")] //! # { //! use a2a_rs::{HttpClient, Message}; -//! use a2a_rs::services::AsyncA2AClient; +//! use a2a_rs::Transport; //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { @@ -40,7 +40,7 @@ //! ## Creating a server //! //! ```rust,ignore -//! use a2a_rs::{HttpServer, SimpleAgentInfo, DefaultRequestProcessor}; +//! use a2a_rs::{HttpServer, SimpleAgentInfo, ConnectRpcAdapter}; //! use my_app::{MyMessageHandler, MyTaskManager, MyNotificationManager}; //! //! #[tokio::main] @@ -51,8 +51,8 @@ //! let notification_manager = MyNotificationManager::new(); //! let agent_info = SimpleAgentInfo::new("my-agent".to_string(), "https://api.example.com".to_string()); //! -//! // Create a request processor with your handlers -//! let processor = DefaultRequestProcessor::new( +//! // Wrap your handlers in the ConnectRPC transport adapter +//! let adapter = ConnectRpcAdapter::new( //! message_handler, //! task_manager, //! notification_manager, @@ -61,7 +61,7 @@ //! //! // Create and start the server //! let server = HttpServer::new( -//! processor, +//! adapter, //! agent_info, //! "127.0.0.1:8080".to_string(), //! ); @@ -84,31 +84,45 @@ pub mod observability; pub use domain::{ A2AError, AgentCapabilities, AgentCard, AgentCardSignature, AgentExtension, AgentInterface, AgentProvider, AgentSkill, Artifact, AuthorizationCodeOAuthFlow, ClientCredentialsOAuthFlow, - DeleteTaskPushNotificationConfigParams, DeviceCodeOAuthFlow, - GetTaskPushNotificationConfigParams, ListTaskPushNotificationConfigsParams, ListTasksParams, - ListTasksResult, Message, MessageSendConfiguration, MessageSendParams, OAuthFlows, Part, - PushNotificationAuthenticationInfo, Role, SecurityScheme, Task, TaskArtifactUpdateEvent, - TaskIdParams, TaskPushNotificationConfig, TaskQueryParams, TaskSendParams, TaskState, - TaskStatus, TaskStatusUpdateEvent, + ContextId, DeleteTaskPushNotificationConfigParams, DeviceCodeOAuthFlow, ErrorDetail, ErrorInfo, + FieldViolation, GetTaskPushNotificationConfigParams, ListTaskPushNotificationConfigsParams, + ListTasksParams, ListTasksResult, Message, MessageSendConfiguration, MessageSendParams, + OAuthFlows, Part, PushConfigId, PushNotificationAuthenticationInfo, Result, RetryPolicy, Role, + SecurityScheme, Task, TaskArtifactUpdateEvent, TaskId, TaskIdParams, + TaskPushNotificationConfig, TaskQueryParams, TaskSendParams, TaskState, TaskStatus, + TaskStatusUpdateEvent, VersionedTask, }; // Port traits for better separation of concerns pub use port::{ - AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - MessageHandler, NotificationManager, StreamingHandler, StreamingSubscriber, TaskManager, - UpdateEvent, + AsyncMessageHandler, AsyncNotificationManager, AsyncNotificationManagerExt, AsyncPushNotifier, + AsyncStreamingHandler, AsyncTaskLifecycle, AsyncTaskLifecycleExt, AsyncTaskQuery, + AsyncTaskVersioning, CallContext, CallInterceptor, CallSide, NoopPushNotifier, SeqEvent, + StreamEvent, StreamItem, StreamingSubscriber, Transport, UpdateEvent, }; #[cfg(feature = "http-client")] pub use adapter::HttpClient; +#[cfg(feature = "jsonrpc-client")] +pub use adapter::JsonRpcClient; + +#[cfg(feature = "client")] +pub use adapter::{TransportFactory, TransportNegotiator, default_registry}; + +#[cfg(feature = "client")] +pub use adapter::{RetryingTransport, subscribe_resilient}; + +#[cfg(any(feature = "http-client", feature = "jsonrpc-client"))] +pub use adapter::{connect, fetch_agent_card}; + #[cfg(feature = "http-server")] pub use adapter::HttpServer; #[cfg(feature = "server")] pub use adapter::{ - DefaultRequestProcessor, InMemoryTaskStorage, NoopPushNotificationSender, - PushNotificationRegistry, PushNotificationSender, SimpleAgentInfo, + ConnectRpcAdapter, InMemoryStreamingHandler, InMemoryTaskStorage, NoopPushNotificationSender, + NoopStreamingHandler, PushNotificationRegistry, PushNotificationSender, SimpleAgentInfo, }; #[cfg(all(feature = "server", feature = "http-client"))] @@ -120,3 +134,6 @@ pub use adapter::{ApiKeyAuthenticator, BearerTokenAuthenticator, NoopAuthenticat pub use adapter::{JwtAuthenticator, OAuth2Authenticator, OpenIdConnectAuthenticator}; #[cfg(feature = "http-server")] pub use port::Authenticator; + +#[cfg(feature = "tracing")] +pub use adapter::LoggingInterceptor; diff --git a/a2a-rs/src/port/client.rs b/a2a-rs/src/port/client.rs new file mode 100644 index 0000000..dc19de5 --- /dev/null +++ b/a2a-rs/src/port/client.rs @@ -0,0 +1,156 @@ +//! The client-side `Transport` port. +//! +//! [`Transport`] is the outbound port a client uses to talk to a remote A2A +//! agent: the application names the capability it needs ("send a message", "get a +//! task", "subscribe to updates"), and a concrete transport **adapter** +//! (ConnectRPC, JSON-RPC 2.0, …) fulfils it over the wire. This is the mirror of +//! the inbound server ports — same hexagonal shape, opposite direction. +//! +//! Each adapter reports its wire protocol via [`Transport::protocol`] so a +//! card-driven negotiator can pick the right one from an agent card's +//! `supported_interfaces`. +//! +//! The port carries no feature gate (hex rule 5 — gate adapters, not ports); it +//! depends only on the always-available `async-trait`/`futures` and domain types. + +use async_trait::async_trait; +use futures::Stream; +use std::pin::Pin; + +use crate::domain::{ + A2AError, ListTasksParams, ListTasksResult, Message, Task, TaskArtifactUpdateEvent, + TaskPushNotificationConfig, TaskStatusUpdateEvent, +}; + +/// The capability a client needs from a remote A2A agent, independent of wire +/// protocol. Implemented by each transport adapter (`HttpClient` for ConnectRPC, +/// `JsonRpcClient` for JSON-RPC 2.0, …). +#[async_trait] +pub trait Transport: Send + Sync { + /// The wire protocol this transport speaks, matching an agent interface's + /// `protocol_binding` (e.g. `"JSONRPC"`, `"CONNECTRPC"`, `"GRPC"`). + fn protocol(&self) -> &str; + + /// Send a message to a task + async fn send_task_message( + &self, + task_id: &str, + message: &Message, + session_id: Option<&str>, + history_length: Option, + ) -> Result; + + /// Get a task by ID + async fn get_task(&self, task_id: &str, history_length: Option) -> Result; + + /// Cancel a task + async fn cancel_task(&self, task_id: &str) -> Result; + + /// Set up push notifications for a task + async fn set_task_push_notification( + &self, + config: &TaskPushNotificationConfig, + ) -> Result; + + /// Get push notification configuration for a task + async fn get_task_push_notification( + &self, + task_id: &str, + ) -> Result; + + /// List tasks with filtering and pagination (v1.0.0) + async fn list_tasks(&self, params: &ListTasksParams) -> Result; + + /// List all push notification configs for a task (v1.0.0) + async fn list_push_notification_configs( + &self, + task_id: &str, + ) -> Result, A2AError>; + + /// Get a specific push notification config by ID (v1.0.0) + async fn get_push_notification_config( + &self, + task_id: &str, + config_id: &str, + ) -> Result; + + /// Delete a specific push notification config (v1.0.0) + async fn delete_push_notification_config( + &self, + task_id: &str, + config_id: &str, + ) -> Result<(), A2AError>; + + /// Subscribe to task updates (for streaming). + /// + /// Passing `last_event_id = None` is the spec-compliant subscribe: it maps + /// to the A2A `SubscribeToTask` call and streams from the task's current + /// state — exactly what a spec client expects. + /// + /// `last_event_id = Some(..)` opts into the a2a-rs **`Last-Event-ID` + /// resumption enhancement** (not part of the A2A v1.0 spec): a resumable + /// transport sends it as the SSE `Last-Event-ID` header so an a2a-rs server + /// replays the events after that id before streaming live. A spec-compliant + /// server ignores the header and simply streams from current state, so this + /// stays interoperable either way. + async fn subscribe_to_task( + &self, + task_id: &str, + history_length: Option, + last_event_id: Option<&str>, + ) -> Result> + Send>>, A2AError>; +} + +/// A streamed [`StreamItem`] tagged with the server's SSE event id (when the +/// transport supports it). A resilient client records the most recent `event_id` +/// and echoes it as `Last-Event-ID` on reconnect to resume without gaps. +/// +/// The `event_id` is part of the a2a-rs resumption enhancement (see +/// [`subscribe_to_task`](Transport::subscribe_to_task)); spec clients that only +/// read `item` are unaffected. +#[derive(Debug, Clone)] +pub struct StreamEvent { + /// The server-assigned per-task event id, parsed from the SSE `id:` field. + /// `None` for the initial task snapshot, for transports without event ids, + /// or when talking to a spec-compliant server that does not emit `id:`. + pub event_id: Option, + /// The update payload. + pub item: StreamItem, +} + +impl StreamEvent { + /// Construct a stream event. + #[inline] + pub fn new(event_id: Option, item: StreamItem) -> Self { + Self { event_id, item } + } + + /// A stream event with no id (initial snapshot / id-less transport). + #[inline] + pub fn untagged(item: StreamItem) -> Self { + Self { + event_id: None, + item, + } + } +} + +/// Items that can be streamed from the server during task subscriptions. +/// +/// When subscribing to streaming updates for a task, the server can send +/// different types of items: +/// - `Task`: The complete initial task state when subscription starts +/// - `StatusUpdate`: Updates to the task's status (state changes, progress) +/// - `ArtifactUpdate`: Notifications about new or updated artifacts +/// +/// This allows clients to receive real-time updates about task progress +/// and results as they become available. +#[derive(Debug, Clone)] +pub enum StreamItem { + /// The initial task state + Task(Task), + /// A task status update + StatusUpdate(TaskStatusUpdateEvent), + /// A task artifact update + ArtifactUpdate(TaskArtifactUpdateEvent), +} diff --git a/a2a-rs/src/port/interceptor.rs b/a2a-rs/src/port/interceptor.rs new file mode 100644 index 0000000..8492838 --- /dev/null +++ b/a2a-rs/src/port/interceptor.rs @@ -0,0 +1,96 @@ +//! The `CallInterceptor` port: before/after middleware around A2A calls. +//! +//! An interceptor is a cross-cutting hook that runs *around* every A2A call — +//! the chain-of-responsibility analogue of the official SDK's `CallInterceptor`. +//! It is a **port** (a capability the application needs from the edge), so the +//! trait lives here and concrete interceptors (logging, metrics, auth-token +//! injection) are adapters. The same trait is wired into both the client +//! transport ([`JsonRpcClient`](crate::adapter::JsonRpcClient)) and the server +//! transport ([`JsonRpcAdapter`](crate::adapter::JsonRpcAdapter)); the +//! [`CallContext::side`] tells an interceptor which direction it is observing. +//! +//! Chains run `before` hooks in registration order, dispatch the call, then run +//! `after` hooks in reverse order — the conventional onion ordering, so an +//! interceptor's `after` wraps everything its `before` set up. A `before` that +//! returns `Err` short-circuits the call (the dispatch never happens) but its +//! `after` still runs, observing the error. +//! +//! The hooks see call *metadata* (method name, side), not the typed +//! request/response — those differ per method and would force the trait generic. +//! Metadata is enough for the canonical uses (logging, metrics, tracing spans, +//! header/auth propagation handled by the adapter around the chain). + +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::domain::A2AError; + +/// Which side of the wire an interceptor chain is running on. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CallSide { + /// The outbound client transport is making the call. + Client, + /// The inbound server transport is handling the call. + Server, +} + +/// Metadata about an in-flight A2A call, passed to each interceptor hook. +#[derive(Debug, Clone)] +pub struct CallContext { + /// The A2A method name (PascalCase wire name, e.g. `"SendMessage"`). + pub method: String, + /// Whether this chain runs on the client or server side. + pub side: CallSide, +} + +impl CallContext { + /// Construct a context for `method` on the given `side`. + pub fn new(method: impl Into, side: CallSide) -> Self { + Self { + method: method.into(), + side, + } + } +} + +/// A before/after hook around an A2A call (auth, logging, metrics, tracing). +/// +/// Both hooks have default no-op bodies, so an interceptor overrides only the +/// side it cares about. +#[async_trait] +pub trait CallInterceptor: Send + Sync { + /// Run before the call is dispatched. Returning `Err` short-circuits the + /// call: the dispatch is skipped and the error is returned to the caller + /// (after `after` hooks still run, observing it). + async fn before(&self, _ctx: &CallContext) -> Result<(), A2AError> { + Ok(()) + } + + /// Run after the call completes, observing its outcome (`Ok` on success, + /// `Err` with a borrow of the error otherwise). + async fn after(&self, _ctx: &CallContext, _outcome: Result<(), &A2AError>) {} +} + +/// Run a chain's `before` hooks in registration order; the first `Err` +/// short-circuits and is returned without invoking the remaining hooks. +pub async fn run_before( + interceptors: &[Arc], + ctx: &CallContext, +) -> Result<(), A2AError> { + for interceptor in interceptors { + interceptor.before(ctx).await?; + } + Ok(()) +} + +/// Run a chain's `after` hooks in reverse registration order (onion unwinding). +pub async fn run_after( + interceptors: &[Arc], + ctx: &CallContext, + outcome: Result<(), &A2AError>, +) { + for interceptor in interceptors.iter().rev() { + interceptor.after(ctx, outcome).await; + } +} diff --git a/a2a-rs/src/port/message_handler.rs b/a2a-rs/src/port/message_handler.rs index aa74151..ef5a0b4 100644 --- a/a2a-rs/src/port/message_handler.rs +++ b/a2a-rs/src/port/message_handler.rs @@ -1,40 +1,9 @@ //! Message handling port definitions -#[cfg(feature = "server")] use async_trait::async_trait; use crate::domain::{A2AError, Message, Task}; -/// A trait for handling message processing operations -pub trait MessageHandler { - /// Process a message for a specific task - fn process_message( - &self, - task_id: &str, - message: &Message, - session_id: Option<&str>, - ) -> Result; - - /// Validate a message before processing - fn validate_message(&self, message: &Message) -> Result<(), A2AError> { - // Default implementation - can be overridden - if message.parts.is_empty() { - return Err(A2AError::ValidationError { - field: "message.parts".to_string(), - message: "Message must contain at least one part".to_string(), - }); - } - Ok(()) - } - - /// Transform a message before processing (e.g., for content filtering) - fn transform_message(&self, message: Message) -> Result { - // Default implementation - pass through unchanged - Ok(message) - } -} - -#[cfg(feature = "server")] #[async_trait] /// An async trait for handling message processing operations pub trait AsyncMessageHandler: Send + Sync { diff --git a/a2a-rs/src/port/mod.rs b/a2a-rs/src/port/mod.rs index c4c9956..ace1c27 100644 --- a/a2a-rs/src/port/mod.rs +++ b/a2a-rs/src/port/mod.rs @@ -14,6 +14,8 @@ // Business capability ports (focused domain interfaces) pub mod authenticator; +pub mod client; +pub mod interceptor; pub mod message_handler; pub mod notification_manager; pub mod streaming_handler; @@ -23,9 +25,15 @@ pub mod task_manager; pub use authenticator::{ AuthContext, AuthContextExtractor, AuthPrincipal, Authenticator, CompositeAuthenticator, }; -pub use message_handler::{AsyncMessageHandler, MessageHandler}; -pub use notification_manager::{AsyncNotificationManager, NotificationManager}; +pub use client::{StreamEvent, StreamItem, Transport}; +pub use interceptor::{CallContext, CallInterceptor, CallSide, run_after, run_before}; +pub use message_handler::AsyncMessageHandler; +pub use notification_manager::{ + AsyncNotificationManager, AsyncNotificationManagerExt, AsyncPushNotifier, NoopPushNotifier, +}; pub use streaming_handler::{ - AsyncStreamingHandler, StreamingHandler, Subscriber as StreamingSubscriber, UpdateEvent, + AsyncStreamingHandler, SeqEvent, Subscriber as StreamingSubscriber, UpdateEvent, +}; +pub use task_manager::{ + AsyncTaskLifecycle, AsyncTaskLifecycleExt, AsyncTaskQuery, AsyncTaskVersioning, }; -pub use task_manager::{AsyncTaskManager, TaskManager}; diff --git a/a2a-rs/src/port/notification_manager.rs b/a2a-rs/src/port/notification_manager.rs index 0c6d0c5..8af52fd 100644 --- a/a2a-rs/src/port/notification_manager.rs +++ b/a2a-rs/src/port/notification_manager.rs @@ -1,9 +1,12 @@ //! Push notification management port definitions -#[cfg(feature = "server")] use async_trait::async_trait; -use crate::domain::{A2AError, TaskIdParams, TaskPushNotificationConfig}; +use crate::domain::{ + A2AError, DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, + ListTaskPushNotificationConfigsParams, TaskArtifactUpdateEvent, TaskPushNotificationConfig, + TaskStatusUpdateEvent, +}; /// Validate a push notification config URL. /// @@ -46,163 +49,148 @@ fn validate_push_notification_url(config: &TaskPushNotificationConfig) -> Result Ok(()) } -/// A trait for managing push notification configurations and delivery -pub trait NotificationManager { - /// Set up push notifications for a task - fn set_task_notification( - &self, - config: &TaskPushNotificationConfig, - ) -> Result; - - /// Get the push notification configuration for a task - fn get_task_notification(&self, task_id: &str) -> Result; - - /// Remove push notification configuration for a task - fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError>; - - /// Check if push notifications are configured for a task - fn has_task_notification(&self, task_id: &str) -> Result { - match self.get_task_notification(task_id) { - Ok(_) => Ok(true), - Err(A2AError::TaskNotFound(_)) => Ok(false), - Err(e) => Err(e), - } - } - - /// Validate push notification configuration - fn validate_notification_config( - &self, - config: &TaskPushNotificationConfig, - ) -> Result<(), A2AError> { - validate_push_notification_url(config) - } - - /// Send a test notification to verify configuration - fn send_test_notification(&self, config: &TaskPushNotificationConfig) -> Result<(), A2AError> { - // Default implementation - can be overridden - self.validate_notification_config(config)?; - // In a real implementation, this would send a test notification - Ok(()) - } -} - -#[cfg(feature = "server")] +/// Async management of push-notification configurations. +/// +/// Expressed in terms of the A2A v1.0.0 multi-config CRUD model — the richest +/// shape — so a single capability covers both single- and multi-config storage. +/// Validation conveniences (URL/task-id checks) live on +/// [`AsyncNotificationManagerExt`], which is blanket-implemented for every +/// `AsyncNotificationManager`. #[async_trait] -/// An async trait for managing push notification configurations and delivery pub trait AsyncNotificationManager: Send + Sync { - /// Set up push notifications for a task - async fn set_task_notification( + /// Create or replace a push-notification config, returning it with any + /// server-assigned ID populated. + async fn set_config( &self, config: &TaskPushNotificationConfig, ) -> Result; - /// Get the push notification configuration for a task - async fn get_task_notification( + /// Get a push-notification config for a task. + async fn get_config( &self, - task_id: &str, + params: &GetTaskPushNotificationConfigParams, ) -> Result; - /// Remove push notification configuration for a task - async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError>; - - /// Check if push notifications are configured for a task - async fn has_task_notification(&self, task_id: &str) -> Result { - match self.get_task_notification(task_id).await { - Ok(_) => Ok(true), - Err(A2AError::TaskNotFound(_)) => Ok(false), - Err(e) => Err(e), - } - } - - /// Validate push notification configuration - async fn validate_notification_config( + /// List all push-notification configs for a task. + async fn list_configs( &self, - config: &TaskPushNotificationConfig, - ) -> Result<(), A2AError> { - validate_push_notification_url(config) - } + params: &ListTaskPushNotificationConfigsParams, + ) -> Result, A2AError>; - /// Send a test notification to verify configuration - async fn send_test_notification( + /// Delete a push-notification config. Idempotent per the v1.0.0 spec. + async fn delete_config( &self, - config: &TaskPushNotificationConfig, - ) -> Result<(), A2AError> { - // Default implementation - can be overridden - self.validate_notification_config(config).await?; - // In a real implementation, this would send a test notification - Ok(()) + params: &DeleteTaskPushNotificationConfigParams, + ) -> Result<(), A2AError>; +} + +/// Validation conveniences over [`AsyncNotificationManager`]. +/// +/// Blanket-implemented for every `AsyncNotificationManager`, so implementors +/// only stub the core CRUD primitives. +#[async_trait] +pub trait AsyncNotificationManagerExt: AsyncNotificationManager { + /// Validate a push-notification config's webhook URL. + fn validate_config(&self, config: &TaskPushNotificationConfig) -> Result<(), A2AError> { + validate_push_notification_url(config) } - /// Set task notification with validation - async fn set_task_notification_validated( + /// Validate the task ID and webhook URL, then store the config. + async fn set_validated( &self, config: &TaskPushNotificationConfig, ) -> Result { - // Validate the task ID if config.task_id.trim().is_empty() { return Err(A2AError::ValidationError { field: "task_id".to_string(), message: "Task ID cannot be empty".to_string(), }); } + self.validate_config(config)?; + self.set_config(config).await + } +} - // Validate the notification config - self.validate_notification_config(config).await?; +impl AsyncNotificationManagerExt for T {} - // Set the notification - self.set_task_notification(config).await - } +/// Out-of-band delivery of task updates to a task's configured push endpoint. +/// +/// This is the **delivery** half of push notifications, deliberately separate +/// from the config-CRUD capability ([`AsyncNotificationManager`]) and from the +/// in-process streaming fan-out ([`AsyncStreamingHandler`](crate::port::AsyncStreamingHandler)). +/// Keeping delivery behind its own port is what lets the orchestration layer +/// (the [`TaskStatusBroadcast`](crate::application::TaskStatusBroadcast) mixin) +/// "commit, announce to subscribers, then notify the webhook" without any one +/// adapter taking on a second job — and lets the notification backend be swapped +/// freely (HTTP webhook, no-op, a queue, a test spy) at the composition edge. +/// +/// Errors are surfaced to the caller, but the orchestration layer treats +/// delivery as best-effort: a webhook that is down must not fail the task +/// mutation that triggered it. +#[async_trait] +pub trait AsyncPushNotifier: Send + Sync { + /// Deliver a status update to the task's configured push endpoint, if any. + /// + /// A task with no registered config is not an error — implementations + /// return `Ok(())`. + async fn notify_status( + &self, + task_id: &str, + event: &TaskStatusUpdateEvent, + ) -> Result<(), A2AError>; - /// Get task notification with validation - async fn get_task_notification_validated( + /// Deliver an artifact update to the task's configured push endpoint, if any. + async fn notify_artifact( &self, - params: &TaskIdParams, - ) -> Result { - if params.id.trim().is_empty() { - return Err(A2AError::ValidationError { - field: "task_id".to_string(), - message: "Task ID cannot be empty".to_string(), - }); - } + task_id: &str, + event: &TaskArtifactUpdateEvent, + ) -> Result<(), A2AError>; +} - self.get_task_notification(¶ms.id).await +/// Deref-forwarding impl so an `Arc` (e.g. the value +/// handed out by `InMemoryTaskStorage::push_notifier`) satisfies `impl +/// AsyncPushNotifier` bounds directly, without re-wrapping. +#[async_trait] +impl AsyncPushNotifier for std::sync::Arc { + async fn notify_status( + &self, + task_id: &str, + event: &TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { + (**self).notify_status(task_id, event).await } - /// Send notification for task status update - async fn notify_task_status_update( + async fn notify_artifact( &self, task_id: &str, - _status_update: &crate::domain::TaskStatusUpdateEvent, + event: &TaskArtifactUpdateEvent, ) -> Result<(), A2AError> { - // Default implementation - can be overridden - // Check if notifications are configured for this task - if !self.has_task_notification(task_id).await? { - return Ok(()); // No notification configured, silently succeed - } + (**self).notify_artifact(task_id, event).await + } +} - // In a real implementation, this would send the actual notification - // For now, just validate that we have the configuration - let _config = self.get_task_notification(task_id).await?; +/// A no-op [`AsyncPushNotifier`] for compositions with no push backend wired. +/// +/// Every method succeeds without doing anything, mirroring `NoopStreamingHandler` +/// on the streaming side. +#[derive(Clone, Debug, Default)] +pub struct NoopPushNotifier; +#[async_trait] +impl AsyncPushNotifier for NoopPushNotifier { + async fn notify_status( + &self, + _task_id: &str, + _event: &TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { Ok(()) } - /// Send notification for task artifact update - async fn notify_task_artifact_update( + async fn notify_artifact( &self, - task_id: &str, - _artifact_update: &crate::domain::TaskArtifactUpdateEvent, + _task_id: &str, + _event: &TaskArtifactUpdateEvent, ) -> Result<(), A2AError> { - // Default implementation - can be overridden - // Check if notifications are configured for this task - if !self.has_task_notification(task_id).await? { - return Ok(()); // No notification configured, silently succeed - } - - // In a real implementation, this would send the actual notification - // For now, just validate that we have the configuration - let _config = self.get_task_notification(task_id).await?; - Ok(()) } } diff --git a/a2a-rs/src/port/streaming_handler.rs b/a2a-rs/src/port/streaming_handler.rs index 6f7a752..4af6a69 100644 --- a/a2a-rs/src/port/streaming_handler.rs +++ b/a2a-rs/src/port/streaming_handler.rs @@ -1,6 +1,5 @@ //! Streaming and real-time update handling port definitions -#[cfg(feature = "server")] use async_trait::async_trait; use futures::Stream; use std::pin::Pin; @@ -9,7 +8,6 @@ use crate::domain::core::task::TaskStateExt; use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent}; /// A trait for subscribing to real-time updates -#[cfg(feature = "server")] #[async_trait] pub trait Subscriber: Send + Sync { /// Handle an update @@ -29,39 +27,6 @@ pub trait Subscriber: Send + Sync { } } -/// A trait for managing streaming connections and real-time updates -pub trait StreamingHandler { - /// Add a status update subscriber for a task - fn add_status_subscriber( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result; // Returns subscription ID - - /// Add an artifact update subscriber for a task - fn add_artifact_subscriber( - &self, - task_id: &str, - subscriber: Box + Send + Sync>, - ) -> Result; // Returns subscription ID - - /// Remove a specific subscription - fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>; - - /// Remove all subscribers for a task - fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>; - - /// Get the number of active subscribers for a task - fn get_subscriber_count(&self, task_id: &str) -> Result; - - /// Check if a task has any active subscribers - fn has_subscribers(&self, task_id: &str) -> Result { - let count = self.get_subscriber_count(task_id)?; - Ok(count > 0) - } -} - -#[cfg(feature = "server")] #[async_trait] /// An async trait for managing streaming connections and real-time updates pub trait AsyncStreamingHandler: Send + Sync { @@ -123,11 +88,17 @@ pub trait AsyncStreamingHandler: Send + Sync { A2AError, >; - /// Create a combined stream of all updates for a task + /// Create a combined stream of all updates for a task. + /// + /// Each yielded [`SeqEvent`] carries a per-task monotonic id so a client can + /// resume after a disconnect. When `from_event_id` is `Some(n)`, the handler + /// first replays any buffered events with id `> n` (best-effort, bounded by + /// the handler's replay buffer) before streaming live updates. async fn combined_update_stream( &self, task_id: &str, - ) -> Result> + Send>>, A2AError>; + from_event_id: Option, + ) -> Result> + Send>>, A2AError>; /// Validate streaming parameters async fn validate_streaming_params(&self, task_id: &str) -> Result<(), A2AError> { @@ -140,13 +111,19 @@ pub trait AsyncStreamingHandler: Send + Sync { Ok(()) } - /// Start streaming for a task with automatic cleanup + /// Start streaming for a task with automatic cleanup. + /// + /// `from_event_id` is forwarded to [`combined_update_stream`] for + /// Last-Event-ID resumption. + /// + /// [`combined_update_stream`]: AsyncStreamingHandler::combined_update_stream async fn start_task_streaming( &self, task_id: &str, - ) -> Result> + Send>>, A2AError> { + from_event_id: Option, + ) -> Result> + Send>>, A2AError> { self.validate_streaming_params(task_id).await?; - self.combined_update_stream(task_id).await + self.combined_update_stream(task_id, from_event_id).await } /// Stop all streaming for a task @@ -155,6 +132,116 @@ pub trait AsyncStreamingHandler: Send + Sync { } } +/// Forwarding impl so a type-erased `Arc` can itself +/// be passed wherever an `impl AsyncStreamingHandler` is expected (e.g. +/// `TaskService::with_streaming_handler`). This lets a single shared streaming +/// backend be injected into both a message handler and a transport adapter +/// without naming its concrete type. Only the required methods are forwarded; +/// the trait's default methods ride along on top of them. +#[async_trait] +impl AsyncStreamingHandler for std::sync::Arc { + async fn add_status_subscriber( + &self, + task_id: &str, + subscriber: Box + Send + Sync>, + ) -> Result { + (**self).add_status_subscriber(task_id, subscriber).await + } + + async fn add_artifact_subscriber( + &self, + task_id: &str, + subscriber: Box + Send + Sync>, + ) -> Result { + (**self).add_artifact_subscriber(task_id, subscriber).await + } + + async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> { + (**self).remove_subscription(subscription_id).await + } + + async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { + (**self).remove_task_subscribers(task_id).await + } + + async fn get_subscriber_count(&self, task_id: &str) -> Result { + (**self).get_subscriber_count(task_id).await + } + + async fn broadcast_status_update( + &self, + task_id: &str, + update: TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { + (**self).broadcast_status_update(task_id, update).await + } + + async fn broadcast_artifact_update( + &self, + task_id: &str, + update: TaskArtifactUpdateEvent, + ) -> Result<(), A2AError> { + (**self).broadcast_artifact_update(task_id, update).await + } + + async fn status_update_stream( + &self, + task_id: &str, + ) -> Result> + Send>>, A2AError> + { + (**self).status_update_stream(task_id).await + } + + async fn artifact_update_stream( + &self, + task_id: &str, + ) -> Result< + Pin> + Send>>, + A2AError, + > { + (**self).artifact_update_stream(task_id).await + } + + async fn combined_update_stream( + &self, + task_id: &str, + from_event_id: Option, + ) -> Result> + Send>>, A2AError> { + (**self) + .combined_update_stream(task_id, from_event_id) + .await + } +} + +/// A streamed [`UpdateEvent`] tagged with a per-task monotonic id. +/// +/// The id is assigned by the streaming handler when the event is broadcast and +/// is surfaced to clients as the SSE `id:` field. On reconnect a client echoes +/// the last id it saw via `Last-Event-ID`, and the handler replays buffered +/// events with a greater id (see +/// [`combined_update_stream`](AsyncStreamingHandler::combined_update_stream)). +/// +/// This id/`Last-Event-ID` resumption is an a2a-rs enhancement on top of the +/// W3C SSE standard, **not** part of the A2A v1.0 spec. Emitting the `id:` field +/// is inert for spec clients (they read only the event payload), so it does not +/// affect interop. +#[derive(Debug, Clone)] +pub struct SeqEvent { + /// Per-task monotonic event id (starts at 1; `0` is reserved for the + /// initial task snapshot, which carries no replayable id). + pub id: u64, + /// The update payload. + pub event: UpdateEvent, +} + +impl SeqEvent { + /// Construct a sequenced event. + #[inline] + pub fn new(id: u64, event: UpdateEvent) -> Self { + Self { id, event } + } +} + /// Union type for different kinds of updates that can be streamed #[derive(Debug, Clone)] pub enum UpdateEvent { diff --git a/a2a-rs/src/port/task_manager.rs b/a2a-rs/src/port/task_manager.rs index e42752d..892e7ec 100644 --- a/a2a-rs/src/port/task_manager.rs +++ b/a2a-rs/src/port/task_manager.rs @@ -1,150 +1,123 @@ //! Task management port definitions -#[cfg(feature = "server")] use async_trait::async_trait; use crate::{ Message, domain::{ - A2AError, DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, - ListTaskPushNotificationConfigsParams, ListTasksParams, ListTasksResult, Task, - TaskIdParams, TaskPushNotificationConfig, TaskQueryParams, TaskState, + A2AError, ContextId, ListTasksParams, ListTasksResult, Task, TaskId, TaskIdParams, + TaskQueryParams, TaskState, VersionedTask, }, }; -/// A trait for managing task lifecycle and operations -pub trait TaskManager { - /// Create a new task - fn create_task(&self, task_id: &str, context_id: &str) -> Result; +/// Async task lifecycle management: the core CRUD capability over individual tasks. +/// +/// A handler implements this trait if it can create, read, mutate, and cancel +/// tasks. Listing/querying across tasks is a separate capability — see +/// [`AsyncTaskQuery`]. Convenience wrappers that validate request parameters +/// live on [`AsyncTaskLifecycleExt`], which is blanket-implemented for every +/// `AsyncTaskLifecycle`. +#[async_trait] +pub trait AsyncTaskLifecycle: Send + Sync { + /// Create a new task in the given context. + async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result; - /// Get a task by ID with optional history - fn get_task(&self, task_id: &str, history_length: Option) -> Result; + /// Get a task by ID with optional history length limit. + async fn get(&self, id: &TaskId, history_length: Option) -> Result; - /// Update task status with an optional message to add to history - fn update_task_status( + /// Update task status, optionally appending a message to history. + async fn update_status( &self, - task_id: &str, + id: &TaskId, state: TaskState, message: Option, ) -> Result; - /// Cancel a task - fn cancel_task(&self, task_id: &str) -> Result; - - /// Check if a task exists - fn task_exists(&self, task_id: &str) -> Result; - - /// List tasks with optional filtering - fn list_tasks( - &self, - _context_id: Option<&str>, - _limit: Option, - ) -> Result, A2AError> { - // Default implementation - can be overridden - // Basic implementation that doesn't support filtering - Err(A2AError::UnsupportedOperation( - "Task listing not implemented".to_string(), - )) - } + /// Cancel a task. + async fn cancel(&self, id: &TaskId) -> Result; - /// Get task metadata - fn get_task_metadata( - &self, - task_id: &str, - ) -> Result, A2AError> { - let task = self.get_task(task_id, None)?; - if let Some(metadata) = task.metadata.as_option() { - let val = serde_json::to_value(metadata)?; - if let serde_json::Value::Object(map) = val { - return Ok(map); - } - } - Ok(serde_json::Map::new()) - } - - /// Validate task parameters - fn validate_task_params(&self, params: &TaskQueryParams) -> Result<(), A2AError> { - if params.id.trim().is_empty() { - return Err(A2AError::ValidationError { - field: "task_id".to_string(), - message: "Task ID cannot be empty".to_string(), - }); - } - - if let Some(history_length) = params.history_length { - if history_length > 1000 { - return Err(A2AError::ValidationError { - field: "history_length".to_string(), - message: "History length cannot exceed 1000".to_string(), - }); - } - } - - Ok(()) - } + /// Check whether a task exists. + async fn exists(&self, id: &TaskId) -> Result; } -#[cfg(feature = "server")] +/// Async task querying: listing tasks with filtering and pagination. +/// +/// Kept distinct from [`AsyncTaskLifecycle`] so a handler that only stores and +/// mutates individual tasks is not forced to implement cross-task search. #[async_trait] -/// An async trait for managing task lifecycle and operations -pub trait AsyncTaskManager: Send + Sync { - /// Create a new task - async fn create_task(&self, task_id: &str, context_id: &str) -> Result; +pub trait AsyncTaskQuery: Send + Sync { + /// List tasks with filtering and pagination (A2A v1.0.0 `tasks/list`). + async fn list(&self, params: &ListTasksParams) -> Result; +} - /// Get a task by ID with optional history - async fn get_task(&self, task_id: &str, history_length: Option) -> Result; +/// Optimistic-concurrency control over task mutations. +/// +/// A distinct capability from [`AsyncTaskLifecycle`] (hex rule 2 — narrow ports): +/// a store that needs lost-update protection implements this, while the plain +/// lifecycle path stays version-free for callers that don't. The version is a +/// monotonic counter the store bumps on **every** successful mutation, including +/// the unversioned [`AsyncTaskLifecycle`] writes — so the two views never drift. +/// +/// The classic read-modify-write loop: +/// +/// ``` +/// # use a2a_rs::{AsyncTaskVersioning, VersionedTask}; +/// # use a2a_rs::domain::{A2AError, Message, TaskId, TaskState}; +/// # async fn read_modify_write( +/// # store: &impl AsyncTaskVersioning, +/// # id: TaskId, +/// # next_state: TaskState, +/// # msg: Option, +/// # ) -> Result<(), A2AError> { +/// let VersionedTask { task, version } = store.get_versioned(&id, None).await?; +/// // … decide the next state from `task` … +/// let _ = &task; +/// match store.update_status_checked(&id, version, next_state, msg).await { +/// Ok(updated) => { /* committed at updated.version */ } +/// Err(A2AError::VersionConflict { .. }) => { /* re-read and retry */ } +/// Err(e) => return Err(e), +/// } +/// # Ok(()) +/// # } +/// ``` +#[async_trait] +pub trait AsyncTaskVersioning: Send + Sync { + /// Current stored version of a task. Bumped on every successful mutation. + async fn version(&self, id: &TaskId) -> Result; - /// Update task status with an optional message to add to history - async fn update_task_status( + /// Fetch a task together with its current version (history-limited as in + /// [`AsyncTaskLifecycle::get`]). + async fn get_versioned( &self, - task_id: &str, + id: &TaskId, + history_length: Option, + ) -> Result; + + /// Update status only if the stored version equals `expected`. + /// + /// On success returns the mutated task and its newly bumped version. If the + /// stored version has advanced past `expected`, fails with + /// [`A2AError::VersionConflict`] and leaves the task untouched. + async fn update_status_checked( + &self, + id: &TaskId, + expected: u64, state: TaskState, message: Option, - ) -> Result; - - /// Cancel a task - async fn cancel_task(&self, task_id: &str) -> Result; - - /// Check if a task exists - async fn task_exists(&self, task_id: &str) -> Result; - - /// List tasks with optional filtering - async fn list_tasks( - &self, - _context_id: Option<&str>, - _limit: Option, - ) -> Result, A2AError> { - // Default implementation - can be overridden - // Basic implementation that doesn't support filtering - Err(A2AError::UnsupportedOperation( - "Task listing not implemented".to_string(), - )) - } - - /// Get task metadata - async fn get_task_metadata( - &self, - task_id: &str, - ) -> Result, A2AError> { - let task = self.get_task(task_id, None).await?; - if let Some(metadata) = task.metadata.as_option() { - let val = serde_json::to_value(metadata)?; - if let serde_json::Value::Object(map) = val { - return Ok(map); - } - } - Ok(serde_json::Map::new()) - } - - /// Validate task parameters - async fn validate_task_params(&self, params: &TaskQueryParams) -> Result<(), A2AError> { - if params.id.trim().is_empty() { - return Err(A2AError::ValidationError { - field: "task_id".to_string(), - message: "Task ID cannot be empty".to_string(), - }); - } + ) -> Result; +} +/// Validation conveniences over [`AsyncTaskLifecycle`]. +/// +/// Blanket-implemented for every `AsyncTaskLifecycle`, so implementors get these +/// for free and only ever stub the core primitives. Constructing a [`TaskId`] +/// from request parameters performs the empty-string validation, so these +/// wrappers parse the wire parameters once at the boundary. +#[async_trait] +pub trait AsyncTaskLifecycleExt: AsyncTaskLifecycle { + /// Validate query parameters, then fetch the task. + async fn get_validated(&self, params: &TaskQueryParams) -> Result { + let id: TaskId = params.id.parse()?; if let Some(history_length) = params.history_length { if history_length > 1000 { return Err(A2AError::ValidationError { @@ -153,68 +126,14 @@ pub trait AsyncTaskManager: Send + Sync { }); } } - - Ok(()) + self.get(&id, params.history_length).await } - /// Get task with validation - async fn get_task_validated(&self, params: &TaskQueryParams) -> Result { - self.validate_task_params(params).await?; - self.get_task(¶ms.id, params.history_length).await - } - - /// Cancel task with validation - async fn cancel_task_validated(&self, params: &TaskIdParams) -> Result { - if params.id.trim().is_empty() { - return Err(A2AError::ValidationError { - field: "task_id".to_string(), - message: "Task ID cannot be empty".to_string(), - }); - } - - self.cancel_task(¶ms.id).await - } - - // ===== v1.0.0 New Methods ===== - - /// List tasks with comprehensive filtering and pagination (v1.0.0) - async fn list_tasks_v3(&self, _params: &ListTasksParams) -> Result { - // Default implementation returns unsupported error - Err(A2AError::UnsupportedOperation( - "Task listing with pagination not implemented".to_string(), - )) - } - - /// Get push notification config by ID (v1.0.0) - async fn get_push_notification_config( - &self, - _params: &GetTaskPushNotificationConfigParams, - ) -> Result { - // Default implementation returns unsupported error - Err(A2AError::UnsupportedOperation( - "Get push notification config not implemented".to_string(), - )) - } - - /// List all push notification configs for a task (v1.0.0) - async fn list_push_notification_configs( - &self, - _params: &ListTaskPushNotificationConfigsParams, - ) -> Result, A2AError> { - // Default implementation returns unsupported error - Err(A2AError::UnsupportedOperation( - "List push notification configs not implemented".to_string(), - )) - } - - /// Delete a specific push notification config (v1.0.0) - async fn delete_push_notification_config( - &self, - _params: &DeleteTaskPushNotificationConfigParams, - ) -> Result<(), A2AError> { - // Default implementation returns unsupported error - Err(A2AError::UnsupportedOperation( - "Delete push notification config not implemented".to_string(), - )) + /// Validate ID parameters, then cancel the task. + async fn cancel_validated(&self, params: &TaskIdParams) -> Result { + let id: TaskId = params.id.parse()?; + self.cancel(&id).await } } + +impl AsyncTaskLifecycleExt for T {} diff --git a/a2a-rs/src/services/client.rs b/a2a-rs/src/services/client.rs deleted file mode 100644 index 88184ee..0000000 --- a/a2a-rs/src/services/client.rs +++ /dev/null @@ -1,80 +0,0 @@ -use async_trait::async_trait; -use futures::Stream; -use std::pin::Pin; - -use crate::domain::{ - A2AError, ListTasksParams, ListTasksResult, Message, Task, TaskArtifactUpdateEvent, - TaskPushNotificationConfig, TaskStatusUpdateEvent, -}; - -#[async_trait] -/// An async trait defining the methods an async client should implement -pub trait AsyncA2AClient: Send + Sync { - /// Send a message to a task - async fn send_task_message( - &self, - task_id: &str, - message: &Message, - session_id: Option<&str>, - history_length: Option, - ) -> Result; - - /// Get a task by ID - async fn get_task(&self, task_id: &str, history_length: Option) -> Result; - - /// Cancel a task - async fn cancel_task(&self, task_id: &str) -> Result; - - /// Set up push notifications for a task - async fn set_task_push_notification( - &self, - config: &TaskPushNotificationConfig, - ) -> Result; - - /// Get push notification configuration for a task - async fn get_task_push_notification( - &self, - task_id: &str, - ) -> Result; - - /// List tasks with filtering and pagination (v1.0.0) - async fn list_tasks(&self, params: &ListTasksParams) -> Result; - - /// List all push notification configs for a task (v1.0.0) - async fn list_push_notification_configs( - &self, - task_id: &str, - ) -> Result, A2AError>; - - /// Get a specific push notification config by ID (v1.0.0) - async fn get_push_notification_config( - &self, - task_id: &str, - config_id: &str, - ) -> Result; - - /// Delete a specific push notification config (v1.0.0) - async fn delete_push_notification_config( - &self, - task_id: &str, - config_id: &str, - ) -> Result<(), A2AError>; - - /// Subscribe to task updates (for streaming) - async fn subscribe_to_task( - &self, - task_id: &str, - history_length: Option, - ) -> Result> + Send>>, A2AError>; -} - -/// Items that can be streamed from the server during task subscriptions.\n///\n/// When subscribing to streaming updates for a task, the server can send\n/// different types of items:\n/// - `Task`: The complete initial task state when subscription starts\n/// - `StatusUpdate`: Updates to the task's status (state changes, progress)\n/// - `ArtifactUpdate`: Notifications about new or updated artifacts\n///\n/// This allows clients to receive real-time updates about task progress\n/// and results as they become available. -#[derive(Debug, Clone)] -pub enum StreamItem { - /// The initial task state - Task(Task), - /// A task status update - StatusUpdate(TaskStatusUpdateEvent), - /// A task artifact update - ArtifactUpdate(TaskArtifactUpdateEvent), -} diff --git a/a2a-rs/src/services/mod.rs b/a2a-rs/src/services/mod.rs index e63ce25..7b7d341 100644 --- a/a2a-rs/src/services/mod.rs +++ b/a2a-rs/src/services/mod.rs @@ -3,14 +3,8 @@ //! Services provide application-level abstractions that orchestrate //! between ports and adapters. -#[cfg(feature = "client")] -pub mod client; - #[cfg(feature = "server")] pub mod server; -#[cfg(feature = "client")] -pub use client::{AsyncA2AClient, StreamItem}; - #[cfg(feature = "server")] pub use server::AgentInfoProvider; diff --git a/a2a-rs/tests/authenticated_card_test.rs b/a2a-rs/tests/authenticated_card_test.rs deleted file mode 100644 index 48953bd..0000000 --- a/a2a-rs/tests/authenticated_card_test.rs +++ /dev/null @@ -1,216 +0,0 @@ -//! Integration tests for agent/getAuthenticatedExtendedCard endpoint (v1.0.0) - -#![cfg(all(feature = "http-client", feature = "http-server"))] - -mod common; - -use a2a_rs::{ - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, SimpleAgentInfo, - }, - domain::A2AError, -}; -use common::TestBusinessHandler; -use std::time::Duration; -use tokio::sync::oneshot; - -async fn setup_server(port: u16, supports_authenticated_card: bool) -> oneshot::Sender<()> { - let storage = InMemoryTaskStorage::new(); - let handler = TestBusinessHandler::with_storage(storage); - - // Create a single agent info instance to use for both processor and server - let mut agent_info = SimpleAgentInfo::new( - "Authenticated Card Test Agent".to_string(), - format!("http://localhost:{}", port), - ) - .with_description("Agent for testing authenticated extended card".to_string()); - - if supports_authenticated_card { - agent_info = agent_info.with_authenticated_extended_card(); - } - - // Clone the agent info for the processor - let processor = DefaultRequestProcessor::with_handler(handler, agent_info.clone()); - - let server = HttpServer::new(processor, agent_info, format!("127.0.0.1:{}", port)); - - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => {} - } - }); - - tokio::time::sleep(Duration::from_millis(100)).await; - - shutdown_tx -} - -#[tokio::test] -async fn test_get_authenticated_extended_card_not_configured() { - let port = 9100; - let shutdown_tx = setup_server(port, false).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let response = client.get_extended_agent_card(None).await; - - // Should return error -32007 - assert!(response.is_err(), "Should have error response"); - if let Err(A2AError::JsonRpc { code, message, .. }) = response { - assert_eq!( - code, -32007, - "Should return error code -32007 (AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED)" - ); - assert!( - message.contains("not configured") - || message.contains("not supported") - || message.contains("not available"), - "Error message should indicate card not configured: {}", - message - ); - } else { - panic!("Expected JsonRpc error, got {:?}", response); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_get_authenticated_extended_card_success() { - let port = 9101; - let shutdown_tx = setup_server(port, true).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let response = client.get_extended_agent_card(None).await; - - assert!(response.is_ok(), "Should not have error response"); - let card = response.unwrap(); - - // Verify it's a valid agent card - assert_eq!(card.name, "Authenticated Card Test Agent"); - assert!(!card.description.is_empty()); - assert_eq!(card.protocol_version(), "1.0"); - - // Verify this is the authenticated version (may have additional info) - // The authenticated card should support this capability - assert!( - card.capabilities.extended_agent_card.unwrap_or(false), - "Authenticated card should indicate support" - ); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_authenticated_card_vs_regular_card() { - let port = 9102; - let shutdown_tx = setup_server(port, true).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Get regular card via HTTP endpoint - let http_client = reqwest::Client::new(); - let regular_card_response = http_client - .get(format!("http://localhost:{}/agent-card", port)) - .send() - .await - .expect("Failed to fetch regular agent card"); - - let regular_card: a2a_rs::domain::AgentCard = regular_card_response - .json() - .await - .expect("Failed to parse regular agent card"); - - // Get authenticated extended card - let auth_card = client - .get_extended_agent_card(None) - .await - .expect("Failed to get authenticated extended card"); - - // Both should have same basic info - assert_eq!(regular_card.name, auth_card.name); - assert_eq!(regular_card.url(), auth_card.url()); - assert_eq!( - regular_card.protocol_version(), - auth_card.protocol_version() - ); - - // Both should indicate support for authenticated extended card - assert!( - regular_card - .capabilities - .extended_agent_card - .unwrap_or(false) - ); - assert!(auth_card.capabilities.extended_agent_card.unwrap_or(false)); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_authenticated_card_error_structure() { - let port = 9103; - let shutdown_tx = setup_server(port, false).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Try to get authenticated card when not configured - let response = client.get_extended_agent_card(None).await; - - assert!(response.is_err()); - let error = response.unwrap_err(); - - // Verify error structure matches JSON-RPC spec - if let A2AError::JsonRpc { code, message, .. } = error { - assert_eq!(code, -32007); - assert!(!message.is_empty()); - } else { - panic!("Expected JsonRpc error, got {:?}", error); - } - - shutdown_tx.send(()).ok(); -} - -#[test] -fn test_authenticated_extended_card_error_code() { - // Test that the error enum produces correct error code - let error = A2AError::AuthenticatedExtendedCardNotConfigured; - let jsonrpc_error = error.to_jsonrpc_error(); - - assert_eq!(jsonrpc_error["code"], -32007); - assert_eq!( - jsonrpc_error["message"], - "Authenticated Extended Card is not configured" - ); -} - -#[tokio::test] -async fn test_authenticated_card_with_extensions() { - let port = 9104; - let shutdown_tx = setup_server(port, true).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Get authenticated card - let card = client - .get_extended_agent_card(None) - .await - .expect("Failed to get card"); - - // Card should have v1.0.0 fields - assert_eq!(card.protocol_version(), "1.0"); - assert_eq!(card.preferred_transport(), "JSONRPC"); - - // Should be able to have extensions (even if empty) - // This verifies the authenticated card includes all v1.0.0 fields - let capabilities = &card.capabilities; - // Extensions field should be present in capabilities - // (may be None or Some(vec![])) - let _ = &capabilities.extensions; - - shutdown_tx.send(()).ok(); -} diff --git a/a2a-rs/tests/client_v3_methods_test.rs b/a2a-rs/tests/client_v3_methods_test.rs deleted file mode 100644 index 6dbf9ac..0000000 --- a/a2a-rs/tests/client_v3_methods_test.rs +++ /dev/null @@ -1,623 +0,0 @@ -//! A2A Protocol v1.0.0 Client SDK Tests -//! -//! This module tests the HttpClient and WebSocketClient wrappers for v1.0.0 methods - -mod common; - -use a2a_rs::{ - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, SimpleAgentInfo, - }, - domain::{ListTasksParams, Message, TaskState}, - port::AsyncTaskManager, - services::AsyncA2AClient, -}; -use common::TestBusinessHandler; -use std::time::Duration; -use tokio::sync::oneshot; - -/// Helper function to setup a test server with pre-populated tasks -async fn setup_test_server(port: u16) -> (oneshot::Sender<()>, InMemoryTaskStorage) { - let storage = InMemoryTaskStorage::new(); - - // Create some test tasks - for i in 0..5 { - let task_id = format!("task-{}", i); - let context_id = format!("ctx-{}", i % 2); // Alternate between ctx-0 and ctx-1 - - storage - .create_task(&task_id, &context_id) - .await - .expect("Failed to create task"); - - // Update task status with a message - storage - .update_task_status( - &task_id, - if i % 2 == 0 { - TaskState::Working - } else { - TaskState::Completed - }, - Some(Message::agent_text( - format!("Task {} message", i), - format!("msg-{}", i), - )), - ) - .await - .expect("Failed to update task"); - } - - let handler = TestBusinessHandler::with_storage(storage.clone()); - let agent_info = SimpleAgentInfo::new( - "Test Agent v1.0.0".to_string(), - format!("http://localhost:{}", port), - ) - .with_version("2.0.0".to_string()) - .with_description("Agent for v1.0.0 testing".to_string()); - - let processor = DefaultRequestProcessor::with_handler(handler, agent_info.clone()); - let server = HttpServer::new(processor, agent_info, format!("127.0.0.1:{}", port)); - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => {} - } - }); - - // Give server time to start - tokio::time::sleep(Duration::from_millis(100)).await; - - (shutdown_tx, storage) -} - -#[tokio::test] -async fn test_http_client_list_tasks() { - let port = 9600; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test basic listing without filters - let params = ListTasksParams::default(); - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - println!("List tasks result: {:?}", result); - - assert_eq!(result.total_size, 5, "Should have 5 tasks total"); - assert!(!result.tasks.is_empty(), "Should return tasks"); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_list_tasks_with_filters() { - let port = 9601; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test filtering by context - let params = ListTasksParams { - context_id: Some("ctx-0".to_string()), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with context filter"); - - println!("Filtered tasks by context: {:?}", result); - - // All returned tasks should have context_id "ctx-0" - for task in &result.tasks { - assert_eq!(task.context_id, "ctx-0", "Task should match context filter"); - } - - // Test filtering by status - let params = ListTasksParams { - status: Some(TaskState::Working), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with status filter"); - - println!("Filtered tasks by status: {:?}", result); - - for task in &result.tasks { - assert_eq!( - task.status.state, - TaskState::Working, - "Task should match status filter" - ); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_list_tasks_pagination() { - let port = 9602; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test pagination with small page size - let params = ListTasksParams { - page_size: Some(2), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with pagination"); - - println!("Paginated result: {:?}", result); - - assert!(result.tasks.len() <= 2, "Should return at most 2 tasks"); - assert_eq!(result.total_size, 5, "Total size should still be 5"); - - // If there's a next page token, fetch the next page - if !result.next_page_token.is_empty() { - let next_token = result.next_page_token; - let params = ListTasksParams { - page_size: Some(2), - page_token: Some(next_token), - ..Default::default() - }; - - let next_result = client - .list_tasks(¶ms) - .await - .expect("Failed to fetch next page"); - - println!("Next page result: {:?}", next_result); - assert!(!next_result.tasks.is_empty(), "Next page should have tasks"); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_list_tasks_history_length() { - let port = 9603; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test with history_length parameter - let params = ListTasksParams { - history_length: Some(5), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with history_length"); - - println!("Tasks with history: {:?}", result); - - // Verify that tasks have history - for task in &result.tasks { - let history = &task.history; - assert!(history.len() <= 5, "History should be limited to 5 entries"); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_list_tasks_include_artifacts() { - let port = 9604; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test with include_artifacts = true - let params = ListTasksParams { - include_artifacts: Some(true), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with artifacts"); - - println!("Tasks with artifacts: {:?}", result); - assert!(!result.tasks.is_empty(), "Should return tasks"); - - // Test with include_artifacts = false - let params = ListTasksParams { - include_artifacts: Some(false), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks without artifacts"); - - println!("Tasks without artifacts: {:?}", result); - assert!(!result.tasks.is_empty(), "Should return tasks"); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_push_config_list() { - let port = 9605; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // First, set a push notification config - let task_id = "task-0"; - let config = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id.to_string(), - id: "config-1".to_string(), - url: "https://client.example.com/webhook".to_string(), - token: "test-token".to_string(), - authentication: None.into(), - ..Default::default() - }; - - client - .set_task_push_notification(&config) - .await - .expect("Failed to set push notification config"); - - // Now list configs for the task - let configs = client - .list_push_notification_configs(task_id) - .await - .expect("Failed to list push configs"); - - println!("Push configs: {:?}", configs); - - assert!(!configs.is_empty(), "Should have at least one config"); - assert_eq!(configs[0].id, "config-1".to_string()); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_push_config_get() { - let port = 9606; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Set a push notification config - let task_id = "task-0"; - let config = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id.to_string(), - id: "config-get-test".to_string(), - url: "https://client.example.com/webhook".to_string(), - token: "test-token-get".to_string(), - authentication: None.into(), - ..Default::default() - }; - - client - .set_task_push_notification(&config) - .await - .expect("Failed to set push notification config"); - - // Get the specific config by ID - let retrieved_config = client - .get_push_notification_config(task_id, "config-get-test") - .await - .expect("Failed to get push config"); - - println!("Retrieved push config: {:?}", retrieved_config); - - assert_eq!(retrieved_config.id, "config-get-test".to_string()); - assert_eq!(retrieved_config.url, "https://client.example.com/webhook"); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_push_config_delete() { - let port = 9607; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Set a push notification config - let task_id = "task-0"; - let config = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id.to_string(), - id: "config-delete-test".to_string(), - url: "https://client.example.com/webhook".to_string(), - token: "test-token-delete".to_string(), - authentication: None.into(), - ..Default::default() - }; - - client - .set_task_push_notification(&config) - .await - .expect("Failed to set push notification config"); - - // Verify it exists - let configs_before = client - .list_push_notification_configs(task_id) - .await - .expect("Failed to list configs"); - - assert!( - !configs_before.is_empty(), - "Config should exist before deletion" - ); - - // Delete the config - client - .delete_push_notification_config(task_id, "config-delete-test") - .await - .expect("Failed to delete push config"); - - // Verify it's deleted by listing again - let configs_after = client - .list_push_notification_configs(task_id) - .await - .expect("Failed to list configs after deletion"); - - // The config should either be empty or not contain our deleted config - if !configs_after.is_empty() { - assert!( - configs_after.iter().all(|c| c.id != "config-delete-test"), - "Deleted config should not appear in list" - ); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_push_config_multiple() { - let port = 9608; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let task_id = "task-0"; - - // Set multiple push notification configs - for i in 1..=3 { - let config = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id.to_string(), - id: format!("config-multi-{}", i), - url: format!("https://client.example.com/webhook-{}", i), - token: format!("token-{}", i), - authentication: None.into(), - ..Default::default() - }; - - client - .set_task_push_notification(&config) - .await - .expect("Failed to set push notification config"); - } - - // List all configs - let configs = client - .list_push_notification_configs(task_id) - .await - .expect("Failed to list configs"); - - println!("Multiple configs: {:?}", configs); - - // Since set_task_push_notification replaces configs (not appends), - // we should have the last config set - assert!(!configs.is_empty(), "Should have at least one config"); - - // The config should be one of the ones we set - let has_our_configs = configs.iter().any(|c| c.id.starts_with("config-multi-")); - assert!(has_our_configs, "Should have our configs"); - - // Verify we can retrieve the configs that exist - for config_wrapper in &configs { - let config_id = &config_wrapper.id; - let retrieved = client - .get_push_notification_config(task_id, config_id) - .await - .expect("Failed to get individual config"); - - assert_eq!(retrieved.id, config_id.clone()); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_with_authentication() { - let port = 9609; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - // Create client with authentication token - let client = HttpClient::with_auth( - format!("http://localhost:{}", port), - "test-bearer-token".to_string(), - ); - - // The server doesn't actually validate the token in this test, - // but we verify the client sends requests successfully - let params = ListTasksParams::default(); - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with auth"); - - assert_eq!(result.total_size, 5); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_with_timeout() { - let port = 9610; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - // Create client with custom timeout - let client = HttpClient::new(format!("http://localhost:{}", port)).with_timeout(60); - - let params = ListTasksParams::default(); - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with timeout"); - - assert_eq!(result.total_size, 5); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_list_tasks_combined_filters() { - let port = 9611; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test with multiple filters combined - let params = ListTasksParams { - context_id: Some("ctx-0".to_string()), - status: Some(TaskState::Working), - page_size: Some(10), - history_length: Some(5), - include_artifacts: Some(true), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with combined filters"); - - println!("Combined filter result: {:?}", result); - - // Verify filters are applied - for task in &result.tasks { - assert_eq!(task.context_id, "ctx-0"); - assert_eq!(task.status.state, TaskState::Working); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_push_config_not_found() { - let port = 9612; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Try to get a non-existent config - let result = client - .get_push_notification_config("task-0", "non-existent-config") - .await; - - println!("Not found result: {:?}", result); - - // Should return an error (either TaskNotFound or config not found) - assert!( - result.is_err(), - "Should return error for non-existent config" - ); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_list_tasks_empty_result() { - let port = 9613; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Filter for a context that doesn't exist - let params = ListTasksParams { - context_id: Some("non-existent-context".to_string()), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks with non-existent context"); - - println!("Empty result: {:?}", result); - - assert_eq!(result.tasks.len(), 0, "Should return empty task list"); - assert_eq!(result.total_size, 0, "Total size should be 0"); - assert!(result.next_page_token.is_empty(), "No next page token"); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_http_client_delete_push_config_idempotent() { - let port = 9614; - let (shutdown_tx, _storage) = setup_test_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let task_id = "task-0"; - let config_id = "config-idempotent-test"; - - // Set a config - let config = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id.to_string(), - id: config_id.to_string(), - url: "https://client.example.com/webhook".to_string(), - token: "test-token".to_string(), - authentication: None.into(), - ..Default::default() - }; - - client - .set_task_push_notification(&config) - .await - .expect("Failed to set push notification config"); - - // Delete it once - client - .delete_push_notification_config(task_id, config_id) - .await - .expect("Failed to delete push config first time"); - - // Delete it again (should be idempotent, no error) - let result = client - .delete_push_notification_config(task_id, config_id) - .await; - - println!("Idempotent delete result: {:?}", result); - - // According to REST principles, DELETE should be idempotent - // The implementation may return success or an error, but it shouldn't crash - // We accept either outcome as valid depending on implementation - assert!( - result.is_ok() || result.is_err(), - "Delete should complete without panic" - ); - - shutdown_tx.send(()).ok(); -} diff --git a/a2a-rs/tests/common/test_handler.rs b/a2a-rs/tests/common/test_handler.rs index ad35c93..76a0851 100644 --- a/a2a-rs/tests/common/test_handler.rs +++ b/a2a-rs/tests/common/test_handler.rs @@ -8,12 +8,17 @@ use std::sync::Arc; use async_trait::async_trait; use a2a_rs::{ - adapter::{business::DefaultMessageHandler, storage::InMemoryTaskStorage}, - domain::{A2AError, Message, Task, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent}, + adapter::{ + business::ResponderMessageHandler, storage::InMemoryTaskStorage, + streaming::InMemoryStreamingHandler, + }, + domain::{ + A2AError, ContextId, Message, Task, TaskArtifactUpdateEvent, TaskId, TaskState, + TaskStatusUpdateEvent, + }, port::{ - AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager, - MessageHandler, NotificationManager, StreamingHandler, TaskManager, - streaming_handler::Subscriber, + AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, + AsyncTaskQuery, streaming_handler::Subscriber, }, }; @@ -21,8 +26,10 @@ use a2a_rs::{ /// by delegating to InMemoryTaskStorage #[derive(Clone)] pub struct TestBusinessHandler { - /// Task storage that implements all the business capabilities + /// Task storage (persistence + push-config CRUD) storage: Arc, + /// Dedicated streaming fan-out + streaming: InMemoryStreamingHandler, } impl TestBusinessHandler { @@ -30,6 +37,7 @@ impl TestBusinessHandler { pub fn new() -> Self { Self { storage: Arc::new(InMemoryTaskStorage::new()), + streaming: InMemoryStreamingHandler::new(), } } @@ -38,6 +46,7 @@ impl TestBusinessHandler { pub fn with_storage(storage: InMemoryTaskStorage) -> Self { Self { storage: Arc::new(storage), + streaming: InMemoryStreamingHandler::new(), } } @@ -54,123 +63,6 @@ impl Default for TestBusinessHandler { } } -// Synchronous trait implementations - not supported since we use async storage -impl MessageHandler for TestBusinessHandler { - fn process_message( - &self, - _task_id: &str, - _message: &Message, - _session_id: Option<&str>, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous message processing not supported. Use async version.".to_string(), - )) - } -} - -impl TaskManager for TestBusinessHandler { - fn create_task(&self, _task_id: &str, _context_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task creation not supported. Use async version.".to_string(), - )) - } - - fn get_task(&self, _task_id: &str, _history_length: Option) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task retrieval not supported. Use async version.".to_string(), - )) - } - - fn update_task_status( - &self, - _task_id: &str, - _state: TaskState, - _message: Option, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task status update not supported. Use async version.".to_string(), - )) - } - - fn cancel_task(&self, _task_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task cancellation not supported. Use async version.".to_string(), - )) - } - - fn task_exists(&self, _task_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous task existence check not supported. Use async version.".to_string(), - )) - } -} - -impl NotificationManager for TestBusinessHandler { - fn set_task_notification( - &self, - _config: &a2a_rs::domain::TaskPushNotificationConfig, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous notification setup not supported. Use async version.".to_string(), - )) - } - - fn get_task_notification( - &self, - _task_id: &str, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous notification retrieval not supported. Use async version.".to_string(), - )) - } - - fn remove_task_notification(&self, _task_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Synchronous notification removal not supported. Use async version.".to_string(), - )) - } -} - -impl StreamingHandler for TestBusinessHandler { - fn add_status_subscriber( - &self, - _task_id: &str, - _subscriber: Box + Send + Sync>, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming subscription not supported. Use async version.".to_string(), - )) - } - - fn add_artifact_subscriber( - &self, - _task_id: &str, - _subscriber: Box + Send + Sync>, - ) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming subscription not supported. Use async version.".to_string(), - )) - } - - fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming unsubscription not supported. Use async version.".to_string(), - )) - } - - fn remove_task_subscribers(&self, _task_id: &str) -> Result<(), A2AError> { - Err(A2AError::UnsupportedOperation( - "Synchronous streaming unsubscription not supported. Use async version.".to_string(), - )) - } - - fn get_subscriber_count(&self, _task_id: &str) -> Result { - Err(A2AError::UnsupportedOperation( - "Synchronous subscriber count not supported. Use async version.".to_string(), - )) - } -} - // Asynchronous trait implementations - delegate to storage #[async_trait] @@ -181,8 +73,12 @@ impl AsyncMessageHandler for TestBusinessHandler { message: &Message, session_id: Option<&str>, ) -> Result { - // Create a message handler and delegate - let message_handler = DefaultMessageHandler::new((*self.storage).clone()); + // Create a message handler and delegate, sharing the streaming handler. + let message_handler = ResponderMessageHandler::echo( + (*self.storage).clone(), + self.streaming.clone(), + self.storage.push_notifier(), + ); message_handler .process_message(task_id, message, session_id) .await @@ -190,81 +86,71 @@ impl AsyncMessageHandler for TestBusinessHandler { } #[async_trait] -impl AsyncTaskManager for TestBusinessHandler { - async fn create_task(&self, task_id: &str, context_id: &str) -> Result { - self.storage.create_task(task_id, context_id).await +impl AsyncTaskLifecycle for TestBusinessHandler { + async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result { + self.storage.create(id, context_id).await } - async fn get_task(&self, task_id: &str, history_length: Option) -> Result { - self.storage.get_task(task_id, history_length).await + async fn get(&self, id: &TaskId, history_length: Option) -> Result { + self.storage.get(id, history_length).await } - async fn update_task_status( + async fn update_status( &self, - task_id: &str, + id: &TaskId, state: TaskState, message: Option, ) -> Result { - self.storage - .update_task_status(task_id, state, message) - .await + self.storage.update_status(id, state, message).await } - async fn cancel_task(&self, task_id: &str) -> Result { - self.storage.cancel_task(task_id).await + async fn cancel(&self, id: &TaskId) -> Result { + self.storage.cancel(id).await } - async fn task_exists(&self, task_id: &str) -> Result { - self.storage.task_exists(task_id).await + async fn exists(&self, id: &TaskId) -> Result { + self.storage.exists(id).await } +} - async fn list_tasks_v3( +#[async_trait] +impl AsyncTaskQuery for TestBusinessHandler { + async fn list( &self, params: &a2a_rs::domain::ListTasksParams, ) -> Result { - self.storage.list_tasks_v3(params).await + self.storage.list(params).await } +} - async fn get_push_notification_config( +#[async_trait] +impl AsyncNotificationManager for TestBusinessHandler { + async fn set_config( + &self, + config: &a2a_rs::domain::TaskPushNotificationConfig, + ) -> Result { + self.storage.set_config(config).await + } + + async fn get_config( &self, params: &a2a_rs::domain::GetTaskPushNotificationConfigParams, ) -> Result { - self.storage.get_push_notification_config(params).await + self.storage.get_config(params).await } - async fn list_push_notification_configs( + async fn list_configs( &self, params: &a2a_rs::domain::ListTaskPushNotificationConfigsParams, ) -> Result, A2AError> { - self.storage.list_push_notification_configs(params).await + self.storage.list_configs(params).await } - async fn delete_push_notification_config( + async fn delete_config( &self, params: &a2a_rs::domain::DeleteTaskPushNotificationConfigParams, ) -> Result<(), A2AError> { - self.storage.delete_push_notification_config(params).await - } -} - -#[async_trait] -impl AsyncNotificationManager for TestBusinessHandler { - async fn set_task_notification( - &self, - config: &a2a_rs::domain::TaskPushNotificationConfig, - ) -> Result { - self.storage.set_task_notification(config).await - } - - async fn get_task_notification( - &self, - task_id: &str, - ) -> Result { - self.storage.get_task_notification(task_id).await - } - - async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> { - self.storage.remove_task_notification(task_id).await + self.storage.delete_config(params).await } } @@ -275,7 +161,7 @@ impl AsyncStreamingHandler for TestBusinessHandler { task_id: &str, subscriber: Box + Send + Sync>, ) -> Result { - self.storage + self.streaming .add_status_subscriber(task_id, subscriber) .await } @@ -285,21 +171,21 @@ impl AsyncStreamingHandler for TestBusinessHandler { task_id: &str, subscriber: Box + Send + Sync>, ) -> Result { - self.storage + self.streaming .add_artifact_subscriber(task_id, subscriber) .await } async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> { - self.storage.remove_subscription(subscription_id).await + self.streaming.remove_subscription(subscription_id).await } async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> { - self.storage.remove_task_subscribers(task_id).await + self.streaming.remove_task_subscribers(task_id).await } async fn get_subscriber_count(&self, task_id: &str) -> Result { - self.storage.get_subscriber_count(task_id).await + self.streaming.get_subscriber_count(task_id).await } async fn broadcast_status_update( @@ -307,7 +193,9 @@ impl AsyncStreamingHandler for TestBusinessHandler { task_id: &str, update: TaskStatusUpdateEvent, ) -> Result<(), A2AError> { - self.storage.broadcast_status_update(task_id, update).await + self.streaming + .broadcast_status_update(task_id, update) + .await } async fn broadcast_artifact_update( @@ -315,7 +203,7 @@ impl AsyncStreamingHandler for TestBusinessHandler { task_id: &str, update: TaskArtifactUpdateEvent, ) -> Result<(), A2AError> { - self.storage + self.streaming .broadcast_artifact_update(task_id, update) .await } @@ -329,7 +217,7 @@ impl AsyncStreamingHandler for TestBusinessHandler { >, A2AError, > { - self.storage.status_update_stream(task_id).await + self.streaming.status_update_stream(task_id).await } async fn artifact_update_stream( @@ -341,22 +229,25 @@ impl AsyncStreamingHandler for TestBusinessHandler { >, A2AError, > { - self.storage.artifact_update_stream(task_id).await + self.streaming.artifact_update_stream(task_id).await } async fn combined_update_stream( &self, task_id: &str, + from_event_id: Option, ) -> Result< std::pin::Pin< Box< dyn futures::Stream< - Item = Result, + Item = Result, > + Send, >, >, A2AError, > { - self.storage.combined_update_stream(task_id).await + self.streaming + .combined_update_stream(task_id, from_event_id) + .await } } diff --git a/a2a-rs/tests/in_memory_storage_v3_test.rs b/a2a-rs/tests/in_memory_storage_v3_test.rs index 9c7fda6..6f8d302 100644 --- a/a2a-rs/tests/in_memory_storage_v3_test.rs +++ b/a2a-rs/tests/in_memory_storage_v3_test.rs @@ -7,10 +7,18 @@ use a2a_rs::{ DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, ListTaskPushNotificationConfigsParams, ListTasksParams, TaskState, }, - port::{AsyncNotificationManager, AsyncTaskManager}, + port::{AsyncNotificationManager, AsyncTaskLifecycle, AsyncTaskQuery}, }; use std::time::Duration; +fn tid(s: &str) -> a2a_rs::domain::TaskId { + s.parse().unwrap() +} + +fn cid(s: &str) -> a2a_rs::domain::ContextId { + s.parse().unwrap() +} + /// Helper to create tasks with different states and contexts async fn create_test_tasks( storage: &InMemoryTaskStorage, @@ -21,7 +29,7 @@ async fn create_test_tasks( for i in 0..count { let task_id = format!("test-task-{}-{}", context_id, i); storage - .create_task(&task_id, context_id) + .create(&tid(&task_id), &cid(context_id)) .await .expect("Failed to create task"); task_ids.push(task_id); @@ -40,10 +48,7 @@ async fn test_list_tasks_v3_basic() { // List all tasks with default parameters let params = ListTasksParams::default(); - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 5, "Should have 5 tasks"); assert_eq!(result.tasks.len(), 5, "Should return 5 tasks"); @@ -78,10 +83,7 @@ async fn test_list_tasks_v3_filter_by_context() { context_id: Some("context-a".to_string()), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 3, "Should have 3 tasks in context A"); assert_eq!(result.tasks.len(), 3); @@ -95,10 +97,7 @@ async fn test_list_tasks_v3_filter_by_context() { context_id: Some("context-b".to_string()), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 2, "Should have 2 tasks in context B"); for task in &result.tasks { @@ -116,23 +115,23 @@ async fn test_list_tasks_v3_filter_by_status() { // Update tasks to different states storage - .update_task_status(&task_ids[0], TaskState::Working, None) + .update_status(&tid(&task_ids[0]), TaskState::Working, None) .await .expect("Failed to update task"); storage - .update_task_status(&task_ids[1], TaskState::Working, None) + .update_status(&tid(&task_ids[1]), TaskState::Working, None) .await .expect("Failed to update task"); storage - .update_task_status(&task_ids[2], TaskState::Completed, None) + .update_status(&tid(&task_ids[2]), TaskState::Completed, None) .await .expect("Failed to update task"); storage - .update_task_status(&task_ids[3], TaskState::Completed, None) + .update_status(&tid(&task_ids[3]), TaskState::Completed, None) .await .expect("Failed to update task"); storage - .update_task_status(&task_ids[4], TaskState::Failed, None) + .update_status(&tid(&task_ids[4]), TaskState::Failed, None) .await .expect("Failed to update task"); // task_ids[5] remains in Submitted state @@ -142,10 +141,7 @@ async fn test_list_tasks_v3_filter_by_status() { status: Some(TaskState::Working), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 2, "Should have 2 working tasks"); for task in &result.tasks { @@ -157,10 +153,7 @@ async fn test_list_tasks_v3_filter_by_status() { status: Some(TaskState::Completed), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 2, "Should have 2 completed tasks"); for task in &result.tasks { @@ -172,10 +165,7 @@ async fn test_list_tasks_v3_filter_by_status() { status: Some(TaskState::Failed), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 1, "Should have 1 failed task"); assert_eq!(result.tasks[0].status.state, TaskState::Failed); @@ -185,10 +175,7 @@ async fn test_list_tasks_v3_filter_by_status() { status: Some(TaskState::Submitted), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 1, "Should have 1 submitted task"); assert_eq!(result.tasks[0].status.state, TaskState::Submitted); @@ -203,7 +190,7 @@ async fn test_list_tasks_v3_filter_by_last_updated_after() { // Get the timestamp of the 3rd task (index 2) let middle_task = storage - .get_task(&task_ids[2], None) + .get(&tid(&task_ids[2]), None) .await .expect("Failed to get task"); let middle_timestamp = middle_task @@ -224,10 +211,7 @@ async fn test_list_tasks_v3_filter_by_last_updated_after() { status_timestamp_after: Some(middle_rfc), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); // Should get tasks 3 and 4 (created after task 2) assert_eq!( @@ -258,19 +242,19 @@ async fn test_list_tasks_v3_combined_filters() { // Set different states for context A tasks storage - .update_task_status(&context_a_ids[0], TaskState::Working, None) + .update_status(&tid(&context_a_ids[0]), TaskState::Working, None) .await .expect("Failed to update task"); storage - .update_task_status(&context_a_ids[1], TaskState::Working, None) + .update_status(&tid(&context_a_ids[1]), TaskState::Working, None) .await .expect("Failed to update task"); storage - .update_task_status(&context_a_ids[2], TaskState::Completed, None) + .update_status(&tid(&context_a_ids[2]), TaskState::Completed, None) .await .expect("Failed to update task"); storage - .update_task_status(&context_a_ids[3], TaskState::Completed, None) + .update_status(&tid(&context_a_ids[3]), TaskState::Completed, None) .await .expect("Failed to update task"); // context_a_ids[4] remains Submitted @@ -281,10 +265,7 @@ async fn test_list_tasks_v3_combined_filters() { status: Some(TaskState::Working), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!( result.total_size, 2, @@ -301,10 +282,7 @@ async fn test_list_tasks_v3_combined_filters() { status: Some(TaskState::Completed), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!( result.total_size, 2, @@ -328,10 +306,7 @@ async fn test_list_tasks_v3_pagination_basic() { page_size: Some(3), ..Default::default() }; - let page1 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page1 = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(page1.total_size, 10, "Total size should be 10"); assert_eq!(page1.page_size, 3, "Page size should be 3"); @@ -344,10 +319,7 @@ async fn test_list_tasks_v3_pagination_basic() { page_token: Some(page1.next_page_token.clone()), ..Default::default() }; - let page2 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page2 = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(page2.total_size, 10); assert_eq!(page2.page_size, 3); @@ -377,10 +349,7 @@ async fn test_list_tasks_v3_pagination_last_page() { page_size: Some(3), ..Default::default() }; - let page1 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page1 = storage.list(¶ms).await.expect("Failed to list tasks"); // Get second page (3 tasks) let params = ListTasksParams { @@ -388,10 +357,7 @@ async fn test_list_tasks_v3_pagination_last_page() { page_token: Some(page1.next_page_token.clone()), ..Default::default() }; - let page2 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page2 = storage.list(¶ms).await.expect("Failed to list tasks"); // Get last page (1 task) let params = ListTasksParams { @@ -399,10 +365,7 @@ async fn test_list_tasks_v3_pagination_last_page() { page_token: Some(page2.next_page_token.clone()), ..Default::default() }; - let page3 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page3 = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(page3.tasks.len(), 1, "Last page should have 1 task"); assert!( @@ -423,10 +386,7 @@ async fn test_list_tasks_v3_page_size_clamping() { page_size: Some(200), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.page_size, 100, "Page size should be clamped to 100"); @@ -435,10 +395,7 @@ async fn test_list_tasks_v3_page_size_clamping() { page_size: Some(0), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.page_size, 1, "Page size should be clamped to 1"); assert_eq!(result.tasks.len(), 1, "Should return 1 task"); @@ -453,10 +410,7 @@ async fn test_list_tasks_v3_large_dataset() { // Get first page with default page size (50) let params = ListTasksParams::default(); - let page1 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page1 = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(page1.total_size, 150); assert_eq!(page1.page_size, 50); @@ -468,10 +422,7 @@ async fn test_list_tasks_v3_large_dataset() { page_token: Some(page1.next_page_token.clone()), ..Default::default() }; - let page2 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page2 = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(page2.tasks.len(), 50); assert!(!page2.next_page_token.is_empty()); @@ -481,10 +432,7 @@ async fn test_list_tasks_v3_large_dataset() { page_token: Some(page2.next_page_token.clone()), ..Default::default() }; - let page3 = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let page3 = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(page3.tasks.len(), 50); assert!( @@ -510,14 +458,14 @@ async fn test_list_tasks_v3_history_length() { // Create a task let task_id = "history-task"; storage - .create_task(task_id, "test-context") + .create(&tid(task_id), &cid("test-context")) .await .expect("Failed to create task"); // Make several state transitions to create history storage - .update_task_status( - task_id, + .update_status( + &tid(task_id), TaskState::Working, Some(a2a_rs::Message::user_text( "working 1".to_string(), @@ -527,8 +475,8 @@ async fn test_list_tasks_v3_history_length() { .await .expect("Failed to update"); storage - .update_task_status( - task_id, + .update_status( + &tid(task_id), TaskState::InputRequired, Some(a2a_rs::Message::user_text( "input required".to_string(), @@ -538,8 +486,8 @@ async fn test_list_tasks_v3_history_length() { .await .expect("Failed to update"); storage - .update_task_status( - task_id, + .update_status( + &tid(task_id), TaskState::Working, Some(a2a_rs::Message::user_text( "working 2".to_string(), @@ -549,8 +497,8 @@ async fn test_list_tasks_v3_history_length() { .await .expect("Failed to update"); storage - .update_task_status( - task_id, + .update_status( + &tid(task_id), TaskState::Completed, Some(a2a_rs::Message::user_text( "completed".to_string(), @@ -565,10 +513,7 @@ async fn test_list_tasks_v3_history_length() { history_length: Some(2), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); let task = &result.tasks[0]; assert_eq!(task.history.len(), 2, "History should be limited to 2"); @@ -578,10 +523,7 @@ async fn test_list_tasks_v3_history_length() { history_length: Some(0), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); let task = &result.tasks[0]; assert!( @@ -596,7 +538,7 @@ async fn test_list_tasks_v3_include_artifacts() { let task_id = "artifact-task"; let mut task = storage - .create_task(task_id, "test-context") + .create(&tid(task_id), &cid("test-context")) .await .expect("Failed to create task"); @@ -613,7 +555,7 @@ async fn test_list_tasks_v3_include_artifacts() { // Update task in storage (through status update to trigger save) storage - .update_task_status(task_id, TaskState::Working, None) + .update_status(&tid(task_id), TaskState::Working, None) .await .expect("Failed to update task"); @@ -622,10 +564,7 @@ async fn test_list_tasks_v3_include_artifacts() { include_artifacts: Some(false), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); let task = &result.tasks[0]; assert!( @@ -638,10 +577,7 @@ async fn test_list_tasks_v3_include_artifacts() { include_artifacts: Some(true), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); let _task = &result.tasks[0]; // Note: Current implementation may not persist artifacts properly @@ -654,10 +590,7 @@ async fn test_list_tasks_v3_empty_results() { // List tasks when storage is empty let params = ListTasksParams::default(); - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 0); assert_eq!(result.tasks.len(), 0); @@ -671,10 +604,7 @@ async fn test_list_tasks_v3_empty_results() { context_id: Some("non-existent-context".to_string()), ..Default::default() }; - let result = storage - .list_tasks_v3(¶ms) - .await - .expect("Failed to list tasks"); + let result = storage.list(¶ms).await.expect("Failed to list tasks"); assert_eq!(result.total_size, 0); assert_eq!(result.tasks.len(), 0); @@ -687,7 +617,7 @@ async fn test_get_push_notification_config() { let task_id = "push-config-task"; storage - .create_task(task_id, "test-context") + .create(&tid(task_id), &cid("test-context")) .await .expect("Failed to create task"); @@ -703,7 +633,7 @@ async fn test_get_push_notification_config() { }; storage - .set_task_notification(&config) + .set_config(&config) .await .expect("Failed to set notification"); @@ -715,7 +645,7 @@ async fn test_get_push_notification_config() { }; let retrieved = storage - .get_push_notification_config(¶ms) + .get_config(¶ms) .await .expect("Failed to get config"); @@ -736,7 +666,7 @@ async fn test_get_push_notification_config_not_found() { metadata: None, }; - let result = storage.get_push_notification_config(¶ms).await; + let result = storage.get_config(¶ms).await; assert!( result.is_err(), "Should return error for non-existent config" @@ -749,7 +679,7 @@ async fn test_list_push_notification_configs() { let task_id = "list-push-task"; storage - .create_task(task_id, "test-context") + .create(&tid(task_id), &cid("test-context")) .await .expect("Failed to create task"); @@ -759,7 +689,7 @@ async fn test_list_push_notification_configs() { metadata: None, }; let configs = storage - .list_push_notification_configs(¶ms) + .list_configs(¶ms) .await .expect("Failed to list configs"); @@ -777,13 +707,13 @@ async fn test_list_push_notification_configs() { }; storage - .set_task_notification(&config) + .set_config(&config) .await .expect("Failed to set notification"); // Now should have 1 config let configs = storage - .list_push_notification_configs(¶ms) + .list_configs(¶ms) .await .expect("Failed to list configs"); @@ -798,7 +728,7 @@ async fn test_delete_push_notification_config() { let task_id = "delete-push-task"; storage - .create_task(task_id, "test-context") + .create(&tid(task_id), &cid("test-context")) .await .expect("Failed to create task"); @@ -814,7 +744,7 @@ async fn test_delete_push_notification_config() { }; storage - .set_task_notification(&config) + .set_config(&config) .await .expect("Failed to set notification"); @@ -824,7 +754,7 @@ async fn test_delete_push_notification_config() { metadata: None, }; let configs = storage - .list_push_notification_configs(&list_params) + .list_configs(&list_params) .await .expect("Failed to list configs"); assert_eq!(configs.len(), 1); @@ -837,13 +767,13 @@ async fn test_delete_push_notification_config() { }; storage - .delete_push_notification_config(&delete_params) + .delete_config(&delete_params) .await .expect("Failed to delete config"); // Verify config is gone let configs = storage - .list_push_notification_configs(&list_params) + .list_configs(&list_params) .await .expect("Failed to list configs"); assert_eq!(configs.len(), 0, "Config should be deleted"); @@ -855,7 +785,7 @@ async fn test_delete_push_notification_config_idempotent() { let task_id = "idempotent-delete-task"; storage - .create_task(task_id, "test-context") + .create(&tid(task_id), &cid("test-context")) .await .expect("Failed to create task"); @@ -866,9 +796,7 @@ async fn test_delete_push_notification_config_idempotent() { metadata: None, }; - let result = storage - .delete_push_notification_config(&delete_params) - .await; + let result = storage.delete_config(&delete_params).await; // Should succeed (idempotent behavior) assert!( result.is_ok(), diff --git a/a2a-rs/tests/integration_test.rs b/a2a-rs/tests/integration_test.rs deleted file mode 100644 index 02e4bb4..0000000 --- a/a2a-rs/tests/integration_test.rs +++ /dev/null @@ -1,255 +0,0 @@ -//! Integration tests for the A2A protocol - -#![cfg(all(feature = "http-client", feature = "http-server"))] - -mod common; - -use a2a_rs::{ - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, SimpleAgentInfo, - }, - domain::{Message, Part, TaskState}, - services::AsyncA2AClient, -}; -use common::TestBusinessHandler; -use reqwest::Client; -use serde_json::Value; -use std::time::Duration; -use tokio::sync::oneshot; - -/// Test a complete HTTP client-server interaction -#[tokio::test] -async fn test_http_client_server_interaction() { - // Create a storage - let storage = InMemoryTaskStorage::new(); - - // Create business handler with the storage - let handler = TestBusinessHandler::with_storage(storage); - - // Create agent info for the processor - let test_agent_info = SimpleAgentInfo::new( - "test-agent".to_string(), - "http://localhost:8182".to_string(), - ); - - // Create a processor - let processor = DefaultRequestProcessor::with_handler(handler, test_agent_info); - - // Create an agent info provider - let agent_info = SimpleAgentInfo::new( - "Test Agent".to_string(), - "http://localhost:8182".to_string(), - ) - .with_description("Test A2A agent for integration tests".to_string()) - .with_provider( - "Test Organization".to_string(), - "https://example.org".to_string(), - ) - .with_documentation_url("https://example.org/docs".to_string()) - .with_push_notifications() - .add_skill( - "test".to_string(), - "Test Skill".to_string(), - Some("A test skill".to_string()), - ); - - // Create the server - let server = HttpServer::new(processor, agent_info, "127.0.0.1:8182".to_string()); - - // Create a shutdown channel - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - // Start the server in a separate task - let server_handle = tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => { - // Server will be dropped and shut down - } - } - }); - - // Give the server time to start - tokio::time::sleep(Duration::from_millis(100)).await; - - // Create the client - let client = HttpClient::new("http://localhost:8182".to_string()); - - // Test 1: Get agent card using direct HTTP request - let http_client = Client::new(); - let response = http_client - .get("http://localhost:8182/agent-card") - .send() - .await - .expect("Failed to fetch agent card"); - - let agent_card: Value = response.json().await.expect("Failed to parse agent card"); - assert_eq!(agent_card["name"].as_str().unwrap(), "Test Agent"); - assert!( - agent_card["capabilities"]["pushNotifications"] - .as_bool() - .unwrap_or(false) - ); - assert!( - !agent_card["capabilities"]["streaming"] - .as_bool() - .unwrap_or(false) - ); - - // Test 2: Get skills using direct HTTP request - let response = http_client - .get("http://localhost:8182/skills") - .send() - .await - .expect("Failed to fetch skills"); - - let skills: Vec = response.json().await.expect("Failed to parse skills"); - assert_eq!(skills.len(), 1); - assert_eq!(skills[0]["id"].as_str().unwrap(), "test"); - - // Test 3: Get skill by ID using direct HTTP request - let response = http_client - .get("http://localhost:8182/skills/test") - .send() - .await - .expect("Failed to fetch skill"); - - let skill: Value = response.json().await.expect("Failed to parse skill"); - assert_eq!(skill["name"].as_str().unwrap(), "Test Skill"); - - // Test 4: Send task message - let task_id = format!("task-{}", uuid::Uuid::new_v4()); - let message_id = format!("msg-{}", uuid::Uuid::new_v4()); - let message = Message::user_text("Hello, A2A agent!".to_string(), message_id); - let task = client - .send_task_message(&task_id, &message, None, None) - .await - .expect("Failed to send task message"); - - assert_eq!(task.id, task_id); - assert_eq!(task.status.state, TaskState::Working); - - // Test 5: Get task - let task = client - .get_task(&task_id, None) - .await - .expect("Failed to get task"); - assert_eq!(task.id, task_id); - assert!(!task.history.is_empty()); - - // Test 6: Get task with limited history - let task_limited = client - .get_task(&task_id, Some(0)) - .await - .expect("Failed to get task with limited history"); - assert_eq!(task_limited.id, task_id); - assert!(task_limited.history.is_empty()); - - // Test 7: Cancel task - println!("About to cancel task with ID: {}", task_id); - let canceled_task = client - .cancel_task(&task_id) - .await - .expect("Failed to cancel task"); - println!("Received task after cancellation: {:?}", canceled_task); - println!("Task state: {:?}", canceled_task.status.state); - assert_eq!(canceled_task.status.state, TaskState::Canceled); - - // Shut down the server - shutdown_tx - .send(()) - .expect("Failed to send shutdown signal"); - - // Wait for the server to shut down - server_handle.await.expect("Server task failed"); -} - -#[tokio::test] -async fn test_message_types() { - use a2a_rs::domain::part; - - // Create a message with text part - let message_id = format!("msg-{}", uuid::Uuid::new_v4()); - let mut message = Message::user_text("Hello, A2A agent!".to_string(), message_id); - - // Add a data part - let data_val: buffa_types::google::protobuf::Value = - serde_json::from_value(serde_json::json!({ - "key": "value" - })) - .unwrap(); - let data_part = Part::data(data_val); - message.add_part(data_part); - - // Add a file part - let file_part = Part::file_from_bytes( - b"Hello, world!".to_vec(), - Some("greeting.txt".to_string()), - Some("text/plain".to_string()), - ); - message - .add_part_validated(file_part) - .expect("Failed to add file part"); - - // Verify message parts - assert_eq!(message.parts.len(), 3); - - // Verify part types - match message.parts[0].content.as_ref() { - Some(part::Content::Text(text)) => assert_eq!(text, "Hello, A2A agent!"), - _ => panic!("Expected Text part"), - } - - match message.parts[1].content.as_ref() { - Some(part::Content::Data(data)) => { - let data_json = serde_json::to_value(&**data).unwrap(); - assert_eq!(data_json["key"], "value"); - } - _ => panic!("Expected Data part"), - } - - match message.parts[2].content.as_ref() { - Some(part::Content::Raw(bytes)) => { - assert_eq!(bytes, b"Hello, world!"); - assert_eq!(message.parts[2].filename, "greeting.txt"); - assert_eq!(message.parts[2].media_type, "text/plain"); - } - _ => panic!("Expected Raw part"), - } -} - -/// Test task history functionality -#[tokio::test] -async fn test_task_history() { - // Create a new task - let context_id = format!("ctx-{}", uuid::Uuid::new_v4()); - let mut task = a2a_rs::domain::Task::new("test-task-1".to_string(), context_id); - - // Create messages - let msg_id1 = format!("msg-{}", uuid::Uuid::new_v4()); - let msg_id2 = format!("msg-{}", uuid::Uuid::new_v4()); - let msg_id3 = format!("msg-{}", uuid::Uuid::new_v4()); - let message1 = Message::user_text("Message 1".to_string(), msg_id1); - let message2 = Message::agent_text("Message 2".to_string(), msg_id2); - let message3 = Message::user_text("Message 3".to_string(), msg_id3); - - // Update the task with messages - task.update_status(TaskState::Working, Some(message1)); - task.update_status(TaskState::Working, Some(message2)); - task.update_status(TaskState::Working, Some(message3)); - - // Verify history has all messages - assert!(!task.history.is_empty()); - let history = &task.history; - assert_eq!(history.len(), 3); - - // Test history truncation - let task_limited = task.with_limited_history(Some(2)); - assert!(!task_limited.history.is_empty()); - let history_limited = &task_limited.history; - assert_eq!(history_limited.len(), 2); - - // Test removing history entirely - let task_no_history = task.with_limited_history(Some(0)); - assert!(task_no_history.history.is_empty()); -} diff --git a/a2a-rs/tests/jsonrpc_client_interop_test.rs b/a2a-rs/tests/jsonrpc_client_interop_test.rs new file mode 100644 index 0000000..9223fea --- /dev/null +++ b/a2a-rs/tests/jsonrpc_client_interop_test.rs @@ -0,0 +1,393 @@ +//! In-process interop round-trip: the JSON-RPC **client** against the JSON-RPC +//! **server**, over a real socket. +//! +//! This proves byte-compatibility of [`JsonRpcClient`] with +//! [`JsonRpcAdapter`](a2a_rs::adapter::JsonRpcAdapter): the client's JSON-RPC +//! envelopes + ProtoJSON bodies are exactly what the server decodes, and the +//! client's SSE reassembly parses exactly what the server emits. A real +//! `TcpListener` (not `tower::oneshot`) is required because the streaming path +//! drives `reqwest` over a live connection. + +#![cfg(all(feature = "jsonrpc-client", feature = "jsonrpc-server"))] + +mod common; + +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use axum::{Json, Router, routing::get}; +use common::TestBusinessHandler; +use futures::{Stream, StreamExt, stream}; + +use a2a_rs::adapter::{InMemoryTaskStorage, JsonRpcAdapter, SimpleAgentInfo, jsonrpc_router}; +use a2a_rs::domain::{ + A2AError, AgentCard, AgentInterface, Message, TaskArtifactUpdateEvent, + TaskPushNotificationConfig, TaskState, TaskStatus, TaskStatusUpdateEvent, +}; +use a2a_rs::port::AsyncStreamingHandler; +use a2a_rs::port::streaming_handler::{SeqEvent, Subscriber}; +use a2a_rs::{JsonRpcClient, StreamItem, Transport, connect, default_registry}; + +// --------------------------------------------------------------------------- +// A streaming handler whose pull-streams are empty but valid, so `subscribe` +// emits the initial task snapshot then completes. (InMemoryTaskStorage returns +// `UnsupportedOperation` from `combined_update_stream`, which would fail +// `subscribe`.) +// --------------------------------------------------------------------------- + +type StatusStream = Pin> + Send>>; +type ArtifactStream = Pin> + Send>>; +type CombinedStream = Pin> + Send>>; + +#[derive(Clone)] +struct EmptyStreamHandler; + +#[async_trait] +impl AsyncStreamingHandler for EmptyStreamHandler { + async fn add_status_subscriber( + &self, + _task_id: &str, + _subscriber: Box + Send + Sync>, + ) -> Result { + Ok("status-sub".to_string()) + } + async fn add_artifact_subscriber( + &self, + _task_id: &str, + _subscriber: Box + Send + Sync>, + ) -> Result { + Ok("artifact-sub".to_string()) + } + async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { + Ok(()) + } + async fn remove_task_subscribers(&self, _task_id: &str) -> Result<(), A2AError> { + Ok(()) + } + async fn get_subscriber_count(&self, _task_id: &str) -> Result { + Ok(0) + } + async fn broadcast_status_update( + &self, + _task_id: &str, + _update: TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { + Ok(()) + } + async fn broadcast_artifact_update( + &self, + _task_id: &str, + _update: TaskArtifactUpdateEvent, + ) -> Result<(), A2AError> { + Ok(()) + } + async fn status_update_stream(&self, _task_id: &str) -> Result { + Ok(Box::pin(stream::empty())) + } + async fn artifact_update_stream(&self, _task_id: &str) -> Result { + Ok(Box::pin(stream::empty())) + } + async fn combined_update_stream( + &self, + _task_id: &str, + _from_event_id: Option, + ) -> Result { + Ok(Box::pin(stream::empty())) + } +} + +// --------------------------------------------------------------------------- +// Server harness +// --------------------------------------------------------------------------- + +/// Spawn the JSON-RPC server (with an agent-card route) on an ephemeral port and +/// return its base URL. +async fn spawn_server() -> String { + let handler = TestBusinessHandler::with_storage(InMemoryTaskStorage::new()); + let agent_info = SimpleAgentInfo::new("interop".to_string(), "http://localhost".to_string()); + let adapter = Arc::new( + JsonRpcAdapter::with_handler(handler, agent_info) + .with_streaming_handler(EmptyStreamHandler), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let base = format!("http://{}", listener.local_addr().unwrap()); + + let card = AgentCard { + supported_interfaces: vec![AgentInterface { + url: base.clone(), + protocol_binding: "JSONRPC".to_string(), + protocol_version: "1.0".to_string(), + ..Default::default() + }], + ..Default::default() + }; + + let app: Router = jsonrpc_router(adapter).route( + "/.well-known/agent-card.json", + get(move || { + let card = card.clone(); + async move { Json(card) } + }), + ); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + base +} + +fn message() -> Message { + Message::user_text("hello".to_string(), "m1".to_string()) +} + +/// Spawn a server whose streaming backend is a real (shared) in-memory handler, +/// returning the base URL and a handle to broadcast through the same backend. +async fn spawn_server_streaming() -> (String, TestBusinessHandler) { + let handler = TestBusinessHandler::with_storage(InMemoryTaskStorage::new()); + let agent_info = SimpleAgentInfo::new("interop".to_string(), "http://localhost".to_string()); + let adapter = Arc::new( + JsonRpcAdapter::with_handler(handler.clone(), agent_info) + .with_streaming_handler(handler.clone()), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let base = format!("http://{}", listener.local_addr().unwrap()); + let app: Router = jsonrpc_router(adapter); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + (base, handler) +} + +fn status_update(task_id: &str, state: TaskState) -> TaskStatusUpdateEvent { + TaskStatusUpdateEvent { + task_id: task_id.to_string(), + context_id: "ctx".to_string(), + kind: "status-update".to_string(), + status: TaskStatus::new(state, None), + metadata: None, + } +} + +/// End-to-end Last-Event-ID resumption: the server emits SSE `id:` fields, the +/// client parses them, and reconnecting with `Last-Event-ID` replays only the +/// buffered events after that id (preceded by the initial task snapshot). +#[tokio::test] +async fn subscribe_resumes_from_last_event_id() { + let (base, handler) = spawn_server_streaming().await; + let client = JsonRpcClient::new(base); + + // Create the task so subscribe emits an initial snapshot. + client + .send_task_message("task-resume", &message(), None, None) + .await + .unwrap(); + + // Two updates broadcast before any subscriber — buffered as event ids 1 and 2. + handler + .broadcast_status_update( + "task-resume", + status_update("task-resume", TaskState::Working), + ) + .await + .unwrap(); + handler + .broadcast_status_update( + "task-resume", + status_update("task-resume", TaskState::Completed), + ) + .await + .unwrap(); + + // First subscription replays everything (id > 0); discover the id the server + // assigned to the Completed event (the message handler may emit its own + // events too, so we don't assume absolute ids). + let mut all = client + .subscribe_to_task("task-resume", None, Some("0")) + .await + .unwrap(); + let mut completed_id = None; + for _ in 0..16 { + match tokio::time::timeout(Duration::from_secs(2), all.next()).await { + Ok(Some(Ok(ev))) => { + if let StreamItem::StatusUpdate(e) = &ev.item { + if e.status.state == ::buffa::EnumValue::from(TaskState::Completed) { + completed_id = ev.event_id; + break; + } + } + } + _ => break, + } + } + let completed_id = completed_id.expect("should observe the Completed event with an id"); + drop(all); + + // Resume from just before Completed: only that event replays, after the snapshot. + let mut stream = client + .subscribe_to_task("task-resume", None, Some(&(completed_id - 1).to_string())) + .await + .unwrap(); + + let mut got = Vec::new(); + while got.len() < 2 { + let ev = tokio::time::timeout(Duration::from_secs(5), stream.next()) + .await + .expect("event within 5s") + .expect("stream not empty") + .expect("ok event"); + got.push(ev); + } + + // First: initial task snapshot (no id). Second: the replayed Completed event. + assert!( + matches!(got[0].item, StreamItem::Task(_)), + "first must be the snapshot" + ); + assert_eq!(got[0].event_id, None); + assert_eq!( + got[1].event_id, + Some(completed_id), + "only the Completed event should replay after Last-Event-ID = completed-1" + ); + match &got[1].item { + StreamItem::StatusUpdate(e) => { + assert_eq!( + e.status.state, + ::buffa::EnumValue::from(TaskState::Completed) + ) + } + other => panic!("expected StatusUpdate, got {other:?}"), + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn unary_roundtrip_send_get_list_cancel() { + let base = spawn_server().await; + let client = JsonRpcClient::new(base); + + // send → returns a task + let task = client + .send_task_message("task-1", &message(), None, None) + .await + .unwrap(); + let id = task.id.clone(); + assert!(!id.is_empty()); + + // get → same task + let got = client.get_task(&id, None).await.unwrap(); + assert_eq!(got.id, id); + + // list → contains it + let listed = client.list_tasks(&Default::default()).await.unwrap(); + assert!( + listed.tasks.iter().any(|t| t.id == id), + "listed tasks should contain {id}" + ); + + // cancel → same task + let canceled = client.cancel_task(&id).await.unwrap(); + assert_eq!(canceled.id, id); +} + +#[tokio::test] +async fn push_config_lifecycle() { + let base = spawn_server().await; + let client = JsonRpcClient::new(base); + + let task = client + .send_task_message("task-pc", &message(), None, None) + .await + .unwrap(); + let id = task.id.clone(); + + let config = TaskPushNotificationConfig { + task_id: id.clone(), + id: "cfg-1".to_string(), + url: "https://example.com/webhook".to_string(), + token: "tok".to_string(), + ..Default::default() + }; + + client.set_task_push_notification(&config).await.unwrap(); + + let configs = client.list_push_notification_configs(&id).await.unwrap(); + assert!( + !configs.is_empty(), + "config list should be non-empty after create" + ); + + let got = client + .get_push_notification_config(&id, "cfg-1") + .await + .unwrap(); + assert_eq!(got.url, "https://example.com/webhook"); + + client + .delete_push_notification_config(&id, "cfg-1") + .await + .unwrap(); +} + +#[tokio::test] +async fn subscribe_yields_initial_task_over_sse() { + let base = spawn_server().await; + let client = JsonRpcClient::new(base); + + let task = client + .send_task_message("task-sub", &message(), None, None) + .await + .unwrap(); + let id = task.id.clone(); + + let mut stream = client.subscribe_to_task(&id, None, None).await.unwrap(); + + // First SSE event must be the initial task snapshot — proving the client's + // SSE reassembly + JSON-RPC frame + StreamResponse union decode all work. + let first = tokio::time::timeout(Duration::from_secs(5), stream.next()) + .await + .expect("subscribe stream should yield within 5s") + .expect("subscribe stream should not be empty") + .expect("first event should be Ok"); + + match first.item { + StreamItem::Task(t) => assert_eq!(t.id, id), + other => panic!("expected initial Task snapshot, got {other:?}"), + } +} + +#[tokio::test] +async fn connect_negotiates_jsonrpc_from_card() { + let base = spawn_server().await; + + // connect() fetches the card and negotiates; the card only offers JSONRPC. + let transport = connect(&base, &default_registry()).await.unwrap(); + assert_eq!(transport.protocol(), "JSONRPC"); + + let task = transport + .send_task_message("task-neg", &message(), None, None) + .await + .unwrap(); + let got = transport.get_task(&task.id, None).await.unwrap(); + assert_eq!(got.id, task.id); +} + +#[tokio::test] +async fn get_task_not_found_maps_to_typed_error() { + let base = spawn_server().await; + let client = JsonRpcClient::new(base); + + let err = client.get_task("does-not-exist", None).await.unwrap_err(); + assert!( + matches!(err, A2AError::TaskNotFound(_)), + "expected TaskNotFound, got {err:?}" + ); +} diff --git a/a2a-rs/tests/jsonrpc_dispatch_test.rs b/a2a-rs/tests/jsonrpc_dispatch_test.rs new file mode 100644 index 0000000..8e5b6c4 --- /dev/null +++ b/a2a-rs/tests/jsonrpc_dispatch_test.rs @@ -0,0 +1,250 @@ +//! Behavioral tests for the JSON-RPC adapter's method dispatch. +//! +//! Drives [`JsonRpcAdapter::handle_unary`] against an in-memory handler and +//! asserts the JSON-RPC envelopes + ProtoJSON result bodies that an +//! off-the-shelf A2A client would see. + +#![cfg(feature = "jsonrpc-server")] + +mod common; + +use a2a_rs::adapter::transport::jsonrpc::{JsonRpcId, JsonRpcRequest, error_code, methods}; +use a2a_rs::adapter::{InMemoryTaskStorage, JsonRpcAdapter, SimpleAgentInfo}; +use common::TestBusinessHandler; +use serde_json::{Value, json}; + +fn adapter() -> JsonRpcAdapter { + let handler = TestBusinessHandler::with_storage(InMemoryTaskStorage::new()); + let agent_info = SimpleAgentInfo::new("test-agent".to_string(), "http://localhost".to_string()); + JsonRpcAdapter::with_handler(handler, agent_info) +} + +fn request(method: &str, params: Value) -> JsonRpcRequest { + JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Num(1), + method: method.to_string(), + params: Some(params), + } +} + +fn send_message_params(task_id: &str) -> Value { + json!({ + "message": { + "messageId": "m1", + "role": "ROLE_USER", + "parts": [{ "text": "hello" }], + "taskId": task_id, + } + }) +} + +#[tokio::test] +async fn send_message_returns_task_union() { + let resp = adapter() + .handle_unary(request( + methods::SEND_MESSAGE, + send_message_params("task-1"), + )) + .await; + let value = serde_json::to_value(&resp).unwrap(); + + assert_eq!(value["jsonrpc"], "2.0"); + assert_eq!(value["id"], 1); + assert!(value.get("error").is_none(), "unexpected error: {value:?}"); + // Field-presence union: result is `{ "task": { ... } }`, no discriminator. + let task = &value["result"]["task"]; + assert_eq!(task["id"], "task-1"); + // State is a SCREAMING_SNAKE proto-name string (the exact value depends on + // the handler; just assert the ProtoJSON enum shape). + assert!( + task["status"]["state"] + .as_str() + .is_some_and(|s| s.starts_with("TASK_STATE_")), + "unexpected status: {:?}", + task["status"], + ); +} + +#[tokio::test] +async fn get_task_round_trips() { + let a = adapter(); + a.handle_unary(request( + methods::SEND_MESSAGE, + send_message_params("task-2"), + )) + .await; + + let resp = a + .handle_unary(request(methods::GET_TASK, json!({ "id": "task-2" }))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert!(value.get("error").is_none(), "unexpected error: {value:?}"); + // GetTask result is a bare Task (not a union). + assert_eq!(value["result"]["id"], "task-2"); +} + +#[tokio::test] +async fn cancel_task_returns_canceled_state() { + let a = adapter(); + a.handle_unary(request( + methods::SEND_MESSAGE, + send_message_params("task-3"), + )) + .await; + + let resp = a + .handle_unary(request(methods::CANCEL_TASK, json!({ "id": "task-3" }))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert!(value.get("error").is_none(), "unexpected error: {value:?}"); + assert_eq!(value["result"]["id"], "task-3"); +} + +#[tokio::test] +async fn unknown_method_is_method_not_found() { + let resp = adapter() + .handle_unary(request("NoSuchMethod", json!({}))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert!(value.get("result").is_none()); + assert_eq!(value["error"]["code"], error_code::METHOD_NOT_FOUND); +} + +#[tokio::test] +async fn invalid_params_is_invalid_params() { + // `message` is required on SendMessageRequest's wire shape; an int is invalid. + let resp = adapter() + .handle_unary(request(methods::SEND_MESSAGE, json!({ "message": 42 }))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert_eq!(value["error"]["code"], error_code::INVALID_PARAMS); +} + +#[tokio::test] +async fn missing_message_is_invalid_params() { + let resp = adapter() + .handle_unary(request(methods::SEND_MESSAGE, json!({}))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert_eq!(value["error"]["code"], error_code::INVALID_PARAMS); +} + +#[tokio::test] +async fn get_missing_task_is_task_not_found() { + let resp = adapter() + .handle_unary(request(methods::GET_TASK, json!({ "id": "nope" }))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert_eq!(value["error"]["code"], error_code::TASK_NOT_FOUND); +} + +#[tokio::test] +async fn list_tasks_returns_response_envelope() { + let a = adapter(); + a.handle_unary(request( + methods::SEND_MESSAGE, + send_message_params("task-4"), + )) + .await; + + let resp = a + .handle_unary(request(methods::LIST_TASKS, json!({}))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert!(value.get("error").is_none(), "unexpected error: {value:?}"); + assert!(value["result"]["tasks"].is_array()); +} + +// --------------------------------------------------------------------------- +// Server-side CallInterceptor chain +// --------------------------------------------------------------------------- + +mod interceptors { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use a2a_rs::domain::A2AError; + use a2a_rs::port::{CallContext, CallInterceptor, CallSide}; + use async_trait::async_trait; + use serde_json::json; + + use super::{adapter, methods, request}; + + /// Records how often each hook fired and whether `after` saw an error. + #[derive(Clone, Default)] + struct Counting { + before: Arc, + after_ok: Arc, + after_err: Arc, + } + + #[async_trait] + impl CallInterceptor for Counting { + async fn before(&self, ctx: &CallContext) -> Result<(), A2AError> { + assert_eq!(ctx.side, CallSide::Server); + self.before.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + async fn after(&self, _ctx: &CallContext, outcome: Result<(), &A2AError>) { + match outcome { + Ok(()) => self.after_ok.fetch_add(1, Ordering::SeqCst), + Err(_) => self.after_err.fetch_add(1, Ordering::SeqCst), + }; + } + } + + /// A `before` that always short-circuits the call. + struct Rejecting; + + #[async_trait] + impl CallInterceptor for Rejecting { + async fn before(&self, _ctx: &CallContext) -> Result<(), A2AError> { + Err(A2AError::UnsupportedOperation( + "rejected by interceptor".to_string(), + )) + } + } + + #[tokio::test] + async fn before_and_after_wrap_each_dispatch() { + let counter = Counting::default(); + let a = adapter().with_interceptor(counter.clone()); + + // A successful call: after observes Ok. + a.handle_unary(request( + methods::SEND_MESSAGE, + super::send_message_params("ti-1"), + )) + .await; + // A failing call (missing task): after observes Err. + let resp = a + .handle_unary(request(methods::GET_TASK, json!({ "id": "ghost" }))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + assert!(value.get("error").is_some(), "expected an error: {value:?}"); + + assert_eq!(counter.before.load(Ordering::SeqCst), 2); + assert_eq!(counter.after_ok.load(Ordering::SeqCst), 1); + assert_eq!(counter.after_err.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn rejecting_before_short_circuits_dispatch() { + // Rejecting runs first; the real method never executes. + let a = adapter() + .with_interceptor(Rejecting) + .with_interceptor(Counting::default()); + + let resp = a + .handle_unary(request(methods::GET_TASK, json!({ "id": "task-x" }))) + .await; + let value = serde_json::to_value(&resp).unwrap(); + // The short-circuit error surfaces as the JSON-RPC error, not a task. + assert_eq!( + value["error"]["message"], + "Unsupported operation: rejected by interceptor" + ); + assert!(value.get("result").is_none() || value["result"].is_null()); + } +} diff --git a/a2a-rs/tests/jsonrpc_router_test.rs b/a2a-rs/tests/jsonrpc_router_test.rs new file mode 100644 index 0000000..32df179 --- /dev/null +++ b/a2a-rs/tests/jsonrpc_router_test.rs @@ -0,0 +1,338 @@ +//! End-to-end tests for the JSON-RPC / REST **routers** (the surface the +//! dispatch tests don't reach). +//! +//! [`jsonrpc_dispatch_test`] drives [`JsonRpcAdapter::handle_unary`] directly; +//! this file stands up the real `axum` routers and drives them with +//! `tower::ServiceExt::oneshot`, so it covers the parts that only exist at the +//! router layer: REST path/query extraction, the `/tasks/{id}/cancel` slash +//! alias, HTTP status mapping, and — most importantly — the two **SSE framings** +//! (`jsonrpc_sse` wraps each event in a JSON-RPC envelope; `rest_sse` emits the +//! bare ProtoJSON `StreamResponse`). + +#![cfg(feature = "jsonrpc-server")] + +mod common; + +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use axum::body::{Body, to_bytes}; +use axum::http::{Request, StatusCode, header::CONTENT_TYPE}; +use common::TestBusinessHandler; +use futures::{Stream, StreamExt, stream}; +use serde_json::{Value, json}; +use tower::ServiceExt; + +use a2a_rs::adapter::{ + InMemoryTaskStorage, JsonRpcAdapter, SimpleAgentInfo, jsonrpc_router, rest_router, +}; +use a2a_rs::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent}; +use a2a_rs::port::AsyncStreamingHandler; +use a2a_rs::port::streaming_handler::{SeqEvent, Subscriber}; + +/// A streaming handler whose pull-streams are empty but valid. +/// +/// `InMemoryTaskStorage` models streaming as subscriber push and returns +/// `UnsupportedOperation` from `combined_update_stream`, so `TaskService`'s +/// stream methods can't run against it. These tests only need the SSE *framing* +/// to be exercised — the initial task snapshot is emitted by the adapter ahead +/// of the (here empty) update stream — so a handler that returns an empty stream +/// is enough to drive `open_stream` to success. +#[derive(Clone)] +struct EmptyStreamHandler; + +type StatusStream = Pin> + Send>>; +type ArtifactStream = Pin> + Send>>; +type CombinedStream = Pin> + Send>>; + +#[async_trait] +impl AsyncStreamingHandler for EmptyStreamHandler { + async fn add_status_subscriber( + &self, + _task_id: &str, + _subscriber: Box + Send + Sync>, + ) -> Result { + Ok("status-sub".to_string()) + } + + async fn add_artifact_subscriber( + &self, + _task_id: &str, + _subscriber: Box + Send + Sync>, + ) -> Result { + Ok("artifact-sub".to_string()) + } + + async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> { + Ok(()) + } + + async fn remove_task_subscribers(&self, _task_id: &str) -> Result<(), A2AError> { + Ok(()) + } + + async fn get_subscriber_count(&self, _task_id: &str) -> Result { + Ok(0) + } + + async fn broadcast_status_update( + &self, + _task_id: &str, + _update: TaskStatusUpdateEvent, + ) -> Result<(), A2AError> { + Ok(()) + } + + async fn broadcast_artifact_update( + &self, + _task_id: &str, + _update: TaskArtifactUpdateEvent, + ) -> Result<(), A2AError> { + Ok(()) + } + + async fn status_update_stream(&self, _task_id: &str) -> Result { + Ok(Box::pin(stream::empty::< + Result, + >())) + } + + async fn artifact_update_stream(&self, _task_id: &str) -> Result { + Ok(Box::pin(stream::empty::< + Result, + >())) + } + + async fn combined_update_stream( + &self, + _task_id: &str, + _from_event_id: Option, + ) -> Result { + Ok(Box::pin(stream::empty::>())) + } +} + +/// An adapter wired with a working (empty) streaming backend. +fn streaming_adapter() -> Arc { + let handler = TestBusinessHandler::with_storage(InMemoryTaskStorage::new()); + let agent_info = + SimpleAgentInfo::new("router-test".to_string(), "http://localhost".to_string()); + Arc::new( + JsonRpcAdapter::with_handler(handler, agent_info) + .with_streaming_handler(EmptyStreamHandler), + ) +} + +/// Build an adapter backed by a real in-memory streaming handler so the SSE +/// methods emit the initial task snapshot. +fn adapter() -> Arc { + let handler = TestBusinessHandler::with_storage(InMemoryTaskStorage::new()); + let agent_info = + SimpleAgentInfo::new("router-test".to_string(), "http://localhost".to_string()); + Arc::new( + JsonRpcAdapter::with_handler(handler.clone(), agent_info).with_streaming_handler(handler), + ) +} + +fn send_message_body(task_id: &str) -> Value { + json!({ + "message": { + "messageId": "m1", + "role": "ROLE_USER", + "parts": [{ "text": "hello" }], + "taskId": task_id, + } + }) +} + +fn post(uri: &str, body: &Value) -> Request { + Request::builder() + .method("POST") + .uri(uri) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_vec(body).unwrap())) + .unwrap() +} + +fn get(uri: &str) -> Request { + Request::builder() + .method("GET") + .uri(uri) + .body(Body::empty()) + .unwrap() +} + +/// Drive a request through the REST router and return `(status, json_body)`. +async fn rest_call(adapter: &Arc, req: Request) -> (StatusCode, Value) { + let resp = rest_router(adapter.clone()).oneshot(req).await.unwrap(); + let status = resp.status(); + let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); + let value = serde_json::from_slice(&bytes).unwrap_or(Value::Null); + (status, value) +} + +/// Drive a request through the JSON-RPC router and return `(status, json_body)`. +async fn jsonrpc_call(adapter: &Arc, body: &Value) -> (StatusCode, Value) { + let resp = jsonrpc_router(adapter.clone()) + .oneshot(post("/", body)) + .await + .unwrap(); + let status = resp.status(); + let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); + let value = serde_json::from_slice(&bytes).unwrap_or(Value::Null); + (status, value) +} + +/// Read the first SSE event's `data:` payload as JSON, with a timeout so a +/// keep-alive stream that never yields fails the test rather than hanging. +async fn first_sse_event(resp: axum::response::Response) -> Value { + assert_eq!( + resp.status(), + StatusCode::OK, + "SSE endpoint should return 200" + ); + let ct = resp + .headers() + .get(CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + assert!( + ct.starts_with("text/event-stream"), + "expected SSE content-type, got {ct:?}" + ); + + let mut stream = resp.into_body().into_data_stream(); + let mut buf = String::new(); + let deadline = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(deadline); + loop { + tokio::select! { + _ = &mut deadline => panic!("timed out waiting for an SSE data line; buffered: {buf:?}"), + chunk = stream.next() => { + let chunk = chunk.expect("stream ended before a data line").expect("stream error"); + buf.push_str(&String::from_utf8_lossy(&chunk)); + if let Some(line) = buf.lines().find(|l| l.starts_with("data:")) { + let payload = line.trim_start_matches("data:").trim(); + return serde_json::from_str(payload).expect("SSE data is not JSON"); + } + } + } + } +} + +// --- REST unary ------------------------------------------------------------ + +#[tokio::test] +async fn rest_send_then_get_round_trips() { + let a = adapter(); + + let (status, body) = rest_call(&a, post("/message:send", &send_message_body("t1"))).await; + assert_eq!(status, StatusCode::OK); + // SendMessageResponse field-presence union: `{ "task": { ... } }`. + assert_eq!(body["task"]["id"], "t1"); + + let (status, body) = rest_call(&a, get("/tasks/t1")).await; + assert_eq!(status, StatusCode::OK); + // GetTask returns a bare Task, not a union. + assert_eq!(body["id"], "t1"); +} + +#[tokio::test] +async fn rest_cancel_slash_alias_works() { + let a = adapter(); + rest_call(&a, post("/message:send", &send_message_body("t2"))).await; + + // The canonical `/tasks/{id}:cancel` colon form is unroutable in matchit; + // the adapter serves the slash alias instead. Official clients accept both. + let (status, body) = rest_call(&a, post("/tasks/t2/cancel", &json!({}))).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(body["id"], "t2"); +} + +#[tokio::test] +async fn rest_get_missing_task_is_404() { + let a = adapter(); + let (status, _body) = rest_call(&a, get("/tasks/nope")).await; + assert_eq!(status, StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn rest_list_tasks_via_query() { + let a = adapter(); + rest_call(&a, post("/message:send", &send_message_body("t3"))).await; + + let (status, body) = rest_call(&a, get("/tasks?pageSize=10")).await; + assert_eq!(status, StatusCode::OK); + assert!(body["tasks"].is_array()); +} + +// --- JSON-RPC unary -------------------------------------------------------- + +#[tokio::test] +async fn jsonrpc_send_message_envelope() { + let a = adapter(); + let body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "SendMessage", + "params": send_message_body("j1"), + }); + let (status, resp) = jsonrpc_call(&a, &body).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(resp["jsonrpc"], "2.0"); + assert_eq!(resp["id"], 1); + assert!(resp.get("error").is_none(), "unexpected error: {resp:?}"); + assert_eq!(resp["result"]["task"]["id"], "j1"); +} + +#[tokio::test] +async fn jsonrpc_rejects_wrong_version() { + let a = adapter(); + let body = json!({ "jsonrpc": "1.0", "id": 1, "method": "GetTask", "params": { "id": "x" } }); + let (status, resp) = jsonrpc_call(&a, &body).await; + assert_eq!(status, StatusCode::OK); // JSON-RPC errors ride in the body, not the HTTP status + assert_eq!(resp["error"]["code"], -32600); // INVALID_REQUEST +} + +// --- SSE framing (the part only the router exercises) ---------------------- + +#[tokio::test] +async fn jsonrpc_stream_frames_events_in_envelopes() { + let a = streaming_adapter(); + let body = json!({ + "jsonrpc": "2.0", + "id": 7, + "method": "SendStreamingMessage", + "params": send_message_body("s1"), + }); + let resp = jsonrpc_router(a.clone()) + .oneshot(post("/", &body)) + .await + .unwrap(); + let event = first_sse_event(resp).await; + + // JSON-RPC SSE: each event is a full response envelope whose `result` is the + // tag-free `StreamResponse` union — here the initial task snapshot. + assert_eq!(event["jsonrpc"], "2.0"); + assert_eq!(event["id"], 7); + assert_eq!(event["result"]["task"]["id"], "s1"); +} + +#[tokio::test] +async fn rest_stream_frames_bare_protojson() { + let a = streaming_adapter(); + let resp = rest_router(a.clone()) + .oneshot(post("/message:stream", &send_message_body("s2"))) + .await + .unwrap(); + let event = first_sse_event(resp).await; + + // REST SSE has no envelope: the event data IS the bare `StreamResponse`. + assert!( + event.get("jsonrpc").is_none(), + "REST SSE must not carry a JSON-RPC envelope" + ); + assert_eq!(event["task"]["id"], "s2"); +} diff --git a/a2a-rs/tests/jsonrpc_wire_test.rs b/a2a-rs/tests/jsonrpc_wire_test.rs new file mode 100644 index 0000000..ac78f69 --- /dev/null +++ b/a2a-rs/tests/jsonrpc_wire_test.rs @@ -0,0 +1,192 @@ +//! Wire-format probe + golden tests for the JSON-RPC / HTTP+JSON adapter. +//! +//! These tests pin down whether the `buffa`-generated domain types serialize as +//! canonical **ProtoJSON** (the wire format the A2A spec and the official SDK +//! use). If they do, the JSON-RPC adapter can serialize the generated types +//! directly (plan "Option A"); the only hand-written serde it adds is the +//! tag-free field-presence unions. +//! +//! The `probe_*` tests print the actual JSON so the exact shape (Timestamp +//! format, `Struct`/metadata shape, `bytes` encoding) is visible in test output. + +#![cfg(feature = "server")] + +use a2a_rs::domain::{Message, Part, Task, TaskState, TaskStatus}; + +/// Recursively sort object keys so two JSON values compare modulo key order. +fn canonical(mut v: serde_json::Value) -> serde_json::Value { + fn sort(v: &mut serde_json::Value) { + match v { + serde_json::Value::Object(map) => { + let mut sorted: std::collections::BTreeMap = + std::mem::take(map).into_iter().collect(); + for val in sorted.values_mut() { + sort(val); + } + *map = sorted.into_iter().collect(); + } + serde_json::Value::Array(arr) => arr.iter_mut().for_each(sort), + _ => {} + } + } + sort(&mut v); + v +} + +// --------------------------------------------------------------------------- +// Probes — print real output to settle R1 (Timestamp) and R2 (Struct/metadata). +// --------------------------------------------------------------------------- + +#[test] +fn timestamp_serializes_as_rfc3339() { + // R1: `google.protobuf.Timestamp` must be an RFC3339 string, not + // `{seconds, nanos}`. If this ever regresses, any Timestamp-bearing type + // needs an Option-B wire conversion. + let status = TaskStatus { + state: buffa::EnumValue::from(TaskState::TASK_STATE_WORKING), + timestamp: buffa::MessageField::some(buffa_types::google::protobuf::Timestamp { + seconds: 1_700_000_000, + nanos: 0, + ..Default::default() + }), + ..Default::default() + }; + let json = serde_json::to_value(&status).unwrap(); + assert_eq!( + canonical(json), + canonical(serde_json::json!({ + "state": "TASK_STATE_WORKING", + "timestamp": "2023-11-14T22:13:20Z", + })), + ); +} + +#[test] +fn metadata_struct_serializes_as_bare_object() { + // R2: `google.protobuf.Struct` must be a bare JSON object, not + // `{fields: {...}}`. Note proto `Struct` numbers are doubles (42 -> 42.0). + let mut message = Message::user_text("hello".to_string(), "msg-1".to_string()); + let struct_val: buffa_types::google::protobuf::Struct = serde_json::from_value( + serde_json::json!({ "foo": "bar", "n": 42, "nested": { "a": true } }), + ) + .unwrap(); + message.metadata = buffa::MessageField::some(struct_val); + let json = serde_json::to_value(&message).unwrap(); + assert_eq!( + canonical(json), + canonical(serde_json::json!({ + "messageId": "msg-1", + "role": "ROLE_USER", + "parts": [{ "text": "hello" }], + "metadata": { "foo": "bar", "n": 42.0, "nested": { "a": true } }, + })), + ); +} + +#[test] +fn bytes_part_serializes_as_base64_under_raw() { + let part = Part { + content: Some(a2a_rs::domain::generated::part::Content::Raw(vec![ + 1, 2, 3, 255, + ])), + ..Default::default() + }; + let json = serde_json::to_value(&part).unwrap(); + assert_eq!( + canonical(json), + canonical(serde_json::json!({ "raw": "AQID/w==" })) + ); +} + +#[test] +fn enums_and_oneof_are_tag_free_proto_names() { + let message = Message::agent_text("hi".to_string(), "m2".to_string()); + let json = serde_json::to_value(&message).unwrap(); + // Role must be the SCREAMING_SNAKE proto name, not an int. + assert_eq!(json["role"], serde_json::json!("ROLE_AGENT")); + // Text part flattens to {"text": "hi"} with no discriminator tag. + assert_eq!(json["parts"][0]["text"], serde_json::json!("hi")); +} + +#[test] +fn message_round_trips_through_protojson() { + // Round-trip catches alias/`skip_if`/`null_as_default` asymmetry: a wire + // body deserializes into the domain type and re-serializes byte-identically. + let wire = serde_json::json!({ + "messageId": "m3", + "contextId": "ctx", + "role": "ROLE_USER", + "parts": [{ "text": "round" }, { "raw": "AQID/w==" }], + }); + let message: Message = serde_json::from_value(wire.clone()).unwrap(); + let back = serde_json::to_value(&message).unwrap(); + assert_eq!(canonical(back), canonical(wire)); +} + +#[test] +fn task_serializes_as_protojson() { + let task = Task { + id: "task-1".to_string(), + context_id: "ctx-1".to_string(), + status: buffa::MessageField::some(TaskStatus { + state: buffa::EnumValue::from(TaskState::TASK_STATE_COMPLETED), + ..Default::default() + }), + ..Default::default() + }; + let json = serde_json::to_value(&task).unwrap(); + assert_eq!(json["id"], serde_json::json!("task-1")); + assert_eq!(json["contextId"], serde_json::json!("ctx-1")); + assert_eq!( + json["status"]["state"], + serde_json::json!("TASK_STATE_COMPLETED") + ); + // proto3 default fields (empty artifacts/history) must be omitted. + assert!( + json.get("artifacts").is_none(), + "empty artifacts should be omitted" + ); +} + +#[test] +fn canonical_sorts_keys() { + let a = serde_json::json!({ "b": 1, "a": 2 }); + let b = serde_json::json!({ "a": 2, "b": 1 }); + assert_eq!(canonical(a), canonical(b)); +} + +#[test] +fn generated_response_types_are_field_presence_unions() { + // The generated `SendMessageResponse`/`StreamResponse` oneofs already + // serialize tag-free, so the adapter reuses them as the JSON-RPC `result` + // rather than hand-writing union serde. + use a2a_rs::domain::generated::{ + SendMessageResponse, StreamResponse, TaskStatusUpdateEvent, send_message_response, + stream_response, + }; + let r = SendMessageResponse { + payload: Some(send_message_response::Payload::Task(Box::new(Task { + id: "t1".into(), + ..Default::default() + }))), + ..Default::default() + }; + assert_eq!( + canonical(serde_json::to_value(&r).unwrap()), + canonical(serde_json::json!({ "task": { "id": "t1" } })), + ); + + let s = StreamResponse { + payload: Some(stream_response::Payload::StatusUpdate(Box::new( + TaskStatusUpdateEvent { + task_id: "t1".into(), + ..Default::default() + }, + ))), + ..Default::default() + }; + assert_eq!( + canonical(serde_json::to_value(&s).unwrap()), + canonical(serde_json::json!({ "statusUpdate": { "taskId": "t1" } })), + ); +} diff --git a/a2a-rs/tests/proto_vendor_sync_test.rs b/a2a-rs/tests/proto_vendor_sync_test.rs new file mode 100644 index 0000000..fb3a48d --- /dev/null +++ b/a2a-rs/tests/proto_vendor_sync_test.rs @@ -0,0 +1,75 @@ +//! Guards against silent drift between the vendored protos and the spec mirror. +//! +//! `a2a-rs/proto/` holds the trimmed proto set that `build.rs` compiles and that +//! `cargo publish` packages (the spec mirror in `spec/` is *not* packaged). Those +//! files duplicate `spec/a2a.proto` + the handful of `spec/google/api/*.proto` +//! they import, so they can drift apart unnoticed. This test fails when any +//! vendored file no longer matches its `spec/` counterpart byte-for-byte — update +//! both together, or this is the failure that catches it. +//! +//! When the `spec/` mirror is absent (e.g. inside a packaged/published crate +//! where only `proto/` ships), the check is skipped rather than failed. + +use std::fs; +use std::path::{Path, PathBuf}; + +/// Collect every file under `dir`, relative to `dir`. +fn files_under(dir: &Path) -> Vec { + let mut out = Vec::new(); + let mut stack = vec![dir.to_path_buf()]; + while let Some(d) = stack.pop() { + for entry in fs::read_dir(&d).unwrap_or_else(|e| panic!("read_dir {d:?}: {e}")) { + let path = entry.unwrap().path(); + if path.is_dir() { + stack.push(path); + } else { + out.push(path.strip_prefix(dir).unwrap().to_path_buf()); + } + } + } + out.sort(); + out +} + +#[test] +fn vendored_protos_match_spec() { + let crate_dir = Path::new(env!("CARGO_MANIFEST_DIR")); + let vendored = crate_dir.join("proto"); + let spec = crate_dir.join("..").join("spec"); + + if !spec.is_dir() { + eprintln!("skipping: spec/ mirror not present at {spec:?}"); + return; + } + + let mut problems = Vec::new(); + for rel in files_under(&vendored) { + let vendored_file = vendored.join(&rel); + let spec_file = spec.join(&rel); + + match fs::read(&spec_file) { + Ok(spec_bytes) => { + let vendored_bytes = fs::read(&vendored_file).unwrap(); + if spec_bytes != vendored_bytes { + problems.push(format!( + " {} differs from spec/{}", + rel.display(), + rel.display() + )); + } + } + Err(_) => problems.push(format!( + " {} has no spec/ counterpart (spec/{} missing)", + rel.display(), + rel.display() + )), + } + } + + assert!( + problems.is_empty(), + "vendored protos in a2a-rs/proto/ have drifted from spec/:\n{}\n\ + Re-sync the two trees (update both `spec/` and `a2a-rs/proto/`).", + problems.join("\n") + ); +} diff --git a/a2a-rs/tests/push_notification_crud_test.rs b/a2a-rs/tests/push_notification_crud_test.rs deleted file mode 100644 index 6d955b9..0000000 --- a/a2a-rs/tests/push_notification_crud_test.rs +++ /dev/null @@ -1,392 +0,0 @@ -//! Integration tests for push notification config CRUD endpoints (v1.0.0) -//! -//! Tests the spec-compliant implementation of: -//! - tasks/pushNotificationConfig/list -//! - tasks/pushNotificationConfig/get -//! - tasks/pushNotificationConfig/delete - -#![cfg(all(feature = "http-client", feature = "http-server"))] - -mod common; - -use a2a_rs::{ - TaskPushNotificationConfig, - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, SimpleAgentInfo, - }, - port::{AsyncNotificationManager, AsyncTaskManager}, - services::AsyncA2AClient, -}; -use common::TestBusinessHandler; -use std::time::Duration; -use tokio::sync::oneshot; - -/// Helper to set up a server with a task -async fn setup_server(port: u16) -> (oneshot::Sender<()>, InMemoryTaskStorage) { - let storage = InMemoryTaskStorage::new(); - let storage_clone = storage.clone(); - - let handler = TestBusinessHandler::with_storage(storage); - - let test_agent_info = SimpleAgentInfo::new( - "test-agent".to_string(), - format!("http://localhost:{}", port), - ); - - let processor = DefaultRequestProcessor::with_handler(handler, test_agent_info); - - let agent_info = SimpleAgentInfo::new( - "Push Config Test Agent".to_string(), - format!("http://localhost:{}", port), - ); - - let server = HttpServer::new(processor, agent_info, format!("127.0.0.1:{}", port)); - - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => {} - } - }); - - tokio::time::sleep(Duration::from_millis(100)).await; - - (shutdown_tx, storage_clone) -} - -#[tokio::test] -async fn test_list_push_notification_configs_empty() { - let port = 9070; - let (shutdown, storage) = setup_server(port).await; - - // Create a task WITHOUT any push notification config - storage - .create_task("task_no_configs", "test_context") - .await - .unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // List configs for task with no configs - let result = client - .list_push_notification_configs("task_no_configs") - .await; - - shutdown.send(()).ok(); - - // Verify successful response with empty array - assert!(result.is_ok(), "List should succeed even with no configs"); - let configs = result.unwrap(); - assert_eq!(configs.len(), 0, "Should have 0 configs"); -} - -#[tokio::test] -async fn test_set_and_list_push_notification_config() { - let port = 9071; - let (shutdown, storage) = setup_server(port).await; - - // Create a task - storage - .create_task("task_with_config", "test_context") - .await - .unwrap(); - - // Set a push notification config using the storage API - let config = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "task_with_config".to_string(), - id: "config_1".to_string(), - url: "https://example.com/webhook1".to_string(), - token: "secret_token_123".to_string(), - authentication: None.into(), - ..Default::default() - }; - - storage.set_task_notification(&config).await.unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // List configs - let result = client - .list_push_notification_configs("task_with_config") - .await; - - shutdown.send(()).ok(); - - // Verify successful response - assert!( - result.is_ok(), - "List configs should succeed: {:?}", - result.err() - ); - let configs = result.unwrap(); - - // Should have 1 config (current implementation supports only 1 per task) - assert_eq!(configs.len(), 1, "Should have 1 config"); - - // Verify config details - let returned_config = &configs[0]; - assert_eq!(returned_config.task_id, "task_with_config"); - assert_eq!(returned_config.id, "config_1".to_string()); - assert_eq!(returned_config.url, "https://example.com/webhook1"); - assert_eq!(returned_config.token, "secret_token_123".to_string()); -} - -#[tokio::test] -async fn test_get_push_notification_config() { - let port = 9072; - let (shutdown, storage) = setup_server(port).await; - - // Create a task - storage - .create_task("task_get_config", "test_context") - .await - .unwrap(); - - // Set a push notification config - let config = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "task_get_config".to_string(), - id: "config_abc".to_string(), - url: "https://example.com/notifications".to_string(), - token: "bearer_token_xyz".to_string(), - authentication: None.into(), - ..Default::default() - }; - - storage.set_task_notification(&config).await.unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Get the config (note: pushNotificationConfigId is optional, current impl ignores it) - let result = client - .get_push_notification_config("task_get_config", "config_abc") - .await; - - shutdown.send(()).ok(); - - // Verify successful response - assert!( - result.is_ok(), - "Get config should succeed: {:?}", - result.err() - ); - let returned_config = result.unwrap(); - - // Verify config details - assert_eq!(returned_config.task_id, "task_get_config"); - assert_eq!(returned_config.id, "config_abc".to_string()); - assert_eq!(returned_config.url, "https://example.com/notifications"); - assert_eq!(returned_config.token, "bearer_token_xyz".to_string()); -} - -#[tokio::test] -async fn test_get_push_notification_config_not_found() { - let port = 9073; - let (shutdown, storage) = setup_server(port).await; - - // Create a task WITHOUT a config - storage - .create_task("task_no_config", "test_context") - .await - .unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Try to get non-existent config - let result = client - .get_push_notification_config("task_no_config", "nonexistent") - .await; - - shutdown.send(()).ok(); - - // Verify error is returned - assert!(result.is_err(), "Get nonexistent config should fail"); -} - -#[tokio::test] -async fn test_delete_push_notification_config() { - let port = 9074; - let (shutdown, storage) = setup_server(port).await; - - // Create a task - storage - .create_task("task_delete", "test_context") - .await - .unwrap(); - - // Set a push notification config - let config = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "task_delete".to_string(), - id: "config_to_delete".to_string(), - url: "https://example.com/webhook".to_string(), - token: String::new(), - authentication: None.into(), - ..Default::default() - }; - - storage.set_task_notification(&config).await.unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Verify config exists - let configs_before = client - .list_push_notification_configs("task_delete") - .await - .unwrap(); - assert_eq!( - configs_before.len(), - 1, - "Should have 1 config before deletion" - ); - - // Delete the config - let result = client - .delete_push_notification_config("task_delete", "config_to_delete") - .await; - - // Verify successful deletion - assert!( - result.is_ok(), - "Delete config should succeed: {:?}", - result.err() - ); - - // Verify config is gone - let configs_after = client - .list_push_notification_configs("task_delete") - .await - .unwrap(); - - shutdown.send(()).ok(); - - assert_eq!( - configs_after.len(), - 0, - "Should have 0 configs after deletion" - ); -} - -#[tokio::test] -async fn test_delete_nonexistent_push_config() { - let port = 9075; - let (shutdown, storage) = setup_server(port).await; - - // Create a task without config - storage - .create_task("task_empty", "test_context") - .await - .unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Try to delete non-existent config (DELETE is idempotent, so this should succeed) - let result = client - .delete_push_notification_config("task_empty", "nonexistent") - .await; - - shutdown.send(()).ok(); - - // Idempotent delete: succeeds even if config doesn't exist - assert!( - result.is_ok(), - "Delete is idempotent, should succeed even for nonexistent config" - ); -} - -#[tokio::test] -async fn test_delete_push_config_from_nonexistent_task() { - let port = 9076; - let (shutdown, _storage) = setup_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Try to delete config from non-existent task (DELETE is idempotent, so this should succeed) - let result = client - .delete_push_notification_config("nonexistent_task", "config_1") - .await; - - shutdown.send(()).ok(); - - // Idempotent delete: succeeds even if task doesn't exist - assert!( - result.is_ok(), - "Delete is idempotent, should succeed even for nonexistent task" - ); -} - -#[tokio::test] -async fn test_push_notification_config_with_authentication() { - let port = 9077; - let (shutdown, storage) = setup_server(port).await; - - // Create a task - storage - .create_task("task_with_auth", "test_context") - .await - .unwrap(); - - // Set a push notification config with authentication - let config = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "task_with_auth".to_string(), - id: "config_auth".to_string(), - url: "https://example.com/secure-webhook".to_string(), - token: "validation_token".to_string(), - authentication: Some(a2a_rs::domain::PushNotificationAuthenticationInfo { - scheme: "Bearer".to_string(), - credentials: "secret_credentials".to_string(), - ..Default::default() - }) - .into(), - ..Default::default() - }; - - storage.set_task_notification(&config).await.unwrap(); - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // List and verify authentication is preserved - let result = client - .list_push_notification_configs("task_with_auth") - .await; - - shutdown.send(()).ok(); - - assert!(result.is_ok()); - let configs = result.unwrap(); - assert_eq!(configs.len(), 1); - - let returned_config = &configs[0]; - assert!(returned_config.authentication.as_option().is_some()); - - let auth = returned_config.authentication.as_option().unwrap(); - assert_eq!(auth.scheme, "Bearer".to_string()); - assert_eq!(auth.credentials, "secret_credentials".to_string()); -} - -#[tokio::test] -async fn test_list_configs_for_nonexistent_task() { - let port = 9078; - let (shutdown, _storage) = setup_server(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Try to list configs for non-existent task - let result = client - .list_push_notification_configs("nonexistent_task") - .await; - - shutdown.send(()).ok(); - - // Should succeed with empty array (spec doesn't require task to exist for listing) - assert!( - result.is_ok(), - "List should succeed even for nonexistent task" - ); - let configs = result.unwrap(); - assert_eq!(configs.len(), 0, "Should return empty array"); -} diff --git a/a2a-rs/tests/push_notification_test.rs b/a2a-rs/tests/push_notification_test.rs deleted file mode 100644 index c64898e..0000000 --- a/a2a-rs/tests/push_notification_test.rs +++ /dev/null @@ -1,274 +0,0 @@ -//! Push notification tests - -#![cfg(all(feature = "http-client", feature = "http-server"))] - -mod common; - -use a2a_rs::{ - TaskPushNotificationConfig, - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, - PushNotificationSender, SimpleAgentInfo, - }, - domain::{A2AError, Message, Part, TaskArtifactUpdateEvent, TaskStatusUpdateEvent}, - services::AsyncA2AClient, -}; -use async_trait::async_trait; -use common::TestBusinessHandler; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use tokio::sync::oneshot; - -/// Mock push notification sender for testing -#[derive(Clone)] -struct MockPushNotificationSender { - status_updates: Arc>>, - artifact_updates: Arc>>, -} - -impl MockPushNotificationSender { - fn new() -> Self { - Self { - status_updates: Arc::new(Mutex::new(Vec::new())), - artifact_updates: Arc::new(Mutex::new(Vec::new())), - } - } - - fn get_status_updates(&self) -> Vec { - self.status_updates.lock().unwrap().clone() - } - - #[allow(dead_code)] - fn get_artifact_updates(&self) -> Vec { - self.artifact_updates.lock().unwrap().clone() - } -} - -#[async_trait] -impl PushNotificationSender for MockPushNotificationSender { - async fn send_status_update( - &self, - config: &a2a_rs::domain::TaskPushNotificationConfig, - event: &TaskStatusUpdateEvent, - ) -> Result<(), A2AError> { - // Record the update - let update = format!( - "Status update for task {} to URL {}", - event.task_id, config.url - ); - self.status_updates.lock().unwrap().push(update); - Ok(()) - } - - async fn send_artifact_update( - &self, - config: &a2a_rs::domain::TaskPushNotificationConfig, - event: &TaskArtifactUpdateEvent, - ) -> Result<(), A2AError> { - // Record the update - let update = format!( - "Artifact update for task {} to URL {}", - event.task_id, config.url - ); - self.artifact_updates.lock().unwrap().push(update); - Ok(()) - } -} - -/// Test push notification functionality -#[tokio::test] -async fn test_push_notifications() { - // Create a mock push notification sender - let push_sender = MockPushNotificationSender::new(); - let push_sender_clone = push_sender.clone(); - - // Create a storage with the push sender - let storage = InMemoryTaskStorage::with_push_sender(push_sender_clone); - - // Create business handler with the storage - let handler = TestBusinessHandler::with_storage(storage); - - // Create agent info for the processor - let test_agent_info = SimpleAgentInfo::new( - "test-agent".to_string(), - "http://localhost:8184".to_string(), - ); - - // Create a processor - let processor = DefaultRequestProcessor::with_handler(handler, test_agent_info); - - // Create an agent info provider - let agent_info = SimpleAgentInfo::new( - "Push Test Agent".to_string(), - "http://localhost:8184".to_string(), - ) - .with_push_notifications() - .with_state_transition_history(); - - // Create the server - let server = HttpServer::new(processor, agent_info, "127.0.0.1:8184".to_string()); - - // Create a shutdown channel - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - // Start the server in a separate task - let server_handle = tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => { - // Server will be dropped and shut down - } - } - }); - - // Give the server time to start - tokio::time::sleep(Duration::from_millis(100)).await; - - // Create the client - let client = HttpClient::new("http://localhost:8184".to_string()); - - // Test 1: Set push notification with ID (v1.0.0 feature) - let task_id = format!("push-task-{}", uuid::Uuid::new_v4()); - let push_config_id = "config-123".to_string(); - let push_config = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id.clone(), - id: push_config_id.clone(), - url: "https://example.com/webhook".to_string(), - token: "test-token".to_string(), - authentication: None.into(), - ..Default::default() - }; - - let result = client.set_task_push_notification(&push_config).await; - assert!(result.is_ok(), "Failed to set push notification with ID"); - - // Test 2: Send a task message - let message_id = format!("msg-{}", uuid::Uuid::new_v4()); - let message = Message::user_text("Hello, Push Notification Agent!".to_string(), message_id); - let _task = client - .send_task_message(&task_id, &message, None, None) - .await - .expect("Failed to send task message"); - - // Give time for push notifications to be processed - tokio::time::sleep(Duration::from_millis(100)).await; - - let artifact_part = Part::text("Artifact content".to_string()); - - let _artifact = a2a_rs::domain::Artifact { - artifact_id: format!("artifact-{}", uuid::Uuid::new_v4()), - name: "test-artifact".to_string(), - description: "A test artifact".to_string(), - parts: vec![artifact_part], - metadata: None.into(), - extensions: Vec::new(), - ..Default::default() - }; - - let artifact_message_id = format!("msg-{}", uuid::Uuid::new_v4()); - let artifact_message = Message::builder() - .message_id(artifact_message_id) - .context_id("default".to_string()) - .role(a2a_rs::domain::Role::Agent) - .build(); - - // Send the artifact message - let _task = client - .send_task_message(&task_id, &artifact_message, None, None) - .await - .expect("Failed to send artifact message"); - - // Give time for push notifications to be processed - tokio::time::sleep(Duration::from_millis(100)).await; - - // Test 4: Cancel the task - let _canceled_task = client - .cancel_task(&task_id) - .await - .expect("Failed to cancel task"); - - // Give time for push notifications to be processed - tokio::time::sleep(Duration::from_millis(100)).await; - - // Verify that push notifications were sent - let status_updates = push_sender.get_status_updates(); - println!("Status updates: {:?}", status_updates); - assert!( - !status_updates.is_empty(), - "Should have sent at least one status update" - ); - - // Test 5: Test multiple push notification configs (v1.0.0 feature) - // Set a second push notification config with a different ID - let task_id_multi = format!("push-task-multi-{}", uuid::Uuid::new_v4()); - let push_config_1 = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id_multi.clone(), - id: "config-1".to_string(), - url: "https://example.com/webhook1".to_string(), - token: "token-1".to_string(), - authentication: None.into(), - ..Default::default() - }; - let push_config_2 = a2a_rs::domain::TaskPushNotificationConfig { - tenant: String::new(), - task_id: task_id_multi.clone(), - id: "config-2".to_string(), - url: "https://example.com/webhook2".to_string(), - token: "token-2".to_string(), - authentication: None.into(), - ..Default::default() - }; - - // Set both configs - let _ = client.set_task_push_notification(&push_config_1).await; - let _ = client.set_task_push_notification(&push_config_2).await; - - // Verify that push notifications were sent with both configs - println!("Successfully set multiple push notification configs for task"); - - // Shut down the server - shutdown_tx - .send(()) - .expect("Failed to send shutdown signal"); - - // Wait for the server to shut down - server_handle.await.expect("Server task failed"); -} - -/// Test push notification config with ID field (v1.0.0) -#[tokio::test] -async fn test_push_notification_config_id() { - // Create a config with an ID - let config_with_id = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "dummy".to_string(), - id: "unique-config-123".to_string(), - url: "https://example.com/webhook".to_string(), - token: "bearer-token".to_string(), - authentication: None.into(), - ..Default::default() - }; - - // Serialize and verify ID is present - let config_json = serde_json::to_value(&config_with_id).unwrap(); - assert_eq!(config_json["id"], "unique-config-123"); - assert_eq!(config_json["url"], "https://example.com/webhook"); - - // Create a config without an ID (should still be valid) - let config_without_id = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "dummy".to_string(), - id: String::new(), - url: "https://example.com/webhook".to_string(), - token: "bearer-token".to_string(), - authentication: None.into(), - ..Default::default() - }; - - // Serialize and verify ID is not present or empty when None/empty in proto - let config_json = serde_json::to_value(&config_without_id).unwrap(); - assert!(config_json.get("id").is_none_or(|v| v.as_str() == Some(""))); - assert_eq!(config_json["url"], "https://example.com/webhook"); -} diff --git a/a2a-rs/tests/retry_transport_test.rs b/a2a-rs/tests/retry_transport_test.rs new file mode 100644 index 0000000..016ea82 --- /dev/null +++ b/a2a-rs/tests/retry_transport_test.rs @@ -0,0 +1,211 @@ +//! Unit tests for the resilient streaming core (`subscribe_resilient` / +//! `RetryingTransport`) against a scripted fake [`Transport`]. + +#![cfg(feature = "client")] + +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use a2a_rs::domain::{ + A2AError, ListTasksParams, ListTasksResult, Message, RetryPolicy, Task, + TaskPushNotificationConfig, TaskState, TaskStatus, TaskStatusUpdateEvent, +}; +use a2a_rs::port::{StreamEvent, StreamItem, Transport}; +use a2a_rs::subscribe_resilient; + +type EventStream = Pin> + Send>>; +/// One successful subscribe yields these events; a `None` script entry makes the +/// subscribe call itself fail (a connection error). +type Script = Vec>; + +#[derive(Default)] +struct FakeInner { + scripts: Mutex>>, + calls: Mutex, + seen_resume: Mutex>>, +} + +#[derive(Clone, Default)] +struct FakeTransport { + inner: Arc, +} + +impl FakeTransport { + fn with_scripts(scripts: Vec>) -> Self { + Self { + inner: Arc::new(FakeInner { + scripts: Mutex::new(scripts.into()), + ..Default::default() + }), + } + } + fn calls(&self) -> u32 { + *self.inner.calls.lock().unwrap() + } + fn seen_resume(&self) -> Vec> { + self.inner.seen_resume.lock().unwrap().clone() + } +} + +fn fast_policy(max_retries: u32) -> RetryPolicy { + RetryPolicy { + base_delay: Duration::from_millis(1), + max_delay: Duration::from_millis(2), + max_retries, + jitter_ms: 0, + } +} + +fn working(id: u64) -> StreamEvent { + StreamEvent::new( + Some(id), + StreamItem::StatusUpdate(TaskStatusUpdateEvent { + task_id: "t".to_string(), + context_id: "c".to_string(), + kind: "status-update".to_string(), + status: TaskStatus::new(TaskState::Working, None), + metadata: None, + }), + ) +} + +fn terminal_task() -> StreamEvent { + let task = Task::builder() + .id("t".to_string()) + .status(TaskStatus::new(TaskState::Completed, None)) + .build(); + StreamEvent::untagged(StreamItem::Task(task)) +} + +#[async_trait] +impl Transport for FakeTransport { + fn protocol(&self) -> &str { + "FAKE" + } + async fn send_task_message( + &self, + _: &str, + _: &Message, + _: Option<&str>, + _: Option, + ) -> Result { + unimplemented!() + } + async fn get_task(&self, _: &str, _: Option) -> Result { + unimplemented!() + } + async fn cancel_task(&self, _: &str) -> Result { + unimplemented!() + } + async fn set_task_push_notification( + &self, + _: &TaskPushNotificationConfig, + ) -> Result { + unimplemented!() + } + async fn get_task_push_notification( + &self, + _: &str, + ) -> Result { + unimplemented!() + } + async fn list_tasks(&self, _: &ListTasksParams) -> Result { + unimplemented!() + } + async fn list_push_notification_configs( + &self, + _: &str, + ) -> Result, A2AError> { + unimplemented!() + } + async fn get_push_notification_config( + &self, + _: &str, + _: &str, + ) -> Result { + unimplemented!() + } + async fn delete_push_notification_config(&self, _: &str, _: &str) -> Result<(), A2AError> { + unimplemented!() + } + async fn subscribe_to_task( + &self, + _task_id: &str, + _history_length: Option, + last_event_id: Option<&str>, + ) -> Result { + *self.inner.calls.lock().unwrap() += 1; + self.inner + .seen_resume + .lock() + .unwrap() + .push(last_event_id.map(|s| s.to_string())); + match self.inner.scripts.lock().unwrap().pop_front() { + Some(Some(script)) => Ok(Box::pin(futures::stream::iter(script))), + _ => Err(A2AError::Internal("connect failed".to_string())), + } + } +} + +/// The core retries failed connects with backoff, then forwards events from the +/// first successful connect and ends on the terminal task. +#[tokio::test] +async fn retries_then_succeeds() { + let fake = FakeTransport::with_scripts(vec![None, None, Some(vec![Ok(terminal_task())])]); + let mut stream = subscribe_resilient(Arc::new(fake.clone()), "t", None, None, fast_policy(5)); + + let items: Vec<_> = collect(&mut stream).await; + assert_eq!(items.len(), 1, "one terminal event"); + assert!(matches!(items[0], StreamItem::Task(_))); + assert_eq!(fake.calls(), 3, "two failures + one success"); +} + +/// After `max_retries` consecutive failed connects the stream yields a final +/// error and ends. +#[tokio::test] +async fn gives_up_after_max_retries() { + let fake = FakeTransport::with_scripts(vec![None, None, None, None, None]); + let mut stream = subscribe_resilient(Arc::new(fake.clone()), "t", None, None, fast_policy(2)); + + let mut last = None; + while let Some(item) = stream.next().await { + last = Some(item); + } + assert!(matches!(last, Some(Err(_))), "ends with an error item"); + // attempt 0 (initial) + retries 1,2 = 3 connects, then attempt 3 > 2 gives up. + assert_eq!(fake.calls(), 3); +} + +/// On reconnect the core echoes the last observed event id as `Last-Event-ID`. +#[tokio::test] +async fn reconnect_threads_last_event_id() { + let fake = FakeTransport::with_scripts(vec![ + Some(vec![Ok(working(5))]), // first connect: one event id 5, then ends + Some(vec![Ok(terminal_task())]), // reconnect: terminal + ]); + let mut stream = subscribe_resilient(Arc::new(fake.clone()), "t", None, None, fast_policy(5)); + + let items = collect(&mut stream).await; + assert_eq!(items.len(), 2); + assert!(matches!(items[0], StreamItem::StatusUpdate(_))); + assert!(matches!(items[1], StreamItem::Task(_))); + // First connect carries no resume id; the reconnect echoes id 5. + assert_eq!(fake.seen_resume(), vec![None, Some("5".to_string())]); +} + +async fn collect(stream: &mut EventStream) -> Vec { + let mut out = Vec::new(); + while let Some(item) = stream.next().await { + match item { + Ok(ev) => out.push(ev.item), + Err(_) => break, + } + } + out +} diff --git a/a2a-rs/tests/spec_compliance_test.rs b/a2a-rs/tests/spec_compliance_test.rs deleted file mode 100644 index 0e29f9b..0000000 --- a/a2a-rs/tests/spec_compliance_test.rs +++ /dev/null @@ -1,1067 +0,0 @@ -//! A2A Protocol Specification Compliance Tests -//! -//! This module validates that our Rust types match the JSON Schema definitions -//! in the A2A specification files located in ../spec/ - -mod common; - -use a2a_rs::{ - adapter::SimpleAgentInfo, - domain::{Message, Part, TaskState}, -}; -use jsonschema::{Draft, Validator}; -use serde_json::{Value, json}; -use std::fs; - -/// Load and compile a JSON Schema from the spec directory -#[allow(dead_code)] -fn load_schema(filename: &str) -> Validator { - let schema_path = format!("../spec/{}", filename); - let schema_content = fs::read_to_string(&schema_path) - .unwrap_or_else(|_| panic!("Failed to read schema file: {}", schema_path)); - - let schema: Value = serde_json::from_str(&schema_content) - .unwrap_or_else(|_| panic!("Failed to parse schema JSON: {}", filename)); - - Validator::options() - .with_draft(Draft::Draft7) - .build(&schema) - .unwrap_or_else(|_| panic!("Failed to compile schema: {}", filename)) -} - -/// Extract a specific definition from a schema file with all definitions context -fn extract_definition(schema_content: &str, definition_name: &str) -> Value { - let schema: Value = serde_json::from_str(schema_content).unwrap(); - let _definition = schema["definitions"][definition_name].clone(); - - // Create a new schema with the specific definition as root but keep all definitions - json!({ - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "definitions": schema["definitions"], - "$ref": format!("#/definitions/{}", definition_name) - }) -} - -#[tokio::test] -async fn test_agent_card_compliance() { - use a2a_rs::services::AgentInfoProvider; - // Create a sample AgentCard using our SimpleAgentInfo - let agent_info = SimpleAgentInfo::new( - "Test Agent".to_string(), - "https://api.example.com".to_string(), - ) - .with_description("A test agent for A2A protocol compliance".to_string()) - .with_version("1.0.0".to_string()) - .with_provider( - "Test Organization".to_string(), - "https://example.org".to_string(), - ) - .with_documentation_url("https://docs.example.org".to_string()) - .with_streaming() - .with_push_notifications() - .with_state_transition_history() - .add_skill( - "echo".to_string(), - "Echo Skill".to_string(), - Some("Echoes input back to user".to_string()), - ) - .add_skill( - "translate".to_string(), - "Translation".to_string(), - Some("Translates text between languages".to_string()), - ); - - let agent_card = agent_info.get_agent_card().await.unwrap(); - - // Serialize to JSON - let agent_card_json = serde_json::to_value(&agent_card).unwrap(); - println!( - "AgentCard JSON: {}", - serde_json::to_string_pretty(&agent_card_json).unwrap() - ); - - // Load the agent schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let agent_card_schema = extract_definition(&schema_content, "AgentCard"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&agent_card_schema) - .expect("Failed to compile AgentCard schema"); - - // Validate against schema - let result = schema.validate(&agent_card_json); - if let Err(errors) = result { - for error in errors { - eprintln!("AgentCard validation error: {}", error); - eprintln!("Instance path: {}", error.instance_path); - } - panic!("AgentCard does not comply with A2A specification"); - } -} - -#[test] -fn test_message_compliance() { - // Create a comprehensive message with all part types - let message_id = uuid::Uuid::new_v4().to_string(); - let mut message = Message::user_text("Hello, agent!".to_string(), message_id.clone()); - - // Add a data part - let data_val: buffa_types::google::protobuf::Value = serde_json::from_value(json!({ - "key": "value", - "number": 42, - "nested": { - "array": [1, 2, 3] - } - })) - .unwrap(); - let data_part = Part::data(data_val); - message.add_part(data_part); - - // Add a file part - let file_part = Part::file_from_bytes( - "SGVsbG8gV29ybGQ=".to_string().into_bytes(), // "Hello World" in base64 - Some("test.txt".to_string()), - Some("text/plain".to_string()), - ); - message.add_part_validated(file_part).unwrap(); - - // Set context and task IDs - message.context_id = "ctx-123".to_string(); - message.task_id = "task-456".to_string(); - - // Serialize to JSON - let message_json = serde_json::to_value(&message).unwrap(); - println!( - "Message JSON: {}", - serde_json::to_string_pretty(&message_json).unwrap() - ); - - // Load and validate against Message schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let message_schema = extract_definition(&schema_content, "Message"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&message_schema) - .expect("Failed to compile Message schema"); - - let result = schema.validate(&message_json); - if let Err(errors) = result { - for error in errors { - eprintln!("Message validation error: {}", error); - eprintln!("Instance path: {}", error.instance_path); - } - panic!("Message does not comply with A2A specification"); - } -} - -#[test] -fn test_task_compliance() { - // Create a task - let context_id = "ctx-789".to_string(); - use a2a_rs::domain::Task; - let mut task = Task::new("task-987".to_string(), context_id.clone()); - - // Add history messages - let msg1 = Message::user_text("Initial message".to_string(), "msg-1".to_string()); - let msg2 = Message::agent_text("Agent response".to_string(), "msg-2".to_string()); - - task.update_status(TaskState::Working, Some(msg1)); - task.update_status(TaskState::Completed, Some(msg2)); - - // Serialize to JSON - let task_json = serde_json::to_value(&task).unwrap(); - println!( - "Task JSON: {}", - serde_json::to_string_pretty(&task_json).unwrap() - ); - - // Load and validate against Task schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let task_schema = extract_definition(&schema_content, "Task"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&task_schema) - .expect("Failed to compile Task schema"); - - let result = schema.validate(&task_json); - if let Err(errors) = result { - for error in errors { - eprintln!("Task validation error: {}", error); - eprintln!("Instance path: {}", error.instance_path); - } - panic!("Task does not comply with A2A specification"); - } -} - -#[test] -fn test_task_states_compliance() { - // Test all valid task states according to the specification - let valid_states = [ - TaskState::Submitted, - TaskState::Working, - TaskState::InputRequired, - TaskState::Completed, - TaskState::Canceled, - TaskState::Failed, - TaskState::Rejected, - TaskState::AuthRequired, - TaskState::Unknown, - ]; - - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let task_state_schema = extract_definition(&schema_content, "TaskState"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&task_state_schema) - .expect("Failed to compile TaskState schema"); - - for state in &valid_states { - let state_json = serde_json::to_value(state).unwrap(); - - let result = schema.validate(&state_json); - if let Err(errors) = result { - for error in errors { - eprintln!("TaskState {:?} validation error: {}", state, error); - } - panic!( - "TaskState {:?} does not comply with A2A specification", - state - ); - } - } -} - -#[test] -fn test_error_codes_compliance() { - // Test that our error codes match the specification - - // Standard JSON-RPC errors - let jsonrpc_errors = vec![ - (-32700, "Parse error"), - (-32600, "Invalid Request"), - (-32601, "Method not found"), - (-32602, "Invalid params"), - (-32603, "Internal error"), - ]; - - // A2A-specific errors (v1.0.0 includes new -32007 error) - let a2a_errors = vec![ - (-32001, "Task not found"), - (-32002, "Task not cancelable"), - (-32003, "Push notifications not supported"), - (-32004, "Operation not supported"), - (-32005, "Content type not supported"), - (-32006, "Invalid agent response"), - (-32007, "Authenticated Extended Card is not configured"), - ]; - - // All error codes should be documented in the spec - let all_errors = [jsonrpc_errors, a2a_errors].concat(); - - for (code, message) in all_errors { - println!("Checking error code {} with message: {}", code, message); - // This validates that our error codes align with the specification - // The actual validation would depend on how we structure our error types - } -} - -#[test] -fn test_authenticated_extended_card_error() { - use a2a_rs::domain::error::A2AError; - - let error = A2AError::AuthenticatedExtendedCardNotConfigured; - let jsonrpc_error = error.to_jsonrpc_error(); - - assert_eq!(jsonrpc_error["code"], -32007); - assert_eq!( - jsonrpc_error["message"], - "Authenticated Extended Card is not configured" - ); -} - -#[test] -fn test_agent_card_v100_fields() { - use a2a_rs::domain::{AgentCapabilities, AgentCard, AgentCardSignature, AgentInterface}; - use std::collections::HashMap; - - // Create an AgentCard with all v1.0.0 fields using the v1.0.0 builder - let header_struct = { - let mut header = HashMap::new(); - header.insert( - "alg".to_string(), - serde_json::Value::String("RS256".to_string()), - ); - let header_val = serde_json::to_value(header).unwrap(); - serde_json::from_value(header_val).unwrap() - }; - - let card = AgentCard::builder() - .name("Test Agent v1.0.0".to_string()) - .description("Agent with v1.0.0 features".to_string()) - .url("https://api.example.com/jsonrpc".to_string()) - .version("2.0.0".to_string()) - .protocol_version("0.3.0".to_string()) - .preferred_transport("JSONRPC".to_string()) - .capabilities(AgentCapabilities::default()) - .default_input_modes(vec!["text".to_string()]) - .default_output_modes(vec!["text".to_string()]) - .additional_interfaces(vec![ - AgentInterface { - url: "https://api.example.com/grpc".to_string(), - protocol_binding: "GRPC".to_string(), - protocol_version: "0.3.0".to_string(), - ..Default::default() - }, - AgentInterface { - url: "https://api.example.com/http".to_string(), - protocol_binding: "HTTP+JSON".to_string(), - protocol_version: "0.3.0".to_string(), - ..Default::default() - }, - ]) - .icon_url("https://example.com/icon.png".to_string()) - .signatures(vec![AgentCardSignature { - protected: "eyJhbGciOiJSUzI1NiJ9".to_string(), - signature: "cC4hiUPoj9Eetdgtv3hF80EGrhuB__dzERat0XF9g2VtQgr9PJbu3XOiZj5RZmh7AAuHIm4Bh-0Qc_lF5YKt_O8W2Fp5jujGbds9uJdbF9CUAr7t1dnZcAcQjbKBYNX4BAynRFdiuB--f_nZLgrnbyTyWzO75vRK5h6xBArLIARNPvkSjtQBMHlb1L07Qe7K0GarZRmB_eSN9383LcOLn6_dO--xi12jzDwusC-eOkHWEsqtFZESc6BfI7noOPqvhJ1phCnvWh6IeYI2w9QOYEUipUTI8np6LbgGY9Fs98rqVt5AXLIhWkWywlVmtVrBp0igcN_IoypGlUPQGe77Rw".to_string(), - header: buffa::MessageField::some(header_struct), - ..Default::default() - }]) - .skills(vec![]) - .build(); - - println!( - "DEBUG: card.supported_interfaces = {:?}", - card.supported_interfaces - ); - // Test protocol_version (should default to "0.3.0") - assert_eq!(card.protocol_version(), "0.3.0"); - - // Test preferred_transport (should default to "JSONRPC") - assert_eq!(card.preferred_transport(), "JSONRPC"); - - // Serialize and validate - let card_json = serde_json::to_value(&card).unwrap(); - println!( - "AgentCard v1.0.0: {}", - serde_json::to_string_pretty(&card_json).unwrap() - ); - - // Verify the v1.0.0 fields are present - assert_eq!( - card_json["supportedInterfaces"][0]["protocolVersion"], - "0.3.0" - ); - assert_eq!( - card_json["supportedInterfaces"][0]["protocolBinding"], - "JSONRPC" - ); - assert!(card_json["supportedInterfaces"].is_array()); - assert_eq!( - card_json["supportedInterfaces"][1]["protocolBinding"], - "GRPC" - ); - assert_eq!(card_json["iconUrl"], "https://example.com/icon.png"); - assert!(card_json["signatures"].is_array()); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let agent_card_schema = extract_definition(&schema_content, "AgentCard"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&agent_card_schema) - .expect("Failed to compile AgentCard schema"); - - let result = schema.validate(&card_json); - if let Err(errors) = result { - for error in errors { - eprintln!("AgentCard v1.0.0 validation error: {}", error); - eprintln!("Instance path: {}", error.instance_path); - } - panic!("AgentCard v1.0.0 does not comply with A2A specification"); - } -} - -#[test] -fn test_agent_capabilities_extensions() { - use a2a_rs::domain::{AgentCapabilities, AgentExtension}; - use std::collections::HashMap; - - let mut capabilities = AgentCapabilities::default(); - - // Add extensions - let mut ext_params = HashMap::new(); - ext_params.insert( - "version".to_string(), - serde_json::Value::String("1.0".to_string()), - ); - let ext_params_val = serde_json::to_value(&ext_params).unwrap(); - let ext_params_struct: buffa_types::google::protobuf::Struct = - serde_json::from_value(ext_params_val).unwrap(); - - capabilities.extensions = vec![ - AgentExtension { - uri: "https://example.com/extensions/custom-auth".to_string(), - description: "Custom authentication extension".to_string(), - required: true, - params: buffa::MessageField::some(ext_params_struct), - ..Default::default() - }, - AgentExtension { - uri: "https://example.com/extensions/advanced-features".to_string(), - description: "Advanced features extension".to_string(), - required: false, - params: buffa::MessageField::none(), - ..Default::default() - }, - ]; - - // Serialize and verify - let capabilities_json = serde_json::to_value(&capabilities).unwrap(); - println!( - "AgentCapabilities with extensions: {}", - serde_json::to_string_pretty(&capabilities_json).unwrap() - ); - - assert!(capabilities_json["extensions"].is_array()); - assert_eq!( - capabilities_json["extensions"][0]["uri"], - "https://example.com/extensions/custom-auth" - ); - assert_eq!(capabilities_json["extensions"][0]["required"], true); - assert!(capabilities_json["extensions"][1]["required"].is_null()); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let capabilities_schema = extract_definition(&schema_content, "AgentCapabilities"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&capabilities_schema) - .expect("Failed to compile AgentCapabilities schema"); - - let result = schema.validate(&capabilities_json); - if let Err(errors) = result { - for error in errors { - eprintln!("AgentCapabilities validation error: {}", error); - } - panic!("AgentCapabilities does not comply with A2A specification"); - } -} - -#[test] -fn test_agent_skill_security() { - use a2a_rs::domain::AgentSkill; - use std::collections::HashMap; - - // Create a skill with security requirements - let mut security_req = HashMap::new(); - security_req.insert( - "oauth2".to_string(), - vec!["read:data".to_string(), "write:data".to_string()], - ); - - let skill = AgentSkill::new( - "secure-operation".to_string(), - "Secure Operation".to_string(), - "An operation requiring OAuth2 authentication".to_string(), - vec!["security".to_string(), "auth".to_string()], - ) - .with_security(vec![security_req]); - - // Serialize and verify - let skill_json = serde_json::to_value(&skill).unwrap(); - println!( - "AgentSkill with security: {}", - serde_json::to_string_pretty(&skill_json).unwrap() - ); - - assert!(skill_json["securityRequirements"].is_array()); - assert_eq!( - skill_json["securityRequirements"][0]["schemes"]["oauth2"]["list"][0], - "read:data" - ); - assert_eq!( - skill_json["securityRequirements"][0]["schemes"]["oauth2"]["list"][1], - "write:data" - ); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let skill_schema = extract_definition(&schema_content, "AgentSkill"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&skill_schema) - .expect("Failed to compile AgentSkill schema"); - - let result = schema.validate(&skill_json); - if let Err(errors) = result { - for error in errors { - eprintln!("AgentSkill validation error: {}", error); - } - panic!("AgentSkill does not comply with A2A specification"); - } -} - -#[test] -fn test_message_extensions_field() { - use a2a_rs::domain::Message; - - // Create a message with extensions - let mut message = Message::user_text( - "Test message with extensions".to_string(), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - - message.extensions = vec![ - "https://example.com/extensions/custom-protocol".to_string(), - "https://example.com/extensions/advanced-features".to_string(), - ]; - - // Serialize and verify - let message_json = serde_json::to_value(&message).unwrap(); - println!( - "Message with extensions: {}", - serde_json::to_string_pretty(&message_json).unwrap() - ); - - assert!(message_json["extensions"].is_array()); - assert_eq!( - message_json["extensions"][0], - "https://example.com/extensions/custom-protocol" - ); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let message_schema = extract_definition(&schema_content, "Message"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&message_schema) - .expect("Failed to compile Message schema"); - - let result = schema.validate(&message_json); - if let Err(errors) = result { - for error in errors { - eprintln!("Message with extensions validation error: {}", error); - } - panic!("Message does not comply with A2A specification"); - } -} - -#[test] -fn test_artifact_extensions_field() { - use a2a_rs::domain::{Artifact, Part}; - - // Create an artifact with extensions - let artifact = Artifact { - artifact_id: format!("artifact-{}", uuid::Uuid::new_v4()), - name: "Test Artifact".to_string(), - description: "Artifact with extension support".to_string(), - parts: vec![Part::text("Artifact content".to_string())], - metadata: None.into(), - extensions: vec!["https://example.com/extensions/artifact-encryption".to_string()], - ..Default::default() - }; - - // Serialize and verify - let artifact_json = serde_json::to_value(&artifact).unwrap(); - println!( - "Artifact with extensions: {}", - serde_json::to_string_pretty(&artifact_json).unwrap() - ); - - assert!(artifact_json["extensions"].is_array()); - assert_eq!( - artifact_json["extensions"][0], - "https://example.com/extensions/artifact-encryption" - ); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let artifact_schema = extract_definition(&schema_content, "Artifact"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&artifact_schema) - .expect("Failed to compile Artifact schema"); - - let result = schema.validate(&artifact_json); - if let Err(errors) = result { - for error in errors { - eprintln!("Artifact with extensions validation error: {}", error); - } - panic!("Artifact does not comply with A2A specification"); - } -} - -#[test] -fn test_mutual_tls_security_scheme() { - use a2a_rs::domain::SecurityScheme; - - let mtls_scheme = - SecurityScheme::mutual_tls(Some("Client certificate authentication".to_string())); - - // Serialize and verify - let scheme_json = serde_json::to_value(&mtls_scheme).unwrap(); - println!( - "MutualTLS SecurityScheme: {}", - serde_json::to_string_pretty(&scheme_json).unwrap() - ); - - assert!(scheme_json.get("mtlsSecurityScheme").is_some()); - assert_eq!( - scheme_json["mtlsSecurityScheme"]["description"], - "Client certificate authentication" - ); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let security_schema = extract_definition(&schema_content, "SecurityScheme"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&security_schema) - .expect("Failed to compile SecurityScheme schema"); - - let result = schema.validate(&scheme_json); - if let Err(errors) = result { - for error in errors { - eprintln!("MutualTLS SecurityScheme validation error: {}", error); - } - panic!("MutualTLS SecurityScheme does not comply with A2A specification"); - } -} - -#[test] -fn test_oauth2_with_metadata_url() { - use a2a_rs::domain::{ClientCredentialsOAuthFlow, OAuthFlows, SecurityScheme}; - use std::collections::HashMap; - - let mut scopes = HashMap::new(); - scopes.insert("read:data".to_string(), "Read access to data".to_string()); - scopes.insert("write:data".to_string(), "Write access to data".to_string()); - - let flow = ClientCredentialsOAuthFlow { - token_url: "https://auth.example.com/token".to_string(), - refresh_url: "https://auth.example.com/refresh".to_string(), - scopes, - ..Default::default() - }; - let flows = OAuthFlows::client_credentials(flow); - let oauth2_scheme = SecurityScheme::oauth2( - flows, - Some("OAuth2 with metadata discovery".to_string()), - Some("https://auth.example.com/.well-known/oauth-authorization-server".to_string()), - ); - - // Serialize and verify - let scheme_json = serde_json::to_value(&oauth2_scheme).unwrap(); - println!( - "OAuth2 with metadata URL: {}", - serde_json::to_string_pretty(&scheme_json).unwrap() - ); - - assert!(scheme_json.get("oauth2SecurityScheme").is_some()); - assert_eq!( - scheme_json["oauth2SecurityScheme"]["oauth2MetadataUrl"], - "https://auth.example.com/.well-known/oauth-authorization-server" - ); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let security_schema = extract_definition(&schema_content, "SecurityScheme"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&security_schema) - .expect("Failed to compile SecurityScheme schema"); - - let result = schema.validate(&scheme_json); - if let Err(errors) = result { - for error in errors { - eprintln!("OAuth2 SecurityScheme validation error: {}", error); - } - panic!("OAuth2 SecurityScheme does not comply with A2A specification"); - } -} - -#[test] -fn test_list_tasks_params() { - use a2a_rs::domain::{ListTasksParams, TaskState}; - - let params = ListTasksParams { - context_id: Some("ctx-123".to_string()), - status: Some(TaskState::Working), - page_size: Some(25), - page_token: None, - history_length: Some(10), - include_artifacts: Some(true), - status_timestamp_after: Some("2024-01-01T00:00:00Z".to_string()), // 2024-01-01 00:00:00 UTC - metadata: None, - }; - - // Serialize and verify - let params_json = serde_json::to_value(¶ms).unwrap(); - println!( - "ListTasksParams: {}", - serde_json::to_string_pretty(¶ms_json).unwrap() - ); - - assert_eq!(params_json["contextId"], "ctx-123"); - assert_eq!(params_json["status"], "TASK_STATE_WORKING"); - assert_eq!(params_json["pageSize"], 25); - assert_eq!(params_json["historyLength"], 10); - assert_eq!(params_json["includeArtifacts"], true); - assert_eq!(params_json["statusTimestampAfter"], "2024-01-01T00:00:00Z"); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let params_schema = extract_definition(&schema_content, "ListTasksParams"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(¶ms_schema) - .expect("Failed to compile ListTasksParams schema"); - - let result = schema.validate(¶ms_json); - if let Err(errors) = result { - for error in errors { - eprintln!("ListTasksParams validation error: {}", error); - } - panic!("ListTasksParams does not comply with A2A specification"); - } -} - -#[test] -fn test_push_notification_config_with_id() { - use a2a_rs::domain::TaskPushNotificationConfig; - let config = TaskPushNotificationConfig { - tenant: String::new(), - task_id: "dummy".to_string(), - - id: "config-abc123".to_string(), - url: "https://client.example.com/webhook".to_string(), - token: "bearer-token-xyz".to_string(), - authentication: None.into(), - ..Default::default() - }; - - // Serialize and verify - let config_json = serde_json::to_value(&config).unwrap(); - println!( - "PushNotificationConfig with id: {}", - serde_json::to_string_pretty(&config_json).unwrap() - ); - - assert_eq!(config_json["id"], "config-abc123"); - assert_eq!(config_json["url"], "https://client.example.com/webhook"); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let config_schema = extract_definition(&schema_content, "PushNotificationConfig"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&config_schema) - .expect("Failed to compile PushNotificationConfig schema"); - - let result = schema.validate(&config_json); - if let Err(errors) = result { - for error in errors { - eprintln!("PushNotificationConfig validation error: {}", error); - } - panic!("PushNotificationConfig does not comply with A2A specification"); - } -} - -#[test] -fn test_message_send_configuration_optional_accepted_output_modes() { - use a2a_rs::domain::MessageSendConfiguration; - - // Test that acceptedOutputModes is now optional in v1.0.0 - let config = MessageSendConfiguration { - accepted_output_modes: None, // This should be valid now - history_length: Some(10), - push_notification_config: None, - blocking: Some(false), - }; - - // Serialize and verify - let config_json = serde_json::to_value(&config).unwrap(); - println!( - "MessageSendConfiguration: {}", - serde_json::to_string_pretty(&config_json).unwrap() - ); - - // acceptedOutputModes should not be serialized when None - assert!(config_json.get("acceptedOutputModes").is_none()); - assert_eq!(config_json["historyLength"], 10); - - // Validate against schema - let schema_content = fs::read_to_string("../spec/specification.json") - .expect("Failed to read specification.json"); - let config_schema = extract_definition(&schema_content, "MessageSendConfiguration"); - - let schema = Validator::options() - .with_draft(Draft::Draft7) - .build(&config_schema) - .expect("Failed to compile MessageSendConfiguration schema"); - - let result = schema.validate(&config_json); - if let Err(errors) = result { - for error in errors { - eprintln!("MessageSendConfiguration validation error: {}", error); - } - panic!("MessageSendConfiguration does not comply with A2A specification"); - } -} - -#[cfg(test)] -mod property_based_tests { - use super::*; - use proptest::prelude::*; - - proptest! { - #[test] - fn message_serialization_roundtrip( - text in ".*", - message_id in ".*", - role in prop::sample::select(vec!["user", "agent"]), - ) { - let message = if role == "user" { - Message::user_text(text.clone(), message_id.clone()) - } else { - Message::agent_text(text.clone(), message_id.clone()) - }; - - // Serialize and deserialize - let json = serde_json::to_value(&message).unwrap(); - let deserialized: Message = serde_json::from_value(json).unwrap(); - - // Check that essential properties are preserved - prop_assert_eq!(message.message_id, deserialized.message_id); - prop_assert_eq!(message.role, deserialized.role); - prop_assert_eq!(message.parts.len(), deserialized.parts.len()); - } - - #[test] - fn task_id_validation(task_id in ".*") { - if !task_id.is_empty() { - let context_id = "ctx-test".to_string(); - use a2a_rs::domain::Task; - let task = Task::new(task_id.clone(), context_id); - prop_assert_eq!(task.id, task_id); - } - } - } -} - -// Priority 3: Error Handling and Validation Tests - -#[tokio::test] -async fn test_task_list_page_size_validation() { - use a2a_rs::{ - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, SimpleAgentInfo, - }, - services::AsyncA2AClient, - }; - use common::TestBusinessHandler; - use std::time::Duration; - use tokio::sync::oneshot; - - let port = 9500; - let storage = InMemoryTaskStorage::new(); - let handler = TestBusinessHandler::with_storage(storage); - let agent_info = SimpleAgentInfo::new( - "Validation Test Agent".to_string(), - format!("http://localhost:{}", port), - ); - - let processor = DefaultRequestProcessor::with_handler(handler, agent_info.clone()); - let server = HttpServer::new(processor, agent_info, format!("127.0.0.1:{}", port)); - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => {} - } - }); - - tokio::time::sleep(Duration::from_millis(100)).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Test page_size > 100 (should clamp to 100, not error) - let params = a2a_rs::domain::ListTasksParams { - page_size: Some(150), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - // According to spec, page_size should be clamped, not return error - assert_eq!(result.page_size, 100); - - // Test page_size < 1 (should clamp to 1, not error) - let params = a2a_rs::domain::ListTasksParams { - page_size: Some(0), - ..Default::default() - }; - - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - // According to spec, page_size should be clamped, not return error - assert_eq!(result.page_size, 1); - - shutdown_tx.send(()).ok(); -} - -#[test] -fn test_all_a2a_error_codes_defined() { - use a2a_rs::domain::A2AError; - - // Test all A2A-specific error codes from the specification - let errors = vec![ - (A2AError::TaskNotFound("test-task".to_string()), -32001), - (A2AError::TaskNotCancelable("test-task".to_string()), -32002), - (A2AError::PushNotificationNotSupported, -32003), - (A2AError::UnsupportedOperation("test".to_string()), -32004), - ( - A2AError::ContentTypeNotSupported("test/type".to_string()), - -32005, - ), - (A2AError::InvalidAgentResponse("test".to_string()), -32006), - (A2AError::AuthenticatedExtendedCardNotConfigured, -32007), - ]; - - for (error, expected_code) in errors { - let jsonrpc_error = error.to_jsonrpc_error(); - assert_eq!( - jsonrpc_error["code"], expected_code, - "Error {:?} should have code {}", - error, expected_code - ); - assert!( - !jsonrpc_error["message"].as_str().unwrap().is_empty(), - "Error message should not be empty" - ); - } -} - -#[test] -fn test_jsonrpc_error_structure_compliance() { - use a2a_rs::domain::A2AError; - - let error = A2AError::TaskNotFound("task-123".to_string()); - let jsonrpc_error = error.to_jsonrpc_error(); - - // Verify JSON-RPC error structure - assert!(jsonrpc_error.is_object(), "Error should be an object"); - assert!( - jsonrpc_error.get("code").is_some(), - "Error must have code field" - ); - assert!( - jsonrpc_error.get("message").is_some(), - "Error must have message field" - ); - - // code should be an integer - assert!( - jsonrpc_error["code"].is_i64(), - "Error code must be an integer" - ); - - // message should be a string - assert!( - jsonrpc_error["message"].is_string(), - "Error message must be a string" - ); -} - -#[test] -fn test_task_state_transitions_validation() { - use a2a_rs::domain::{Message, Task, TaskState}; - - let task_id = "task-transition-test".to_string(); - let context_id = "ctx-test".to_string(); - let mut task = Task::new(task_id.clone(), context_id); - - // Valid state transitions - let valid_transitions = vec![ - (TaskState::Submitted, TaskState::Working), - (TaskState::Working, TaskState::InputRequired), - (TaskState::InputRequired, TaskState::Working), - (TaskState::Working, TaskState::Completed), - (TaskState::Working, TaskState::Failed), - (TaskState::Working, TaskState::Canceled), - ]; - - for (_from_state, to_state) in valid_transitions { - let msg = Message::agent_text( - format!("Transitioning to {:?}", to_state), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - task.update_status(to_state, Some(msg)); - assert_eq!( - task.status.state, to_state, - "Task should transition to {:?}", - to_state - ); - } -} - -#[test] -fn test_error_code_ranges() { - // Verify that error codes follow the specification ranges - - // Standard JSON-RPC errors: -32700 to -32603 - let jsonrpc_codes = vec![-32700, -32600, -32601, -32602, -32603]; - - for code in jsonrpc_codes { - assert!( - (-32700..=-32600).contains(&code), - "JSON-RPC error code {} should be in range -32700 to -32600", - code - ); - } - - // A2A-specific errors: -32001 to -32007 - let a2a_codes = vec![-32001, -32002, -32003, -32004, -32005, -32006, -32007]; - - for code in a2a_codes { - assert!( - (-32007..=-32001).contains(&code), - "A2A error code {} should be in range -32007 to -32001", - code - ); - } -} diff --git a/a2a-rs/tests/sqlx_storage_test.rs b/a2a-rs/tests/sqlx_storage_test.rs index 2dfbaba..198db19 100644 --- a/a2a-rs/tests/sqlx_storage_test.rs +++ b/a2a-rs/tests/sqlx_storage_test.rs @@ -4,11 +4,21 @@ mod sqlx_tests { use a2a_rs::adapter::storage::{DatabaseConfig, SqlxTaskStorage}; use a2a_rs::domain::TaskState; - use a2a_rs::port::{AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager}; + use a2a_rs::port::{ + AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskLifecycle, AsyncTaskQuery, + AsyncTaskVersioning, + }; use a2a_rs::{A2AError, TaskPushNotificationConfig}; use std::sync::Arc; use uuid::Uuid; + fn tid(s: &str) -> a2a_rs::domain::TaskId { + s.parse().unwrap() + } + fn cid(s: &str) -> a2a_rs::domain::ContextId { + s.parse().unwrap() + } + async fn create_test_storage() -> Result { // Use SQLite in-memory for tests let config = DatabaseConfig::builder() @@ -26,28 +36,28 @@ mod sqlx_tests { let context_id = "test-context"; // Test task creation - let task = storage.create_task(&task_id, context_id).await?; + let task = storage.create(&tid(&task_id), &cid(context_id)).await?; assert_eq!(task.id, task_id); assert_eq!(task.context_id, context_id); assert_eq!(task.status.state, TaskState::Submitted); // Test task existence - assert!(storage.task_exists(&task_id).await?); - assert!(!storage.task_exists("non-existent").await?); + assert!(storage.exists(&tid(&task_id)).await?); + assert!(!storage.exists(&tid("non-existent")).await?); // Test status updates let working_task = storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; assert_eq!(working_task.status.state, TaskState::Working); let completed_task = storage - .update_task_status(&task_id, TaskState::Completed, None) + .update_status(&tid(&task_id), TaskState::Completed, None) .await?; assert_eq!(completed_task.status.state, TaskState::Completed); // Test task retrieval with history - let retrieved_task = storage.get_task(&task_id, Some(10)).await?; + let retrieved_task = storage.get(&tid(&task_id), Some(10)).await?; assert_eq!(retrieved_task.id, task_id); assert_eq!(retrieved_task.status.state, TaskState::Completed); // Should have history: Submitted -> Working -> Completed @@ -63,17 +73,17 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create and start working on task - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; // Cancel the working task - let canceled_task = storage.cancel_task(&task_id).await?; + let canceled_task = storage.cancel(&tid(&task_id)).await?; assert_eq!(canceled_task.status.state, TaskState::Canceled); // Verify cancellation was successful - let task_with_history = storage.get_task(&task_id, None).await?; + let task_with_history = storage.get(&tid(&task_id), None).await?; assert_eq!(task_with_history.status.state, TaskState::Canceled); // Note: We're not fully implementing history loading in this version // In a full implementation, you'd verify the cancellation message was added @@ -87,16 +97,16 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create, work on, and complete task - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; storage - .update_task_status(&task_id, TaskState::Completed, None) + .update_status(&tid(&task_id), TaskState::Completed, None) .await?; // Try to cancel completed task - should fail - let result = storage.cancel_task(&task_id).await; + let result = storage.cancel(&tid(&task_id)).await; assert!(result.is_err()); if let Err(A2AError::TaskNotCancelable(_)) = result { @@ -114,10 +124,10 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create first task - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; // Try to create duplicate - should fail - let result = storage.create_task(&task_id, "test-context").await; + let result = storage.create(&tid(&task_id), &cid("test-context")).await; assert!(result.is_err()); if let Err(A2AError::TaskNotFound(_)) = result { @@ -138,24 +148,24 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create task and make several status changes - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; storage - .update_task_status(&task_id, TaskState::InputRequired, None) + .update_status(&tid(&task_id), TaskState::InputRequired, None) .await?; storage - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; storage - .update_task_status(&task_id, TaskState::Completed, None) + .update_status(&tid(&task_id), TaskState::Completed, None) .await?; // Note: We're not fully implementing history loading in this version // In a full implementation, you'd test history limits here - let _task_limited = storage.get_task(&task_id, Some(3)).await?; - let _task_full = storage.get_task(&task_id, None).await?; + let _task_limited = storage.get(&tid(&task_id), Some(3)).await?; + let _task_full = storage.get(&tid(&task_id), None).await?; Ok(()) } @@ -166,7 +176,7 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create task first - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; // Set push notification config let config = TaskPushNotificationConfig { @@ -179,20 +189,38 @@ mod sqlx_tests { ..Default::default() }; - let set_config = storage.set_task_notification(&config).await?; + let set_config = storage.set_config(&config).await?; assert_eq!(set_config.task_id, task_id); assert_eq!(set_config.url, "https://example.com/webhook"); // Get push notification config - let retrieved_config = storage.get_task_notification(&task_id).await?; + let retrieved_config = storage + .get_config(&a2a_rs::domain::GetTaskPushNotificationConfigParams { + id: task_id.clone(), + push_notification_config_id: None, + metadata: None, + }) + .await?; assert_eq!(retrieved_config.task_id, task_id); assert_eq!(retrieved_config.url, "https://example.com/webhook"); // Remove push notification config - storage.remove_task_notification(&task_id).await?; + storage + .delete_config(&a2a_rs::domain::DeleteTaskPushNotificationConfigParams { + id: task_id.clone(), + push_notification_config_id: String::new(), + metadata: None, + }) + .await?; // Verify it's removed - let result = storage.get_task_notification(&task_id).await; + let result = storage + .get_config(&a2a_rs::domain::GetTaskPushNotificationConfigParams { + id: task_id.clone(), + push_notification_config_id: None, + metadata: None, + }) + .await; assert!(result.is_err()); Ok(()) @@ -227,23 +255,25 @@ mod sqlx_tests { Ok(()) } + /// Subscriber management is not a storage responsibility: it lives in + /// `InMemoryStreamingHandler`. This pins the registry semantics on that + /// adapter. #[tokio::test] async fn test_streaming_subscribers() -> Result<(), Box> { - let storage = create_test_storage().await?; - let task_id = Uuid::new_v4().to_string(); + use a2a_rs::InMemoryStreamingHandler; - // Create task - storage.create_task(&task_id, "test-context").await?; + let streaming = InMemoryStreamingHandler::new(); + let task_id = Uuid::new_v4().to_string(); - // Test subscriber count - let count = storage.get_subscriber_count(&task_id).await?; + // No subscribers registered for an unknown task. + let count = streaming.get_subscriber_count(&task_id).await?; assert_eq!(count, 0); - // Test removing non-existent subscribers - storage.remove_task_subscribers(&task_id).await?; + // Removing subscribers for a task with none is a no-op. + streaming.remove_task_subscribers(&task_id).await?; - // Test unsupported operations - let result = storage.remove_subscription("fake-id").await; + // Removal by subscription ID is unsupported by the in-memory handler. + let result = streaming.remove_subscription("fake-id").await; assert!(matches!(result, Err(A2AError::UnsupportedOperation(_)))); Ok(()) @@ -260,13 +290,13 @@ mod sqlx_tests { let handle = tokio::spawn(async move { let task_id = format!("concurrent-task-{}", i); let task = storage_clone - .create_task(&task_id, "concurrent-context") + .create(&tid(&task_id), &cid("concurrent-context")) .await?; storage_clone - .update_task_status(&task_id, TaskState::Working, None) + .update_status(&tid(&task_id), TaskState::Working, None) .await?; storage_clone - .update_task_status(&task_id, TaskState::Completed, None) + .update_status(&tid(&task_id), TaskState::Completed, None) .await?; Ok::<_, A2AError>(task) }); @@ -282,8 +312,8 @@ mod sqlx_tests { // Verify all tasks exist for i in 0..10 { let task_id = format!("concurrent-task-{}", i); - assert!(storage.task_exists(&task_id).await?); - let task = storage.get_task(&task_id, None).await?; + assert!(storage.exists(&tid(&task_id)).await?); + let task = storage.get(&tid(&task_id), None).await?; assert_eq!(task.status.state, TaskState::Completed); } @@ -315,12 +345,12 @@ mod sqlx_tests { // Create some tasks for i in 0..5 { let task_id = format!("task-{}", i); - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; } // List all tasks let params = a2a_rs::domain::ListTasksParams::default(); - let result = storage.list_tasks_v3(¶ms).await?; + let result = storage.list(¶ms).await?; assert_eq!(result.total_size, 5, "Should have 5 tasks"); assert_eq!(result.tasks.len(), 5, "Should return 5 tasks"); @@ -334,15 +364,15 @@ mod sqlx_tests { let storage = create_test_storage().await?; // Create tasks in different contexts and states - storage.create_task("task-a-1", "context-a").await?; - storage.create_task("task-a-2", "context-a").await?; - storage.create_task("task-b-1", "context-b").await?; + storage.create(&tid("task-a-1"), &cid("context-a")).await?; + storage.create(&tid("task-a-2"), &cid("context-a")).await?; + storage.create(&tid("task-b-1"), &cid("context-b")).await?; storage - .update_task_status("task-a-1", TaskState::Working, None) + .update_status(&tid("task-a-1"), TaskState::Working, None) .await?; storage - .update_task_status("task-a-2", TaskState::Completed, None) + .update_status(&tid("task-a-2"), TaskState::Completed, None) .await?; // Filter by context @@ -350,7 +380,7 @@ mod sqlx_tests { context_id: Some("context-a".to_string()), ..Default::default() }; - let result = storage.list_tasks_v3(¶ms).await?; + let result = storage.list(¶ms).await?; assert_eq!(result.total_size, 2, "Should have 2 tasks in context-a"); // Filter by status @@ -358,7 +388,7 @@ mod sqlx_tests { status: Some(TaskState::Working), ..Default::default() }; - let result = storage.list_tasks_v3(¶ms).await?; + let result = storage.list(¶ms).await?; assert_eq!(result.total_size, 1, "Should have 1 working task"); Ok(()) @@ -371,7 +401,7 @@ mod sqlx_tests { // Create 10 tasks for i in 0..10 { storage - .create_task(&format!("task-{}", i), "test-context") + .create(&tid(&format!("task-{}", i)), &cid("test-context")) .await?; } @@ -380,7 +410,7 @@ mod sqlx_tests { page_size: Some(3), ..Default::default() }; - let page1 = storage.list_tasks_v3(¶ms).await?; + let page1 = storage.list(¶ms).await?; assert_eq!(page1.tasks.len(), 3, "Should return 3 tasks"); assert!( !page1.next_page_token.is_empty(), @@ -393,7 +423,7 @@ mod sqlx_tests { page_token: Some(page1.next_page_token.clone()), ..Default::default() }; - let page2 = storage.list_tasks_v3(¶ms).await?; + let page2 = storage.list(¶ms).await?; assert_eq!(page2.tasks.len(), 3, "Should return 3 tasks"); Ok(()) @@ -405,7 +435,7 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create task first - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; // Set push notification config let config = TaskPushNotificationConfig { @@ -417,7 +447,7 @@ mod sqlx_tests { authentication: None.into(), ..Default::default() }; - storage.set_task_notification(&config).await?; + storage.set_config(&config).await?; // Get specific config let get_params = a2a_rs::domain::GetTaskPushNotificationConfigParams { @@ -425,7 +455,7 @@ mod sqlx_tests { push_notification_config_id: Some("config-1".to_string()), metadata: None, }; - let retrieved = storage.get_push_notification_config(&get_params).await?; + let retrieved = storage.get_config(&get_params).await?; assert_eq!(retrieved.url, "https://example.com/webhook"); assert_eq!(retrieved.token, "test-token"); @@ -434,7 +464,7 @@ mod sqlx_tests { id: task_id.clone(), metadata: None, }; - let configs = storage.list_push_notification_configs(&list_params).await?; + let configs = storage.list_configs(&list_params).await?; assert_eq!(configs.len(), 1, "Should have 1 config"); // Delete config @@ -443,12 +473,10 @@ mod sqlx_tests { push_notification_config_id: "config-1".to_string(), metadata: None, }; - storage - .delete_push_notification_config(&delete_params) - .await?; + storage.delete_config(&delete_params).await?; // Verify deleted - let configs = storage.list_push_notification_configs(&list_params).await?; + let configs = storage.list_configs(&list_params).await?; assert_eq!(configs.len(), 0, "Config should be deleted"); Ok(()) @@ -460,7 +488,7 @@ mod sqlx_tests { let task_id = Uuid::new_v4().to_string(); // Create task - storage.create_task(&task_id, "test-context").await?; + storage.create(&tid(&task_id), &cid("test-context")).await?; // Set multiple configs let config1 = TaskPushNotificationConfig { @@ -482,19 +510,71 @@ mod sqlx_tests { ..Default::default() }; - storage.set_task_notification(&config1).await?; - storage.set_task_notification(&config2).await?; + storage.set_config(&config1).await?; + storage.set_config(&config2).await?; // List should return both let list_params = a2a_rs::domain::ListTaskPushNotificationConfigsParams { id: task_id.clone(), metadata: None, }; - let configs = storage.list_push_notification_configs(&list_params).await?; + let configs = storage.list_configs(&list_params).await?; assert_eq!(configs.len(), 2, "Should have 2 configs"); Ok(()) } + + #[tokio::test] + async fn test_optimistic_concurrency_versioning() -> Result<(), Box> { + let storage = create_test_storage().await?; + let task_id = Uuid::new_v4().to_string(); + + // A freshly created task starts at version 1. + storage.create(&tid(&task_id), &cid("ctx")).await?; + assert_eq!(storage.version(&tid(&task_id)).await?, 1); + + // Every unversioned mutation bumps the version too. + storage + .update_status(&tid(&task_id), TaskState::Working, None) + .await?; + let snapshot = storage.get_versioned(&tid(&task_id), None).await?; + assert_eq!(snapshot.version, 2); + assert_eq!(snapshot.task.status.state, TaskState::Working); + + // A conditional update with the stale version is rejected, untouched. + let stale = storage + .update_status_checked(&tid(&task_id), 1, TaskState::Completed, None) + .await; + match stale { + Err(A2AError::VersionConflict { + expected, actual, .. + }) => { + assert_eq!(expected, 1); + assert_eq!(actual, 2); + } + other => panic!("expected VersionConflict, got {other:?}"), + } + // State is unchanged after the rejected update. + assert_eq!( + storage.get(&tid(&task_id), None).await?.status.state, + TaskState::Working + ); + + // A conditional update with the current version succeeds and bumps. + let updated = storage + .update_status_checked(&tid(&task_id), 2, TaskState::Completed, None) + .await?; + assert_eq!(updated.version, 3); + assert_eq!(updated.task.status.state, TaskState::Completed); + + // Versioning ops on a missing task report TaskNotFound. + assert!(matches!( + storage.version(&tid("ghost")).await, + Err(A2AError::TaskNotFound(_)) + )); + + Ok(()) + } } #[cfg(not(feature = "sqlx-storage"))] diff --git a/a2a-rs/tests/task_list_test.rs b/a2a-rs/tests/task_list_test.rs deleted file mode 100644 index 9d18a49..0000000 --- a/a2a-rs/tests/task_list_test.rs +++ /dev/null @@ -1,619 +0,0 @@ -//! Integration tests for tasks/list endpoint (v1.0.0) - -#![cfg(all(feature = "http-client", feature = "http-server"))] - -mod common; - -use a2a_rs::{ - adapter::{ - DefaultRequestProcessor, HttpClient, HttpServer, InMemoryTaskStorage, SimpleAgentInfo, - }, - domain::{ListTasksParams, Message, TaskState}, - port::AsyncTaskManager, - services::AsyncA2AClient, -}; -use common::TestBusinessHandler; -use std::time::Duration; -use tokio::sync::oneshot; - -async fn setup_server_with_tasks(port: u16) -> (oneshot::Sender<()>, InMemoryTaskStorage) { - let storage = InMemoryTaskStorage::new(); - let storage_clone = storage.clone(); - - // Create business handler with the storage - let handler = TestBusinessHandler::with_storage(storage); - - // Create agent info for the processor - let test_agent_info = SimpleAgentInfo::new( - "test-agent".to_string(), - format!("http://localhost:{}", port), - ); - - // Create a processor - let processor = DefaultRequestProcessor::with_handler(handler, test_agent_info); - - // Create an agent info provider - let agent_info = SimpleAgentInfo::new( - "Task List Test Agent".to_string(), - format!("http://localhost:{}", port), - ) - .with_state_transition_history(); - - // Create the server - let server = HttpServer::new(processor, agent_info, format!("127.0.0.1:{}", port)); - - // Create a shutdown channel - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - // Start the server in a separate task - tokio::spawn(async move { - tokio::select! { - _ = server.start() => {}, - _ = shutdown_rx => {} - } - }); - - // Give the server time to start - tokio::time::sleep(Duration::from_millis(100)).await; - - (shutdown_tx, storage_clone) -} - -#[tokio::test] -async fn test_task_list_basic() { - let port = 9001; - let (shutdown_tx, _storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Create several tasks - let task_ids = vec![ - format!("task-{}", uuid::Uuid::new_v4()), - format!("task-{}", uuid::Uuid::new_v4()), - format!("task-{}", uuid::Uuid::new_v4()), - ]; - - for task_id in &task_ids { - let message = Message::user_text( - format!("Test message for {}", task_id), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(task_id, &message, None, None) - .await - .expect("Failed to create task"); - } - - // Give time for tasks to be created - tokio::time::sleep(Duration::from_millis(100)).await; - - // List all tasks - let params = ListTasksParams::default(); - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - // Verify we got all tasks - assert!(result.total_size >= 3, "Should have at least 3 tasks"); - assert!(result.tasks.len() >= 3, "Should return at least 3 tasks"); - assert_eq!(result.page_size, 50, "Default page size should be 50"); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_filter_by_context() { - let port = 9002; - let (shutdown_tx, _storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let context_a = "context-a"; - let context_b = "context-b"; - - // Create tasks in different contexts (using session_id as context) - for i in 0..3 { - let task_id = format!("task-a-{}", i); - let message = Message::user_text( - format!("Message in context A - {}", i), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(&task_id, &message, Some(context_a), None) - .await - .expect("Failed to create task in context A"); - } - - for i in 0..2 { - let task_id = format!("task-b-{}", i); - let message = Message::user_text( - format!("Message in context B - {}", i), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(&task_id, &message, Some(context_b), None) - .await - .expect("Failed to create task in context B"); - } - - tokio::time::sleep(Duration::from_millis(100)).await; - - // List tasks in context A - let params = ListTasksParams { - context_id: Some(context_a.to_string()), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - // All returned tasks should be in context A - assert_eq!(result.total_size, 3, "Should have 3 tasks in context A"); - for task in &result.tasks { - assert_eq!(task.context_id, context_a, "Task should be in context A"); - } - - // List tasks in context B - let params = ListTasksParams { - context_id: Some(context_b.to_string()), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert_eq!(result.total_size, 2, "Should have 2 tasks in context B"); - for task in &result.tasks { - assert_eq!(task.context_id, context_b, "Task should be in context B"); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_filter_by_status() { - let port = 9003; - let (shutdown_tx, storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Create tasks and set different states - let task_id_1 = format!("task-{}", uuid::Uuid::new_v4()); - let task_id_2 = format!("task-{}", uuid::Uuid::new_v4()); - let task_id_3 = format!("task-{}", uuid::Uuid::new_v4()); - - // Create tasks - for task_id in &[&task_id_1, &task_id_2, &task_id_3] { - let message = Message::user_text( - "Test message".to_string(), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(task_id, &message, None, None) - .await - .expect("Failed to create task"); - } - - tokio::time::sleep(Duration::from_millis(50)).await; - - // Update task states directly through storage - storage - .update_task_status(&task_id_1, TaskState::Working, None) - .await - .expect("Failed to update task 1"); - storage - .update_task_status(&task_id_2, TaskState::Completed, None) - .await - .expect("Failed to update task 2"); - // task_id_3 remains in Submitted state - - tokio::time::sleep(Duration::from_millis(50)).await; - - // Filter by Working status - let params = ListTasksParams { - status: Some(TaskState::Working), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert!( - result.total_size >= 1, - "Should have at least 1 working task" - ); - for task in &result.tasks { - assert_eq!( - task.status.state, - TaskState::Working, - "Task should be in Working state" - ); - } - - // Filter by Completed status - let params = ListTasksParams { - status: Some(TaskState::Completed), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert!( - result.total_size >= 1, - "Should have at least 1 completed task" - ); - for task in &result.tasks { - assert_eq!( - task.status.state, - TaskState::Completed, - "Task should be in Completed state" - ); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_pagination() { - let port = 9004; - let (shutdown_tx, _storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Create 10 tasks - for i in 0..10 { - let task_id = format!("task-{}", i); - let message = Message::user_text( - format!("Message {}", i), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(&task_id, &message, None, None) - .await - .expect("Failed to create task"); - } - - tokio::time::sleep(Duration::from_millis(100)).await; - - // Get first page with page_size 3 - let params = ListTasksParams { - page_size: Some(3), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert_eq!(result.page_size, 3, "Page size should be 3"); - assert_eq!(result.tasks.len(), 3, "Should return exactly 3 tasks"); - assert!(result.total_size >= 10, "Total size should be at least 10"); - assert!( - !result.next_page_token.is_empty(), - "Should have next page token" - ); - - // Get second page using next_page_token - let params = ListTasksParams { - page_size: Some(3), - page_token: Some(result.next_page_token.clone()), - ..Default::default() - }; - let result_page2 = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks page 2"); - - assert_eq!(result_page2.page_size, 3, "Page size should be 3"); - assert_eq!(result_page2.tasks.len(), 3, "Should return exactly 3 tasks"); - - // Verify tasks are different between pages - let page1_ids: Vec<_> = result.tasks.iter().map(|t| &t.id).collect(); - let page2_ids: Vec<_> = result_page2.tasks.iter().map(|t| &t.id).collect(); - - for id in &page2_ids { - assert!( - !page1_ids.contains(id), - "Page 2 should have different tasks than page 1" - ); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_page_size_clamping() { - let port = 9005; - let (shutdown_tx, _storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // Create one task - let message = Message::user_text( - "Test message".to_string(), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message("task-1", &message, None, None) - .await - .expect("Failed to create task"); - - tokio::time::sleep(Duration::from_millis(50)).await; - - // Test page_size is clamped to 1-100 - let params = ListTasksParams { - page_size: Some(150), // Should be clamped to 100 - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert_eq!(result.page_size, 100, "Page size should be clamped to 100"); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_history_length() { - let port = 9006; - let (shutdown_tx, _storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let task_id = format!("task-{}", uuid::Uuid::new_v4()); - - // Create task and add multiple messages to history - for i in 0..5 { - let message = Message::user_text( - format!("Message {}", i), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(&task_id, &message, None, None) - .await - .expect("Failed to send message"); - } - - tokio::time::sleep(Duration::from_millis(100)).await; - - // List tasks with history_length = 2 - let params = ListTasksParams { - history_length: Some(2), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - let task = result - .tasks - .iter() - .find(|t| t.id == task_id) - .expect("Task not found"); - - if !task.history.is_empty() { - assert_eq!( - task.history.len(), - 2, - "History should be limited to 2 messages" - ); - } - - // List tasks with history_length = 0 (no history) - let params = ListTasksParams { - history_length: Some(0), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - let task = result - .tasks - .iter() - .find(|t| t.id == task_id) - .expect("Task not found"); - - assert!( - task.history.is_empty(), - "History should be empty when history_length is 0" - ); - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_include_artifacts() { - let port = 9007; - let (shutdown_tx, storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let task_id = format!("task-{}", uuid::Uuid::new_v4()); - - // Create task - let message = Message::user_text( - "Test message".to_string(), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(&task_id, &message, None, None) - .await - .expect("Failed to create task"); - - // Add artifact through storage - let mut task = storage - .get_task(&task_id, None) - .await - .expect("Failed to get task"); - - task.add_artifact(a2a_rs::domain::Artifact { - artifact_id: format!("artifact-{}", uuid::Uuid::new_v4()), - name: "Test Artifact".to_string(), - description: String::new(), - parts: vec![a2a_rs::domain::Part::text("Artifact content".to_string())], - metadata: None.into(), - extensions: Vec::new(), - ..Default::default() - }); - - // Update task in storage - let task_state = match task.status.state { - ::buffa::EnumValue::Known(s) => s, - _ => TaskState::Unknown, - }; - storage - .update_task_status(&task_id, task_state, None) - .await - .ok(); - - tokio::time::sleep(Duration::from_millis(50)).await; - - // List tasks with include_artifacts = false (default) - let params = ListTasksParams { - include_artifacts: Some(false), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - let listed_task = result - .tasks - .iter() - .find(|t| t.id == task_id) - .expect("Task not found"); - - assert!( - listed_task.artifacts.is_empty(), - "Artifacts should be excluded when include_artifacts is false" - ); - - // List tasks with include_artifacts = true - let params = ListTasksParams { - include_artifacts: Some(true), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - let _listed_task = result - .tasks - .iter() - .find(|t| t.id == task_id) - .expect("Task not found"); - - // Note: Artifact inclusion depends on whether artifacts were persisted - // through the storage layer. The add_artifact call above only modifies - // a local copy, so the storage may not have the artifact. - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_combined_filters() { - let port = 9008; - let (shutdown_tx, storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - let context_id = "test-context-combined"; - - // Create multiple tasks in same context (using session_id) - let task_ids: Vec = (0..5) - .map(|_i| format!("combined-task-{}", uuid::Uuid::new_v4())) - .collect(); - - for task_id in &task_ids { - let message = Message::user_text( - format!("Message for {}", task_id), - format!("msg-{}", uuid::Uuid::new_v4()), - ); - client - .send_task_message(task_id, &message, Some(context_id), None) - .await - .expect("Failed to create task"); - } - - tokio::time::sleep(Duration::from_millis(100)).await; - - // All tasks start as Working (set by DefaultMessageHandler) - // Update some tasks to other states - leave task_ids[0] and task_ids[1] as Working - storage - .update_task_status(&task_ids[2], TaskState::Completed, None) - .await - .ok(); - storage - .update_task_status(&task_ids[3], TaskState::Completed, None) - .await - .ok(); - storage - .update_task_status(&task_ids[4], TaskState::Failed, None) - .await - .ok(); - - tokio::time::sleep(Duration::from_millis(50)).await; - - // Filter by both context and status - let params = ListTasksParams { - context_id: Some(context_id.to_string()), - status: Some(TaskState::Working), - page_size: Some(10), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert_eq!( - result.total_size, 2, - "Should have 2 working tasks in context" - ); - for task in &result.tasks { - assert_eq!(task.context_id, context_id); - assert_eq!(task.status.state, TaskState::Working); - } - - shutdown_tx.send(()).ok(); -} - -#[tokio::test] -async fn test_task_list_empty_results() { - let port = 9009; - let (shutdown_tx, _storage) = setup_server_with_tasks(port).await; - - let client = HttpClient::new(format!("http://localhost:{}", port)); - - // List tasks with filter that matches nothing - let params = ListTasksParams { - context_id: Some("non-existent-context".to_string()), - ..Default::default() - }; - let result = client - .list_tasks(¶ms) - .await - .expect("Failed to list tasks"); - - assert_eq!(result.total_size, 0, "Should have no tasks"); - assert_eq!(result.tasks.len(), 0, "Should return empty array"); - assert!( - result.next_page_token.is_empty(), - "Should have empty next page token" - ); - - shutdown_tx.send(()).ok(); -} diff --git a/a2a-rs/tests/transport_negotiation_test.rs b/a2a-rs/tests/transport_negotiation_test.rs new file mode 100644 index 0000000..088e1e2 --- /dev/null +++ b/a2a-rs/tests/transport_negotiation_test.rs @@ -0,0 +1,221 @@ +//! Unit tests for client-side transport negotiation. +//! +//! These drive [`TransportNegotiator`] with **fake** factories (no network) to +//! pin the ranking algorithm: client preference (registration order) dominates +//! card order, a failing `create` falls through to the next compatible +//! interface, an unknown protocol errors, and the major-version filter skips +//! incompatible interfaces. + +#![cfg(feature = "client")] + +use std::pin::Pin; + +use async_trait::async_trait; +use futures::Stream; + +use a2a_rs::domain::{ + A2AError, AgentCard, AgentInterface, ListTasksParams, ListTasksResult, Message, Task, + TaskPushNotificationConfig, +}; +use a2a_rs::{StreamEvent, Transport, TransportFactory, TransportNegotiator}; + +/// A no-op transport that only reports its protocol — its RPC methods are never +/// called in negotiation tests. +struct DummyTransport { + proto: &'static str, +} + +#[async_trait] +impl Transport for DummyTransport { + fn protocol(&self) -> &str { + self.proto + } + async fn send_task_message( + &self, + _: &str, + _: &Message, + _: Option<&str>, + _: Option, + ) -> Result { + unimplemented!() + } + async fn get_task(&self, _: &str, _: Option) -> Result { + unimplemented!() + } + async fn cancel_task(&self, _: &str) -> Result { + unimplemented!() + } + async fn set_task_push_notification( + &self, + _: &TaskPushNotificationConfig, + ) -> Result { + unimplemented!() + } + async fn get_task_push_notification( + &self, + _: &str, + ) -> Result { + unimplemented!() + } + async fn list_tasks(&self, _: &ListTasksParams) -> Result { + unimplemented!() + } + async fn list_push_notification_configs( + &self, + _: &str, + ) -> Result, A2AError> { + unimplemented!() + } + async fn get_push_notification_config( + &self, + _: &str, + _: &str, + ) -> Result { + unimplemented!() + } + async fn delete_push_notification_config(&self, _: &str, _: &str) -> Result<(), A2AError> { + unimplemented!() + } + async fn subscribe_to_task( + &self, + _: &str, + _: Option, + _: Option<&str>, + ) -> Result> + Send>>, A2AError> { + unimplemented!() + } +} + +/// A factory that builds a [`DummyTransport`] for one protocol, optionally +/// failing `create` to exercise fall-through. +struct FakeFactory { + proto: &'static str, + fail: bool, +} + +#[async_trait] +impl TransportFactory for FakeFactory { + fn protocol(&self) -> &str { + self.proto + } + async fn create( + &self, + _card: &AgentCard, + _iface: &AgentInterface, + ) -> Result, A2AError> { + if self.fail { + Err(A2AError::Internal("boom".to_string())) + } else { + Ok(Box::new(DummyTransport { proto: self.proto })) + } + } +} + +fn iface(proto: &str, version: &str) -> AgentInterface { + AgentInterface { + url: format!("http://localhost/{proto}"), + protocol_binding: proto.to_string(), + protocol_version: version.to_string(), + ..Default::default() + } +} + +fn card(interfaces: Vec) -> AgentCard { + AgentCard { + supported_interfaces: interfaces, + ..Default::default() + } +} + +#[tokio::test] +async fn prefers_client_order_over_card_order() { + let negotiator = TransportNegotiator::new() + .with(FakeFactory { + proto: "CONNECTRPC", + fail: false, + }) + .with(FakeFactory { + proto: "JSONRPC", + fail: false, + }); + // Card lists JSONRPC first, but the client prefers CONNECTRPC. + let c = card(vec![iface("JSONRPC", "1.0"), iface("CONNECTRPC", "1.0")]); + let transport = negotiator.negotiate(&c).await.unwrap(); + assert_eq!(transport.protocol(), "CONNECTRPC"); +} + +#[tokio::test] +async fn falls_through_on_create_failure() { + let negotiator = TransportNegotiator::new() + .with(FakeFactory { + proto: "CONNECTRPC", + fail: true, + }) + .with(FakeFactory { + proto: "JSONRPC", + fail: false, + }); + let c = card(vec![iface("CONNECTRPC", "1.0"), iface("JSONRPC", "1.0")]); + let transport = negotiator.negotiate(&c).await.unwrap(); + assert_eq!(transport.protocol(), "JSONRPC"); +} + +#[tokio::test] +async fn unknown_protocol_errors() { + let negotiator = TransportNegotiator::new().with(FakeFactory { + proto: "JSONRPC", + fail: false, + }); + // `Box` isn't `Debug`, so match the Result rather than `unwrap_err`. + let result = negotiator + .negotiate(&card(vec![iface("GRPC", "1.0")])) + .await; + assert!(matches!(result, Err(A2AError::UnsupportedOperation(_)))); +} + +#[tokio::test] +async fn empty_interfaces_errors() { + let negotiator = TransportNegotiator::new().with(FakeFactory { + proto: "JSONRPC", + fail: false, + }); + assert!(negotiator.negotiate(&card(vec![])).await.is_err()); +} + +#[tokio::test] +async fn skips_incompatible_major_version() { + let negotiator = TransportNegotiator::new().with(FakeFactory { + proto: "JSONRPC", + fail: false, + }); + // v2.x is not compatible with this client. + assert!( + negotiator + .negotiate(&card(vec![iface("JSONRPC", "2.0")])) + .await + .is_err() + ); + // v1.x is. + let transport = negotiator + .negotiate(&card(vec![iface("JSONRPC", "1.5")])) + .await + .unwrap(); + assert_eq!(transport.protocol(), "JSONRPC"); +} + +#[test] +fn supported_lists_protocols_in_preference_order() { + let negotiator = TransportNegotiator::new() + .with(FakeFactory { + proto: "CONNECTRPC", + fail: false, + }) + .with(FakeFactory { + proto: "JSONRPC", + fail: false, + }); + assert_eq!( + negotiator.supported().collect::>(), + vec!["CONNECTRPC", "JSONRPC"] + ); +} diff --git a/release-plz.toml b/release-plz.toml new file mode 100644 index 0000000..f0ad312 --- /dev/null +++ b/release-plz.toml @@ -0,0 +1,65 @@ +# release-plz configuration — https://release-plz.dev/docs/config +# +# Tag convention: per-crate `{package}-v{version}` (release-plz's default, made +# explicit here). Each crate tags and releases itself; there is no umbrella `v*` +# tag. The `release` job (see .github/workflows/release-plz.yml) runs on push to +# master and publishes only the crates whose Cargo.toml version is ahead of +# crates.io, tagging each one. A released `a2a-agents-v*` tag drives the binary +# build in release-binaries.yml. +# +# Why this file exists: 0.3.0 produced wrong changelog compare-links because the +# repo carried both per-crate tags and a manual umbrella `v0.3.0` tag, so the +# links pointed at inconsistent refs. Standardizing on per-crate tags + the +# release-plz `release_link` variable (which derives the compare URL from the +# tag template below) keeps them correct. + +[workspace] +# Per-crate tag + GitHub release names. +git_tag_name = "{{ package }}-v{{ version }}" +git_release_name = "{{ package }} v{{ version }}" +# All crates publish together; keep the changelog/PR per-crate. +git_release_body = "{{ changelog }}" + +[changelog] +header = """ +# Changelog + +All notable changes to this project will be documented in this file. +""" + +# `release_link` is supplied by release-plz and resolves to the correct +# `compare/{prev}...{curr}` URL for the per-crate tag template above — this is +# what fixes the 0.3.0 broken compare-links. +body = """ +## [{{ version | trim_start_matches(pat="v") }}]\ +{% if release_link %}({{ release_link }}){% endif %} - {{ timestamp | date(format="%Y-%m-%d") }} +{% for group, commits in commits | group_by(attribute="group") %} +### {{ group | upper_first }} +{% for commit in commits %} +- {% if commit.scope %}*({{ commit.scope }})* {% endif %}{{ commit.message | upper_first }}\ +{% endfor %} +{% endfor %} +""" +trim = true + +# Conventional-commit routing. Noise types (ci, build, style, chore) and any +# "fmt"/"clippy"/"fixed CI" cleanup commits are skipped from the user-facing +# changelog; the substance (feat/fix/perf/refactor/docs) is grouped. +commit_parsers = [ + { message = "(?i)^(chore\\(release\\)|release v|prepare release)", skip = true }, + { message = "(?i)(rustfmt|cargo fmt|^fmt\\b| fmt\\b|clippy|fixed ci|fix ci|ci:|^ci\\b)", skip = true }, + { message = "^feat", group = "Added" }, + { message = "^fix", group = "Fixed" }, + { message = "^perf", group = "Performance" }, + { message = "^refactor", group = "Changed" }, + { message = "^docs?", group = "Documentation" }, + { message = "^test", skip = true }, + { message = "^build", skip = true }, + { message = "^style", skip = true }, + { message = "^ci", skip = true }, + { message = "^chore", skip = true }, + { message = ".*", group = "Other" }, +] +# Drop merge commits. +filter_commits = false +protect_breaking_commits = true