Skip to content

Commit aa7d19a

Browse files
committed
enhance(BA-6053): preserve str start_command form in model definitions
Stop coercing string ``start_command`` values into argv lists via ``shlex.split`` in the API and config validators. The kernel runner already wraps ``str`` commands as ``[shell, "-c", str]`` at exec time (``kernel/service.py``), so passing the original string through gives users full shell semantics — line continuations, ``$VAR`` expansion, pipes — and removes the need to manually strip backslashes from copy-pasted multi-line vendor recipes (e.g. vLLM). - ``ModelServiceConfig.start_command`` and the draft mirror now accept ``str | list[str] | None``; the previous ``shlex.split`` validator is removed. - ``{model_path}`` substitution and ``with_args_appended`` handle both forms; preset ARGS appended to a string command are shell-quoted via ``shlex.join``. - DTOs in ``common/dto/manager/v2/deployment`` are widened to match. The GraphQL types expose the field as ``JSON | None`` (the same pattern already used for ``ModelMetadata.version``) since GraphQL cannot natively represent a scalar/list union. Existing list-form definitions are unaffected. resolves #11624
1 parent ddb766c commit aa7d19a

6 files changed

Lines changed: 129 additions & 42 deletions

File tree

changes/11624.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Preserve the original `str` form of `start_command` in model definitions all the way to the kernel runner so the kernel runner's existing `[shell, "-c", str]` wrapping handles shell semantics (line continuations, `$VAR` expansion, pipes) — eliminating the need to manually strip backslashes from copy-pasted multi-line vendor recipes (e.g. vLLM).

src/ai/backend/common/config.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
AliasChoices,
1515
ConfigDict,
1616
Field,
17-
field_validator,
1817
)
1918

2019
from . import validators as tx
@@ -144,16 +143,6 @@ def snake_to_kebab_case(string: str) -> str:
144143
agent_selector_config_iv = t.Dict({}) | agent_selector_globalconfig_iv
145144

146145

