Skip to content

Commit a6c4092

Browse files
authored
Merge pull request #51 from m-misiura/migrate-to-llama-stack-api-shim
fix: Fix new llama-stack-api module resolution
2 parents dd12f87 + 4e5e504 commit a6c4092

File tree

10 files changed

+199
-49
lines changed

10 files changed

+199
-49
lines changed

llama_stack_provider_trustyai_fms/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import logging
22
from typing import Any
33

4-
# Import Safety API
5-
from llama_stack_api.safety import Safety
6-
from llama_stack_api.datatypes import Api
4+
from .compat import Api, Safety
75

86
# First import the provider spec to ensure registration
97
from .provider import get_provider_spec
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Compatibility helpers for importing llama_stack APIs across versions.
2+
3+
The llama_stack APIs were moved under a separate `llama_stack_api` package
4+
upstream. Prefer the new package layout and fall back to the legacy one so
5+
this provider can run with both.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
# Try new llama_stack_api package first, fall back to legacy llama_stack
11+
try: # Current dedicated llama_stack_api package (preferred)
12+
from llama_stack_api.datatypes import (
13+
Api,
14+
ProviderSpec,
15+
RemoteProviderSpec,
16+
ShieldsProtocolPrivate,
17+
)
18+
from llama_stack_api.inference import (
19+
OpenAIAssistantMessageParam,
20+
OpenAIDeveloperMessageParam,
21+
OpenAIMessageParam,
22+
OpenAISystemMessageParam,
23+
OpenAIToolMessageParam,
24+
OpenAIUserMessageParam,
25+
SystemMessage,
26+
ToolResponseMessage,
27+
UserMessage,
28+
)
29+
from llama_stack_api.resource import ResourceType
30+
from llama_stack_api.safety import (
31+
ModerationObject,
32+
ModerationObjectResults,
33+
RunShieldResponse,
34+
Safety,
35+
SafetyViolation,
36+
ShieldStore,
37+
ViolationLevel,
38+
)
39+
from llama_stack_api.schema_utils import json_schema_type
40+
from llama_stack_api.shields import ListShieldsResponse, Shield, Shields
41+
42+
except ModuleNotFoundError: # Legacy llama_stack layout
43+
from llama_stack.apis.datatypes import Api
44+
from llama_stack.apis.inference import (
45+
CompletionMessage,
46+
Message,
47+
SystemMessage,
48+
ToolResponseMessage,
49+
UserMessage,
50+
)
51+
from llama_stack.apis.resource import ResourceType
52+
from llama_stack.apis.safety import (
53+
RunShieldResponse,
54+
Safety,
55+
SafetyViolation,
56+
ShieldStore,
57+
ViolationLevel,
58+
)
59+
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
60+
from llama_stack.providers.datatypes import (
61+
ProviderSpec,
62+
RemoteProviderSpec,
63+
ShieldsProtocolPrivate,
64+
)
65+
from llama_stack.schema_utils import json_schema_type
66+
67+
OpenAIMessageParam = Message
68+
OpenAIUserMessageParam = UserMessage
69+
OpenAISystemMessageParam = SystemMessage
70+
OpenAIToolMessageParam = ToolResponseMessage
71+
OpenAIAssistantMessageParam = CompletionMessage
72+
OpenAIDeveloperMessageParam = SystemMessage # Developer didn't exist
73+
74+
try:
75+
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
76+
except ImportError:
77+
class ModerationObject:
78+
"""Placeholder for legacy versions without ModerationObject"""
79+
pass
80+
81+
class ModerationObjectResults:
82+
"""Placeholder for legacy versions without ModerationObjectResults"""
83+
pass
84+
85+
86+
__all__ = [
87+
"Api",
88+
"ListShieldsResponse",
89+
"ModerationObject",
90+
"ModerationObjectResults",
91+
"OpenAIAssistantMessageParam",
92+
"OpenAIDeveloperMessageParam",
93+
"OpenAIMessageParam",
94+
"OpenAISystemMessageParam",
95+
"OpenAIToolMessageParam",
96+
"OpenAIUserMessageParam",
97+
"ProviderSpec",
98+
"RemoteProviderSpec",
99+
"ResourceType",
100+
"RunShieldResponse",
101+
"Safety",
102+
"SafetyViolation",
103+
"Shield",
104+
"Shields",
105+
"ShieldsProtocolPrivate",
106+
"ShieldStore",
107+
"SystemMessage",
108+
"ToolResponseMessage",
109+
"UserMessage",
110+
"ViolationLevel",
111+
"json_schema_type",
112+
]

