Skip to content

Commit d31f345

Browse files
authored
Merge pull request trustyai-explainability#31 from m-misiura/RemoteProviderSpec
Update provider.py to ensure it can be loaded using `RemoteProviderSpec` introduced in lls 0.23.0
2 parents 3895458 + ad7fb49 commit d31f345

File tree

3 files changed

+90
-41
lines changed

3 files changed

+90
-41
lines changed
Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,56 @@
11
import logging
22

3-
from llama_stack.providers.datatypes import (
4-
AdapterSpec,
5-
Api,
6-
ProviderSpec,
7-
remote_provider_spec,
8-
)
9-
10-
# Add logging at the top
3+
from llama_stack.providers.datatypes import Api, ProviderSpec
4+
115
logger = logging.getLogger(__name__)
126

7+
try:
8+
from llama_stack.providers.datatypes import AdapterSpec, remote_provider_spec
9+
10+
USE_LEGACY = True
11+
logger.debug("Using legacy remote_provider_spec")
12+
except ImportError:
13+
from llama_stack.providers.datatypes import RemoteProviderSpec
14+
15+
USE_LEGACY = False
16+
logger.debug("Using new RemoteProviderSpec")
17+
1318

1419
def get_provider_spec() -> ProviderSpec:
15-
return remote_provider_spec(
16-
api=Api.safety,
17-
adapter=AdapterSpec(
18-
adapter_type="trustyai_fms",
20+
if USE_LEGACY:
21+
return remote_provider_spec(
22+
api=Api.safety,
23+
adapter=AdapterSpec(
24+
adapter_type="trustyai_fms",
25+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
26+
module="llama_stack_provider_trustyai_fms",
27+
),
28+
)
29+
else:
30+
return RemoteProviderSpec(
31+
api=Api.safety,
32+
provider_type="remote::trustyai_fms",
1933
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
2034
module="llama_stack_provider_trustyai_fms",
21-
),
22-
)
35+
adapter_type="trustyai_fms",
36+
)
2337

2438

2539
def get_shields_provider_spec() -> ProviderSpec:
26-
spec = remote_provider_spec(
27-
api=Api.shields,
28-
adapter=AdapterSpec(
29-
adapter_type="trustyai_fms",
40+
if USE_LEGACY:
41+
return remote_provider_spec(
42+
api=Api.shields,
43+
adapter=AdapterSpec(
44+
adapter_type="trustyai_fms",
45+
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
46+
module="llama_stack_provider_trustyai_fms",
47+
),
48+
)
49+
else:
50+
return RemoteProviderSpec(
51+
api=Api.shields,
52+
provider_type="remote::trustyai_fms",
3053
config_class="llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig",
3154
module="llama_stack_provider_trustyai_fms",
32-
),
33-
)
34-
# Add debug logging
35-
logger.debug(f"Returning shields provider spec: {spec}")
36-
return spec
55+
adapter_type="trustyai_fms",
56+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "llama-stack-provider-trustyai-fms"
7-
version = "0.2.2"
7+
version = "0.2.3"
88
description = "Remote safety provider for Llama Stack integrating FMS Guardrails Orchestrator and community detectors"
99
authors = [
1010
{name = "GitHub: m-misiura"}

tests/unit/test_provider.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec
1+
from llama_stack.providers.datatypes import Api, ProviderSpec
2+
3+
try:
4+
from llama_stack.providers.datatypes import AdapterSpec
5+
6+
USE_LEGACY = True
7+
except ImportError:
8+
from llama_stack.providers.datatypes import RemoteProviderSpec
9+
10+
USE_LEGACY = False
211

312
from llama_stack_provider_trustyai_fms.provider import (
413
get_provider_spec,
@@ -9,26 +18,46 @@
918
class TestProviderFunctions:
1019
def test_get_provider_spec(self):
1120
spec = get_provider_spec()
12-
1321
assert isinstance(spec, ProviderSpec)
1422
assert spec.api == Api.safety
15-
assert isinstance(spec.adapter, AdapterSpec)
16-
assert spec.adapter.adapter_type == "trustyai_fms"
17-
assert (
18-
spec.adapter.config_class
19-
== "llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig"
20-
)
21-
assert spec.adapter.module == "llama_stack_provider_trustyai_fms"
23+
24+
if USE_LEGACY:
25+
assert isinstance(spec.adapter, AdapterSpec)
26+
assert spec.adapter.adapter_type == "trustyai_fms"
27+
assert (
28+
spec.adapter.config_class
29+
== "llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig"
30+
)
31+
assert spec.adapter.module == "llama_stack_provider_trustyai_fms"
32+
else:
33+
assert isinstance(spec, RemoteProviderSpec)
34+
assert spec.provider_type == "remote::trustyai_fms"
35+
assert spec.adapter_type == "trustyai_fms"
36+
assert (
37+
spec.config_class
38+
== "llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig"
39+
)
40+
assert spec.module == "llama_stack_provider_trustyai_fms"
2241

2342
def test_get_shields_provider_spec(self):
2443
spec = get_shields_provider_spec()
25-
2644
assert isinstance(spec, ProviderSpec)
2745
assert spec.api == Api.shields
28-
assert isinstance(spec.adapter, AdapterSpec)
29-
assert spec.adapter.adapter_type == "trustyai_fms"
30-
assert (
31-
spec.adapter.config_class
32-
== "llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig"
33-
)
34-
assert spec.adapter.module == "llama_stack_provider_trustyai_fms"
46+
47+
if USE_LEGACY:
48+
assert isinstance(spec.adapter, AdapterSpec)
49+
assert spec.adapter.adapter_type == "trustyai_fms"
50+
assert (
51+
spec.adapter.config_class
52+
== "llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig"
53+
)
54+
assert spec.adapter.module == "llama_stack_provider_trustyai_fms"
55+
else:
56+
assert isinstance(spec, RemoteProviderSpec)
57+
assert spec.provider_type == "remote::trustyai_fms"
58+
assert spec.adapter_type == "trustyai_fms"
59+
assert (
60+
spec.config_class
61+
== "llama_stack_provider_trustyai_fms.config.FMSSafetyProviderConfig"
62+
)
63+
assert spec.module == "llama_stack_provider_trustyai_fms"

0 commit comments

Comments
 (0)