Skip to content

Commit a412ff0

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
LLM Provider & Message Abstractions (#4826)
Summary: Pull Request resolved: #4826 This diff implements a simple `LLMProvider` protocol (to be followed with a `LiteLLMProvider` implementation for easy access to many providers) and an `LLMMessage` dataclass. These classes support easy typing for any LLM usage within Ax, and enable storage of the inputs and conversations. Reviewed By: lena-kashtelyan Differential Revision: D90788634 fbshipit-source-id: 750ccbe916d57224f07a04f88e4d6648cb5fa69b
1 parent bc00476 commit a412ff0

5 files changed

Lines changed: 177 additions & 0 deletions

File tree

ax/core/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ax.core.data import Data
1313
from ax.core.experiment import Experiment
1414
from ax.core.generator_run import GeneratorRun
15+
from ax.core.llm_provider import LLMMessage, LLMProvider
1516
from ax.core.metric import Metric
1617
from ax.core.objective import MultiObjective, Objective
1718
from ax.core.observation import ObservationFeatures
@@ -46,6 +47,8 @@
4647
"Experiment",
4748
"FixedParameter",
4849
"GeneratorRun",
50+
"LLMMessage",
51+
"LLMProvider",
4952
"Metric",
5053
"MultiObjective",
5154
"MultiObjectiveOptimizationConfig",

ax/core/llm_provider.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
LLM Provider protocol and message types for Ax.
11+
12+
This module defines the core abstractions for LLM integration in Ax:
13+
- LLMMessage: Unified message type for conversations and responses
14+
- LLMProvider: Protocol defining the interface for LLM providers
15+
"""
16+
17+
from dataclasses import dataclass, field
18+
from typing import Any, Literal, Protocol, runtime_checkable
19+
20+
21+
@dataclass
22+
class LLMMessage:
23+
"""Represents a single message in a conversation.
24+
25+
This unified class handles both input messages and LLM responses.
26+
For assistant responses (role="assistant"), the metadata field
27+
captures information about the generation.
28+
29+
Attributes:
30+
role: Message role - "system", "user", or "assistant"
31+
content: Message content/text
32+
metadata: Additional metadata.
33+
For assistant responses, this may include:
34+
- "usage": Token usage statistics
35+
- "finish_reason": Reason for generation completion (e.g., "stop")
36+
"""
37+
38+
role: Literal["system", "user", "assistant"]
39+
content: str
40+
metadata: dict[str, Any] = field(default_factory=dict)
41+
42+
43+
@runtime_checkable
44+
class LLMProvider(Protocol):
45+
"""Protocol defining the interface for LLM providers.
46+
47+
Any class implementing this protocol can be used as an LLM provider in Ax.
48+
This enables easy integration of custom LLM backends without requiring
49+
inheritance from a base class.
50+
51+
Implementations must provide:
52+
- generate(): method to generate responses from messages
53+
54+
Example:
55+
>>> class MyCustomProvider:
56+
... def generate(
57+
... self,
58+
... messages: list[LLMMessage],
59+
... **kwargs: Any,
60+
... ) -> LLMMessage:
61+
... # Custom implementation
62+
... return LLMMessage(role="assistant", content="response")
63+
...
64+
>>> # Type checker will accept this as LLMProvider
65+
>>> provider: LLMProvider = MyCustomProvider()
66+
"""
67+
68+
def generate(
69+
self,
70+
messages: list[LLMMessage],
71+
**kwargs: Any,
72+
) -> LLMMessage:
73+
"""Generate a response from a sequence of messages.
74+
75+
Args:
76+
messages: List of conversation messages with roles and content
77+
**kwargs: Provider-specific parameters (e.g., temperature, max_tokens)
78+
79+
Returns:
80+
LLMMessage with role="assistant" containing the generated response
81+
"""
82+
...

ax/core/tests/test_llm_provider.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Any, Literal
10+
11+
from ax.core.llm_provider import LLMMessage, LLMProvider
12+
from ax.utils.common.testutils import TestCase
13+
14+
15+
class LLMMessageTest(TestCase):
16+
def test_llm_message(self) -> None:
17+
"""Test LLMMessage creation and validation."""
18+
test_cases: list[tuple[Literal["user", "system", "assistant"], str]] = [
19+
("user", "Hello"),
20+
("system", "You are helpful"),
21+
("assistant", "Hi there"),
22+
]
23+
for role, content in test_cases:
24+
with self.subTest(role=role):
25+
msg = LLMMessage(role=role, content=content)
26+
self.assertEqual(msg.role, role)
27+
self.assertEqual(msg.content, content)
28+
self.assertEqual(msg.metadata, {})
29+
30+
def test_llm_message_with_metadata(self) -> None:
31+
"""Test LLMMessage with assistant metadata (response case)."""
32+
msg = LLMMessage(
33+
role="assistant",
34+
content="Hello world",
35+
metadata={
36+
"finish_reason": "stop",
37+
"usage": {
38+
"prompt_tokens": 10,
39+
"completion_tokens": 20,
40+
"total_tokens": 30,
41+
},
42+
},
43+
)
44+
self.assertEqual(msg.role, "assistant")
45+
self.assertEqual(msg.content, "Hello world")
46+
self.assertEqual(msg.metadata["finish_reason"], "stop")
47+
self.assertEqual(msg.metadata["usage"]["total_tokens"], 30)
48+
49+
# With minimal fields
50+
msg_minimal = LLMMessage(role="assistant", content="Hello")
51+
self.assertEqual(msg_minimal.metadata, {})
52+
53+
54+
class LLMProviderProtocolTest(TestCase):
55+
def test_protocol_compliance(self) -> None:
56+
"""Test that custom classes can implement the LLMProvider protocol."""
57+
58+
class MockProvider:
59+
"""A mock provider that implements the LLMProvider protocol."""
60+
61+
def generate(
62+
self,
63+
messages: list[LLMMessage],
64+
**kwargs: Any,
65+
) -> LLMMessage:
66+
return LLMMessage(
67+
role="assistant",
68+
content=f"Mock response to: {messages[-1].content}",
69+
)
70+
71+
provider = MockProvider()
72+
73+
# Test that it's recognized as implementing the protocol
74+
self.assertIsInstance(provider, LLMProvider)
75+
76+
# Test that it works
77+
response = provider.generate(
78+
messages=[LLMMessage(role="user", content="Hello")]
79+
)
80+
self.assertEqual(response.role, "assistant")
81+
self.assertIn("Hello", response.content)

ax/storage/json_store/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ax.core.data import Data
3131
from ax.core.evaluations_to_data import DataType
3232
from ax.core.generator_run import GeneratorRun
33+
from ax.core.llm_provider import LLMMessage
3334
from ax.core.map_metric import MapMetric
3435
from ax.core.metric import Metric
3536
from ax.core.multi_type_experiment import MultiTypeExperiment
@@ -336,6 +337,7 @@
336337
# name linked to the new corresponding class
337338
"ListSurrogate": Surrogate,
338339
"L2NormMetric": L2NormMetric,
340+
"LLMMessage": LLMMessage,
339341
"LogNormalPrior": LogNormalPrior,
340342
"MapData": Data,
341343
"MapMetric": MapMetric,

ax/storage/json_store/tests/test_json_store.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ax.core.auxiliary import AuxiliaryExperimentPurpose
3535
from ax.core.data import Data
3636
from ax.core.generator_run import GeneratorRun
37+
from ax.core.llm_provider import LLMMessage
3738
from ax.core.metric import Metric
3839
from ax.core.objective import Objective
3940
from ax.core.observation import ObservationFeatures
@@ -348,6 +349,14 @@
348349
("HierarchicalSearchSpace", get_hierarchical_search_space),
349350
("ImprovementGlobalStoppingStrategy", get_improvement_global_stopping_strategy),
350351
("Interval", get_interval),
352+
(
353+
"LLMMessage",
354+
lambda: LLMMessage(
355+
role="assistant",
356+
content="Hello!",
357+
metadata={"finish_reason": "stop", "usage": {"total_tokens": 10}},
358+
),
359+
),
351360
("MapData", get_map_data),
352361
("MapMetric", partial(get_map_metric, name="test")),
353362
("Metric", get_metric),

0 commit comments

Comments
 (0)