llama_stack_provider_trustyai_fms/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from typing import Any
88
from urllib.parse import urlparse
99

10-
from llama_stack_api.schema_utils import json_schema_type
1110
from pydantic import BaseModel, Field, model_validator
1211

12+
from .compat import json_schema_type
13+
1314
# Make sure to export all classes at the module level
1415
__all__ = [
1516
"MessageType",

llama_stack_provider_trustyai_fms/detectors/base.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,30 @@
1212
from urllib.parse import urlparse
1313

1414
import httpx
15-
from llama_stack_api.inference import (
15+
16+
from ..compat import (
17+
ListShieldsResponse,
18+
ModerationObject,
19+
ModerationObjectResults,
1620
OpenAIAssistantMessageParam,
1721
OpenAIDeveloperMessageParam,
1822
OpenAIMessageParam,
1923
OpenAISystemMessageParam,
2024
OpenAIToolMessageParam,
2125
OpenAIUserMessageParam,
22-
SystemMessage,
23-
ToolResponseMessage,
24-
UserMessage,
25-
)
26-
from llama_stack_api.resource import ResourceType
27-
from llama_stack_api.safety import (
28-
ModerationObject,
29-
ModerationObjectResults,
26+
ResourceType,
3027
RunShieldResponse,
3128
Safety,
3229
SafetyViolation,
30+
Shield,
31+
Shields,
32+
ShieldsProtocolPrivate,
3333
ShieldStore,
34+
SystemMessage,
35+
ToolResponseMessage,
36+
UserMessage,
3437
ViolationLevel,
3538
)
36-
from llama_stack_api.shields import ListShieldsResponse, Shield, Shields
37-
from llama_stack_api.datatypes import ShieldsProtocolPrivate
38-
3939
from ..config import (
4040
BaseDetectorConfig,
4141
ChatDetectorConfig,
@@ -251,7 +251,9 @@ def _should_process_message(self, message: OpenAIMessageParam) -> bool:
251251
)
252252
return is_supported
253253

254-
def _filter_messages(self, messages: list[OpenAIMessageParam]) -> list[OpenAIMessageParam]:
254+
def _filter_messages(
255+
self, messages: list[OpenAIMessageParam]
256+
) -> list[OpenAIMessageParam]:
255257
"""Filter messages based on configured message types"""
256258
return [msg for msg in messages if self._should_process_message(msg)]
257259

@@ -1865,7 +1867,9 @@ async def _get_shield_id_from_model(self, model: str) -> str:
18651867
)
18661868
return matching_shields[0]
18671869

1868-
def _convert_input_to_messages(self, texts: str | list[str]) -> list[OpenAIMessageParam]:
1870+
def _convert_input_to_messages(
1871+
self, texts: str | list[str]
1872+
) -> list[OpenAIMessageParam]:
18691873
"""Convert string input(s) to UserMessage objects."""
18701874
if isinstance(texts, str):
18711875
inputs = [texts]

llama_stack_provider_trustyai_fms/detectors/chat.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
from dataclasses import dataclass
55
from typing import Any, cast
66

7-
from llama_stack_api.inference import OpenAIMessageParam
8-
from llama_stack_api.safety import RunShieldResponse
9-
7+
from ..compat import OpenAIMessageParam, RunShieldResponse
108
from ..config import ChatDetectorConfig
119
from .base import (
1210
BaseDetector,

llama_stack_provider_trustyai_fms/detectors/content.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import logging
44
from typing import Any, cast
55

6-
from llama_stack_api.inference import OpenAIMessageParam
7-
from llama_stack_api.safety import RunShieldResponse
8-
6+
from ..compat import OpenAIMessageParam, RunShieldResponse
97
from ..config import (
108
ContentDetectorConfig,
119
)
Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,71 @@
11
import logging
22

3-
from llama_stack_api.datatypes import Api, ProviderSpec, RemoteProviderSpec
3+
from .compat import Api, ProviderSpec, RemoteProviderSpec
44

55
logger = logging.getLogger(__name__)
66

7+
# Check if we're using very old format (0.2.20-0.2.22 with AdapterSpec)
8+
try:
9+
from llama_stack.providers.datatypes import AdapterSpec
10+
11+
USE_ADAPTER_SPEC = True
12+
except ImportError:
13+
# 0.3.3+ and llama-stack-api don't have AdapterSpec
14+
AdapterSpec = None
15+
USE_ADAPTER_SPEC = False
16+
717

818
def get_provider_spec() -> ProviderSpec:
919
"""Get provider specification for Safety API.
1020
11-
Returns RemoteProviderSpec for llama-stack-api >= 0.1.0.
12-
This provider requires llama-stack-api and is not compatible with legacy llama-stack.
21+
Compatible with llama-stack >= 0.2.20 and llama-stack-api >= 0.1.0.
1322
"""
14-
return RemoteProviderSpec(
15-
api=Api.safety,
16-
provider_type="remote::trustyai_fms",
17-
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
18-
module="llama_stack_provider_trustyai_fms",
19-
adapter_type="trustyai_fms",
20-
)
23+
if USE_ADAPTER_SPEC:
24+
# Very old (0.2.20-0.2.22): uses adapter field with AdapterSpec
25+
return RemoteProviderSpec(
26+
api=Api.safety,
27+
adapter=AdapterSpec(
28+
adapter_type="trustyai_fms",
29+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
30+
module="llama_stack_provider_trustyai_fms",
31+
),
32+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
33+
provider_type="remote::trustyai_fms",
34+
)
35+
else:
36+
# Newer (0.3.3+) and new llama-stack-api: uses adapter_type field
37+
return RemoteProviderSpec(
38+
api=Api.safety,
39+
provider_type="remote::trustyai_fms",
40+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
41+
module="llama_stack_provider_trustyai_fms",
42+
adapter_type="trustyai_fms",
43+
)
2144

2245

2346
def get_shields_provider_spec() -> ProviderSpec:
2447
"""Get provider specification for Shields API.
2548
26-
Returns RemoteProviderSpec for llama-stack-api >= 0.1.0.
27-
This provider requires llama-stack-api and is not compatible with legacy llama-stack.
49+
Compatible with llama-stack >= 0.2.20 and llama-stack-api >= 0.1.0.
2850
"""
29-
return RemoteProviderSpec(
30-
api=Api.shields,
31-
provider_type="remote::trustyai_fms",
32-
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
33-
module="llama_stack_provider_trustyai_fms",
34-
adapter_type="trustyai_fms",
35-
)
51+
if USE_ADAPTER_SPEC:
52+
# Very old (0.2.20-0.2.22): uses adapter field with AdapterSpec
53+
return RemoteProviderSpec(
54+
api=Api.shields,
55+
adapter=AdapterSpec(
56+
adapter_type="trustyai_fms",
57+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
58+
module="llama_stack_provider_trustyai_fms",
59+
),
60+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
61+
provider_type="remote::trustyai_fms",
62+
)
63+
else:
64+
# Newer (0.3.3+) and new llama-stack-api: uses adapter_type field
65+
return RemoteProviderSpec(
66+
api=Api.shields,
67+
provider_type="remote::trustyai_fms",
68+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
69+
module="llama_stack_provider_trustyai_fms",
70+
adapter_type="trustyai_fms",
71+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "llama-stack-provider-trustyai-fms"
7-
version = "0.3.0"
7+
version = "0.3.1"
88
description = "Remote safety provider for Llama Stack integrating FMS Guardrails Orchestrator and community detectors"
99
authors = [
1010
{name = "GitHub: m-misiura"}

tests/unit/test_detectors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import pytest
44
import pytest_asyncio
5-
from llama_stack.apis.inference import UserMessage
6-
from llama_stack.apis.safety import RunShieldResponse
75

6+
from llama_stack_provider_trustyai_fms.compat import RunShieldResponse, UserMessage
87
from llama_stack_provider_trustyai_fms.config import (
98
ChatDetectorConfig,
109
ContentDetectorConfig,

tests/unit/test_provider.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from llama_stack.providers.datatypes import Api, ProviderSpec
1+
from llama_stack_provider_trustyai_fms.compat import (
2+
Api,
3+
ProviderSpec,
4+
RemoteProviderSpec,
5+
)
26

37
try:
48
from llama_stack.providers.datatypes import AdapterSpec
59

610
USE_LEGACY = True
711
except ImportError:
8-
from llama_stack.providers.datatypes import RemoteProviderSpec
9-
12+
# New llama-stack-api doesn't have AdapterSpec or llama_stack.providers.datatypes
13+
AdapterSpec = None
1014
USE_LEGACY = False
1115

1216
from llama_stack_provider_trustyai_fms.provider import (

0 commit comments

Comments
 (0)