Skip to content

Commit 2a69bf1

Browse files
author
liuyuan90
committed
fix(schema): coerce string protocolVersion in InitializeRequest
Some ACP clients (notably Zed) send a date string like "2024-11-05" as the protocolVersion instead of an integer. The Rust SDK already handles this gracefully — its Deserialize impl maps any string to V0 with the comment "Old versions used strings". The Python SDK rejected strings outright, causing the agent process to crash on the very first handshake. Changes: - Add `_coerce_protocol_version` field_validator to `InitializeRequest` in `src/acp/schema.py` that maps non-integer values to 1 (current stable version), mirroring the Rust SDK's lenient behaviour. - Add `CLASS_VALIDATOR_INJECTIONS` table and `_inject_field_validators` post-processing step to `scripts/gen_schema.py` so the validator is re-applied automatically on future schema regenerations. - Add `_ensure_pydantic_import` helper used by the injection step to add `field_validator` to the generated pydantic import line. Ref: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs
1 parent df72173 commit 2a69bf1

2 files changed

Lines changed: 86 additions & 1 deletion

File tree

scripts/gen_schema.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,31 @@
120120
),
121121
)
122122

123+
# Classes that need a field_validator injected after generation.
124+
# Each entry: (class_name, field_name, validator_body)
125+
# The validator_body is the full method source (indented 4 spaces inside the class).
126+
CLASS_VALIDATOR_INJECTIONS: tuple[tuple[str, str, str], ...] = (
127+
(
128+
"InitializeRequest",
129+
"protocol_version",
130+
textwrap.dedent("""\
131+
@field_validator("protocol_version", mode="before")
132+
@classmethod
133+
def _coerce_protocol_version(cls, v: Any) -> int:
134+
# Some clients (e.g. Zed) send a date string like "2024-11-05" instead
135+
# of an integer. The Rust SDK treats any string as version 0; we map it
136+
# to 1 (the current stable version) so the connection is not rejected.
137+
# See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs
138+
if isinstance(v, int):
139+
return v
140+
try:
141+
return int(v)
142+
except (TypeError, ValueError):
143+
return 1
144+
"""),
145+
),
146+
)
147+
123148

124149
@dataclass(frozen=True)
125150
class _ProcessingStep:
@@ -182,6 +207,7 @@ def postprocess_generated_schema(output_path: Path) -> list[str]:
182207
_ProcessingStep("apply default overrides", _apply_default_overrides),
183208
_ProcessingStep("attach description comments", _add_description_comments),
184209
_ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model),
210+
_ProcessingStep("inject field validators", _inject_field_validators),
185211
)
186212

187213
for step in processing_steps:
@@ -338,6 +364,51 @@ def __getattr__(self, item: str) -> Any:
338364
return "\n".join(lines) + "\n"
339365

340366

367+
def _ensure_pydantic_import(content: str, name: str) -> str:
368+
"""Add *name* to the ``from pydantic import ...`` line if not already present."""
369+
lines = content.splitlines()
370+
for idx, line in enumerate(lines):
371+
if not line.startswith("from pydantic import "):
372+
continue
373+
imports = [part.strip() for part in line[len("from pydantic import "):].split(",")]
374+
if name not in imports:
375+
imports.append(name)
376+
lines[idx] = "from pydantic import " + ", ".join(imports)
377+
return "\n".join(lines) + "\n"
378+
return content
379+
380+
381+
def _inject_field_validators(content: str) -> str:
382+
"""Inject field_validator methods into classes listed in CLASS_VALIDATOR_INJECTIONS."""
383+
for class_name, _field_name, validator_body in CLASS_VALIDATOR_INJECTIONS:
384+
# Ensure field_validator is imported from pydantic.
385+
content = _ensure_pydantic_import(content, "field_validator")
386+
387+
# Find the end of the class body and append the validator before the next class.
388+
class_pattern = re.compile(
389+
rf"(class {class_name}\(BaseModel\):)(.*?)(?=\nclass |\Z)",
390+
re.DOTALL,
391+
)
392+
393+
def _append_validator(
394+
match: re.Match[str],
395+
_body: str = validator_body,
396+
_class: str = class_name,
397+
) -> str:
398+
header, block = match.group(1), match.group(2)
399+
# Indent the validator body by 4 spaces to sit inside the class.
400+
indented = "\n" + textwrap.indent(_body.rstrip(), " ")
401+
return header + block + indented + "\n"
402+
403+
content, count = class_pattern.subn(_append_validator, content, count=1)
404+
if count == 0:
405+
print(
406+
f"Warning: class {class_name} not found for validator injection",
407+
file=sys.stderr,
408+
)
409+
return content
410+
411+
341412
def _apply_field_overrides(content: str) -> str:
342413
for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES:
343414
if optional:

src/acp/schema.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
88

9-
from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict
9+
from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict, field_validator
1010

1111
PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"]
1212
PlanEntryPriority = Literal["high", "medium", "low"]
@@ -1588,6 +1588,20 @@ class InitializeRequest(BaseModel):
15881588
),
15891589
]
15901590

1591+
@field_validator("protocol_version", mode="before")
1592+
@classmethod
1593+
def _coerce_protocol_version(cls, v: Any) -> int:
1594+
# Some clients (e.g. Zed) send a date string like "2024-11-05" instead
1595+
# of an integer. The Rust SDK treats any string as version 0; we map it
1596+
# to 1 (the current stable version) so the connection is not rejected.
1597+
# See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs
1598+
if isinstance(v, int):
1599+
return v
1600+
try:
1601+
return int(v)
1602+
except (TypeError, ValueError):
1603+
return 1
1604+
15911605

15921606
class KillTerminalRequest(BaseModel):
15931607
# The _meta property is reserved by ACP to allow clients and agents to attach additional

0 commit comments

Comments
 (0)