147-
def _normalize_start_command(value: Any) -> Any:
148-
"""Coerce legacy ``str`` ``start_command`` into argv list via
149-
:func:`shlex.split`. Lists, ``None``, and other types pass through so
150-
the schema's strict check rejects them.
151-
"""
152-
if isinstance(value, str):
153-
return shlex.split(value)
154-
return value
155-
156-
157146
model_definition_iv = t.Dict({
158147
t.Key("models"): t.List(
159148
t.Dict({
@@ -170,8 +159,9 @@ def _normalize_start_command(value: Any) -> Any:
170159
t.Key("args"): t.Dict().allow_extra("*"),
171160
})
172161
),
173-
t.Key("start_command", default=None): (t.Null | t.String | t.List(t.String))
174-
>> _normalize_start_command,
162+
# A ``str`` ``start_command`` is preserved as-is so the
163+
# kernel runner can wrap it as ``[shell, "-c", str]``.
164+
t.Key("start_command", default=None): t.Null | t.String | t.List(t.String),
175165
t.Key("shell", default="/bin/bash"): t.String,
176166
t.Key("port"): t.ToInt[1:],
177167
t.Key("health_check", default=None): t.Null
@@ -258,14 +248,23 @@ class ModelServiceConfig(BaseConfigModel):
258248
default_factory=list,
259249
description="List of pre-start actions to execute before starting the model service.",
260250
)
261-
start_command: list[str] | None = Field(
251+
start_command: str | list[str] | None = Field(
262252
default=None,
263253
description=(
264-
"Argv list to start the model service. ``{model_path}`` in any "
265-
"token is replaced per-token with the resolved ``model_path`` "
266-
"before launch. ``None`` falls back to the image's default CMD."
254+
"Command to start the model service. Two forms are accepted: "
255+
"an argv list (``list[str]``) which is exec'ed directly, or a "
256+
"single shell script string which the kernel runner wraps as "
257+
"``[shell, '-c', str]`` (giving the user full shell semantics "
258+
"such as line continuations, ``$VAR`` expansion, and pipes). "
259+
"``{model_path}`` is replaced with the resolved ``model_path`` "
260+
"before launch — per-token for the list form, in-place for the "
261+
"string form. ``None`` falls back to the image's default CMD."
267262
),
268-
examples=[["python", "service.py"], ["vllm", "serve", "{model_path}"]],
263+
examples=[
264+
["python", "service.py"],
265+
["vllm", "serve", "{model_path}"],
266+
"vllm serve {model_path} --tensor-parallel-size 2",
267+
],
269268
)
270269
shell: str = Field(
271270
default="/bin/bash",
@@ -282,11 +281,6 @@ class ModelServiceConfig(BaseConfigModel):
282281
description="Health check configuration for the model service.",
283282
)
284283

285-
@field_validator("start_command", mode="before")
286-
@classmethod
287-
def _coerce_start_command(cls, value: Any) -> Any:
288-
return _normalize_start_command(value)
289-
290284

291285
class ModelMetadata(BaseConfigModel):
292286
author: str | None = Field(
@@ -498,7 +492,12 @@ def health_check_config(self) -> ModelHealthCheck | None:
498492

499493
def with_args_appended(self, args: list[str]) -> ModelDefinition:
500494
"""Return a copy with ``args`` appended to each model's
501-
``service.start_command`` as separate argv tokens.
495+
``service.start_command``.
496+
497+
For the list form, ``args`` are concatenated as separate argv
498+
tokens. For the string form, ``args`` are shell-quoted via
499+
:func:`shlex.join` and appended after a single space so they are
500+
parsed by the same shell that runs the user's script.
502501
503502
Models with ``service is None`` are passed through unchanged;
504503
a model whose ``start_command`` is ``None`` receives ``args``
@@ -512,9 +511,14 @@ def with_args_appended(self, args: list[str]) -> ModelDefinition:
512511
if model.service is None:
513512
new_models.append(model)
514513
continue
515-
existing = model.service.start_command or []
514+
existing = model.service.start_command
515+
merged: str | list[str]
516+
if isinstance(existing, str):
517+
merged = f"{existing} {shlex.join(args)}"
518+
else:
519+
merged = (existing or []) + args
516520
new_service = model.service.model_copy(
517-
update={"start_command": existing + args},
521+
update={"start_command": merged},
518522
)
519523
new_models.append(model.model_copy(update={"service": new_service}))
520524
return self.model_copy(update={"models": new_models})
@@ -548,16 +552,11 @@ def to_resolved(self) -> ModelHealthCheck:
548552

549553
class ModelServiceConfigDraft(BaseConfigModel):
550554
pre_start_actions: list[PreStartAction] | None = None
551-
start_command: list[str] | None = None
555+
start_command: str | list[str] | None = None
552556
shell: str | None = None
553557
port: int | None = None
554558
health_check: ModelHealthCheckDraft | None = None
555559

556-
@field_validator("start_command", mode="before")
557-
@classmethod
558-
def _coerce_start_command(cls, value: Any) -> Any:
559-
return _normalize_start_command(value)
560-
561560
def to_resolved(self) -> ModelServiceConfig:
562561
# Drop unset (None) scalars so the strict type's ``Field(default=...)``
563562
# declarations remain the single source of truth for default values;
@@ -583,9 +582,15 @@ def to_resolved(self) -> ModelConfig:
583582
# ``start_command`` are resolved here, at the same moment the
584583
# draft becomes a strict ``ModelConfig`` and ``model_path`` is
585584
# finalized. Placeholders therefore never propagate downstream.
586-
service.start_command = [
587-
token.replace("{model_path}", self.model_path) for token in service.start_command
588-
]
585+
if isinstance(service.start_command, str):
586+
service.start_command = service.start_command.replace(
587+
"{model_path}", self.model_path
588+
)
589+
else:
590+
service.start_command = [
591+
token.replace("{model_path}", self.model_path)
592+
for token in service.start_command
593+
]
589594
payload = self.model_dump(exclude_none=True, exclude={"service"})
590595
payload["service"] = service
591596
return ModelConfig.model_validate(payload)

src/ai/backend/common/dto/manager/v2/deployment/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class ModelMetadataInput(BaseRequestModel):
148148

149149
class ModelServiceConfigInput(BaseRequestModel):
150150
pre_start_actions: list[PreStartAction] | None = None
151-
start_command: list[str] | None = None
151+
start_command: str | list[str] | None = None
152152
shell: str | None = None
153153
port: int | None = None
154154
health_check: ModelHealthCheckInput | None = None

src/ai/backend/common/dto/manager/v2/deployment/types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,14 @@ class ModelServiceConfigInfoDTO(BaseResponseModel):
299299
default_factory=list,
300300
description="List of pre-start actions to execute before starting the model service.",
301301
)
302-
start_command: list[str] | None = Field(
303-
default=None, description="Command to start the model service."
302+
start_command: str | list[str] | None = Field(
303+
default=None,
304+
description=(
305+
"Command to start the model service. A list is exec'ed directly "
306+
"as argv; a string is wrapped as ``[shell, '-c', str]`` by the "
307+
"kernel runner so shell semantics (line continuations, ``$VAR`` "
308+
"expansion, pipes) apply."
309+
),
304310
)
305311
shell: str = Field(
306312
default="/bin/bash",

src/ai/backend/manager/api/gql/deployment/types/revision.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,14 @@ class ModelServiceConfigGQL:
383383
pre_start_actions: list[PreStartActionGQL] = gql_field(
384384
description="List of pre-start actions to execute before starting the model service."
385385
)
386-
start_command: list[str] | None = gql_field(
387-
description="Command to start the model service.", default=None
386+
start_command: JSON | None = gql_field(
387+
description=(
388+
"Command to start the model service. A JSON array (``list[str]``) "
389+
"is exec'ed directly as argv; a JSON string is wrapped as "
390+
"``[shell, '-c', str]`` by the kernel runner so shell semantics "
391+
"(line continuations, ``$VAR`` expansion, pipes) apply."
392+
),
393+
default=None,
388394
)
389395
shell: str = gql_field(description="Shell configured for the model service.")
390396
port: int = gql_field(description="Port number for the model service.")
@@ -891,8 +897,14 @@ class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigInputDTO])
891897
description="List of pre-start actions to execute before starting the model service.",
892898
default=None,
893899
)
894-
start_command: list[str] | None = gql_field(
895-
description="Command to start the model service.", default=None
900+
start_command: JSON | None = gql_field(
901+
description=(
902+
"Command to start the model service. A JSON array (``list[str]``) "
903+
"is exec'ed directly as argv; a JSON string is wrapped as "
904+
"``[shell, '-c', str]`` by the kernel runner so shell semantics "
905+
"(line continuations, ``$VAR`` expansion, pipes) apply."
906+
),
907+
default=None,
896908
)
897909
shell: str | None = gql_field(
898910
description="Shell configured for the model service.", default=None

tests/unit/common/test_config.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,45 @@ def test_to_resolved_leaves_start_command_with_no_placeholder_unchanged(self) ->
304304
assert service is not None
305305
assert service.start_command == ["my-server", "--bind", "0.0.0.0"]
306306

307+
def test_to_resolved_preserves_str_start_command(self) -> None:
308+
# When ``start_command`` is given as a single shell-script string,
309+
# the validator must NOT split it into argv tokens; the kernel
310+
# runner is responsible for wrapping it via ``[shell, "-c", str]``.
311+
# This protects copy-pasted multi-line vendor recipes (e.g. vLLM)
312+
# where backslash-newline must be interpreted by the shell, not
313+
# by ``shlex.split``.
314+
script = "vllm serve {model_path} \\\n --tensor-parallel-size 2"
315+
draft = ModelDefinitionDraft.model_validate({
316+
"models": [
317+
{
318+
"name": "demo",
319+
"model_path": "/data",
320+
"service": {
321+
"start_command": script,
322+
"port": 8000,
323+
},
324+
}
325+
]
326+
})
327+
328+
resolved = draft.to_resolved()
329+
330+
service = resolved.models[0].service
331+
assert service is not None
332+
assert isinstance(service.start_command, str)
333+
assert service.start_command == ("vllm serve /data \\\n --tensor-parallel-size 2")
334+
335+
def test_model_service_config_accepts_str_start_command(self) -> None:
336+
# Strict ``ModelServiceConfig`` (post-merge) accepts the str form
337+
# too — required so the resolved revision can carry it all the
338+
# way to the agent and kernel runner unchanged.
339+
service = ModelServiceConfig.model_validate({
340+
"start_command": "python -m foo --bar",
341+
"port": 8000,
342+
})
343+
344+
assert service.start_command == "python -m foo --bar"
345+
307346

308347
class TestModelDefinitionWithArgsAppended:
309348
@pytest.fixture
@@ -438,3 +477,27 @@ async def test_each_model_receives_args(
438477
assert first is not None and second is not None
439478
assert first.start_command == ["a", "--shared", "true"]
440479
assert second.start_command == ["b", "--shared", "true"]
480+
481+
async def test_appends_args_to_str_start_command_with_shell_quoting(self) -> None:
482+
# Preset ARGS are appended to a string ``start_command`` after a
483+
# single space, with each argument passed through ``shlex.join``
484+
# so values containing whitespace or quotes are safely parsed by
485+
# the same shell that will run the script.
486+
definition = ModelDefinition(
487+
models=[
488+
ModelConfig(
489+
name="demo",
490+
model_path="/models",
491+
service=ModelServiceConfig(
492+
port=8000,
493+
start_command="vllm serve /models",
494+
),
495+
)
496+
]
497+
)
498+
499+
result = definition.with_args_appended(["--prompt", "hello world"])
500+
501+
service = result.models[0].service
502+
assert service is not None
503+
assert service.start_command == "vllm serve /models --prompt 'hello world'"

0 commit comments

Comments
 (0)