Skip to content

Commit ab185a5

Browse files
authored
MongoDB-based Lightning Store (#323)
1 parent d581cbc commit ab185a5

File tree

20 files changed

+2456
-749
lines changed

20 files changed

+2456
-749
lines changed

.github/workflows/tests-full.yml

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ jobs:
4747
- uses: actions/checkout@v4
4848
with:
4949
ref: ${{ github.event_name == 'repository_dispatch' && github.event.client_payload.pr_ref || (github.event.pull_request.number && format('refs/pull/{0}/merge', github.event.pull_request.number)) || github.ref }}
50+
5051
- uses: astral-sh/setup-uv@v7
5152
with:
5253
enable-cache: true
@@ -55,10 +56,10 @@ jobs:
5556
run: uv lock --upgrade
5657
if: matrix.setup-script == 'latest'
5758
- name: Sync dependencies (latest)
58-
run: uv sync --frozen --no-default-groups --extra apo --group dev --group agents --group torch-gpu-stable
59+
run: uv sync --frozen --no-default-groups --extra apo --extra mongo --group dev --group agents --group torch-gpu-stable
5960
if: matrix.setup-script == 'latest'
6061
- name: Sync dependencies (stable & legacy)
61-
run: uv sync --frozen --no-default-groups --extra apo --group dev --group agents --group torch-gpu-${{ matrix.setup-script }}
62+
run: uv sync --frozen --no-default-groups --extra apo --extra mongo --group dev --group agents --group torch-gpu-${{ matrix.setup-script }}
6263
if: matrix.setup-script != 'latest'
6364
- name: Freeze dependencies
6465
run: |
@@ -81,6 +82,36 @@ jobs:
8182
- name: Build dashboard
8283
run: cd dashboard && npm run build
8384

85+
- name: Start MongoDB container
86+
run: |
87+
set -euo pipefail
88+
cat /etc/security/limits.conf
89+
docker run -d \
90+
--name mongodb-test \
91+
--ulimit nofile=65535:65535 \
92+
-p 27017:27017 \
93+
mongo:8.2 \
94+
--replSet test-rs
95+
96+
# Wait for mongod to come up
97+
for i in $(seq 1 30); do
98+
if docker exec mongodb-test mongosh --quiet --eval 'db.runCommand({ ping: 1 })' >/dev/null 2>&1; then
99+
echo "Mongo is up"
100+
break
101+
fi
102+
echo "Waiting for Mongo..."
103+
sleep 2
104+
done
105+
106+
# Init replica set (simple single-node)
107+
docker exec mongodb-test mongosh --quiet --eval '
108+
rs.initiate({
109+
_id: "test-rs",
110+
members: [{ _id: 0, host: "localhost:27017" }]
111+
})
112+
'
113+
shell: bash
114+
84115
- name: Launch LiteLLM Proxy
85116
run: |
86117
./scripts/litellm_run.sh
@@ -95,6 +126,8 @@ jobs:
95126
PYTEST_ADDOPTS: "--color=yes"
96127
OPENAI_BASE_URL: http://localhost:12306/
97128
OPENAI_API_KEY: dummy
129+
AGL_TEST_MONGO_URI: mongodb://localhost:27017/?replicaSet=test-rs
130+
98131

99132
minimal-examples:
100133
if: >
@@ -160,6 +193,7 @@ jobs:
160193
source .venv/bin/activate
161194
cd examples/minimal
162195
python write_traces.py otel
196+
sleep 5
163197
164198
- name: Write Traces via AgentOps Tracer
165199
env:
@@ -170,6 +204,7 @@ jobs:
170204
source .venv/bin/activate
171205
cd examples/minimal
172206
python write_traces.py agentops
207+
sleep 5
173208
174209
- name: Write Traces via Otel Tracer with Client
175210
run: |

.github/workflows/tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ jobs:
3737
uv sync --frozen \
3838
--extra apo \
3939
--extra verl \
40+
--extra mongo \
4041
--group dev \
4142
--group torch-cpu \
4243
--group torch-stable \
@@ -166,7 +167,7 @@ jobs:
166167

167168
- name: Run tests
168169
run: |
169-
uv run pytest -v --durations=0 tests
170+
uv run pytest -v --durations=0 tests -m "not mongo"
170171
env:
171172
PYTEST_ADDOPTS: "--color=yes"
172173

agentlightning/llm_proxy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,6 @@ async def openai_stream_generator(self, response_json: Dict[str, Any]) -> AsyncG
853853
) # e.g., "stop", "length", "tool_calls", "content_filter"
854854

855855
def sse_chunk(obj: Dict[str, Any]) -> str:
856-
print("sse_chunk: ", obj)
857856
return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
858857

859858
# 1) initial chunk with the role

agentlightning/store/client_server.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,26 @@ async def _call_store_method(self, method_name: str, *args: Any, **kwargs: Any)
799799
# wait_for_rollouts can block for a long time; avoid holding the lock
800800
# so other requests can make progress while we wait.
801801
return await getattr(self.store, method_name)(*args, **kwargs)
802-
with self._lock:
802+
803+
# If it's already thread-safe, we can just call the method directly.
804+
# Acquiring the threading lock directly would block the event loop if it's
805+
# already held by another thread (for example, the HTTP server thread).
806+
# Potential fix here are needed to make it work. For example:
807+
# ```
808+
# acquired = self._lock.acquire(blocking=False)
809+
# if not acquired:
810+
# await asyncio.to_thread(self._lock.acquire)
811+
# try:
812+
# return await getattr(self.store, method_name)(*args, **kwargs)
813+
# finally:
814+
# self._lock.release()
815+
# ```
816+
# Or we can just bypass the lock for thread-safe stores.
817+
if self.store is not None and self.store.capabilities.get("thread_safe", False):
803818
return await getattr(self.store, method_name)(*args, **kwargs)
819+
else:
820+
with self._lock:
821+
return await getattr(self.store, method_name)(*args, **kwargs)
804822
if self._client is None:
805823
self._client = LightningStoreClient(self.endpoint)
806824
return await getattr(self._client, method_name)(*args, **kwargs)
@@ -1605,6 +1623,7 @@ async def add_otel_span(
16051623
attempt_id=attempt_id,
16061624
sequence_id=sequence_id,
16071625
)
1626+
print("created span", span)
16081627
await self.add_span(span)
16091628
return span
16101629

agentlightning/store/collection/base.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,31 @@
33
from __future__ import annotations
44

55
from typing import (
6+
TYPE_CHECKING,
67
Any,
78
AsyncContextManager,
9+
Awaitable,
10+
Callable,
11+
Dict,
812
Generic,
13+
List,
14+
Literal,
15+
Mapping,
16+
MutableMapping,
917
Optional,
1018
Sequence,
19+
Tuple,
1120
Type,
1221
TypeVar,
22+
cast,
1323
)
1424

25+
if TYPE_CHECKING:
26+
from typing import Self
27+
1528
from agentlightning.types import (
1629
Attempt,
30+
FilterField,
1731
FilterOptions,
1832
PaginatedResult,
1933
ResourcesUpdate,
@@ -36,13 +50,13 @@ def primary_keys(self) -> Sequence[str]:
3650
raise NotImplementedError()
3751

3852
def __repr__(self) -> str:
39-
return f"<{self.__class__.__name__}[{self.item_type().__name__}] ({self.size()})>"
53+
return f"<{self.__class__.__name__}[{self.item_type().__name__}]>"
4054

4155
def item_type(self) -> Type[T]:
4256
"""Get the type of the items in the collection."""
4357
raise NotImplementedError()
4458

45-
def size(self) -> int:
59+
async def size(self) -> int:
4660
"""Get the number of items in the collection."""
4761
raise NotImplementedError()
4862

@@ -132,7 +146,7 @@ class Queue(Generic[T]):
132146
"""Behaves like a deque. Supporting appending items to the end and popping items from the front."""
133147

134148
def __repr__(self) -> str:
135-
return f"<{self.__class__.__name__}[{self.item_type().__name__}] ({self.size()})>"
149+
return f"<{self.__class__.__name__}[{self.item_type().__name__}]>"
136150

137151
def item_type(self) -> Type[T]:
138152
"""Get the type of the items in the queue."""
@@ -177,7 +191,7 @@ async def peek(self, limit: int = 1) -> Sequence[T]:
177191
"""
178192
raise NotImplementedError()
179193

180-
def size(self) -> int:
194+
async def size(self) -> int:
181195
"""Get the number of items in the queue."""
182196
raise NotImplementedError()
183197

@@ -186,7 +200,7 @@ class KeyValue(Generic[K, V]):
186200
"""Behaves like a dictionary. Supporting addition, updating, and deletion of items."""
187201

188202
def __repr__(self) -> str:
189-
return f"<{self.__class__.__name__} ({self.size()})>"
203+
return f"<{self.__class__.__name__}>"
190204

191205
async def has(self, key: K) -> bool:
192206
"""Check if the given key is in the dictionary."""
@@ -204,7 +218,7 @@ async def pop(self, key: K, default: V | None = None) -> V | None:
204218
"""Pop the value for the given key, or the default value if the key is not found."""
205219
raise NotImplementedError()
206220

207-
def size(self) -> int:
221+
async def size(self) -> int:
208222
"""Get the number of items in the dictionary."""
209223
raise NotImplementedError()
210224

@@ -251,7 +265,7 @@ def span_sequence_ids(self) -> KeyValue[str, int]:
251265
"""Dictionary (counter) of span sequence IDs."""
252266
raise NotImplementedError()
253267

254-
def atomic(self, *args: Any, **kwargs: Any) -> AsyncContextManager[None]:
268+
def atomic(self, *args: Any, **kwargs: Any) -> AsyncContextManager[Self]:
255269
"""Perform a atomic operation on the collections.
256270
257271
Subclass may use args and kwargs to support multiple levels of atomicity.
@@ -261,3 +275,82 @@ def atomic(self, *args: Any, **kwargs: Any) -> AsyncContextManager[None]:
261275
**kwargs: Keyword arguments to pass to the operation.
262276
"""
263277
raise NotImplementedError()
278+
279+
async def execute(self, callback: Callable[[Self], Awaitable[T]]) -> T:
280+
"""Execute the given callback within an atomic operation."""
281+
async with self.atomic() as collections:
282+
return await callback(collections)
283+
284+
285+
FilterMap = Mapping[str, FilterField]
286+
287+
288+
def merge_must_filters(target: MutableMapping[str, FilterField], definition: Any) -> None:
289+
"""Normalize a `_must` filter group into the provided mapping.
290+
291+
Mainly for validation purposes.
292+
"""
293+
if definition is None:
294+
return
295+
296+
entries: List[Mapping[str, FilterField]] = []
297+
if isinstance(definition, Mapping):
298+
entries.append(cast(Mapping[str, FilterField], definition))
299+
elif isinstance(definition, Sequence) and not isinstance(definition, (str, bytes)):
300+
for entry in definition: # type: ignore
301+
if not isinstance(entry, Mapping):
302+
raise TypeError("Each `_must` entry must be a mapping of field names to operators")
303+
entries.append(cast(Mapping[str, FilterField], entry))
304+
else:
305+
raise TypeError("`_must` filters must be provided as a mapping or sequence of mappings")
306+
307+
for entry in entries:
308+
for field_name, ops in entry.items():
309+
existing = target.get(field_name, {})
310+
merged_ops: Dict[str, Any] = dict(existing)
311+
for op_name, expected in ops.items():
312+
if op_name in merged_ops:
313+
raise ValueError(f"Duplicate operator '{op_name}' for field '{field_name}' in must filters")
314+
merged_ops[op_name] = expected
315+
target[field_name] = cast(FilterField, merged_ops)
316+
317+
318+
def normalize_filter_options(
319+
filter_options: Optional[FilterOptions],
320+
) -> Tuple[Optional[FilterMap], Optional[FilterMap], Literal["and", "or"]]:
321+
"""Convert FilterOptions to the internal structure and resolve aggregate logic."""
322+
if not filter_options:
323+
return None, None, "and"
324+
325+
aggregate = cast(Literal["and", "or"], filter_options.get("_aggregate", "and"))
326+
if aggregate not in ("and", "or"):
327+
raise ValueError(f"Unsupported filter aggregate '{aggregate}'")
328+
329+
# Extract normalized filters and must filters from the filter options.
330+
normalized: Dict[str, FilterField] = {}
331+
must_filters: Dict[str, FilterField] = {}
332+
for field_name, ops in filter_options.items():
333+
if field_name == "_aggregate":
334+
continue
335+
if field_name == "_must":
336+
merge_must_filters(must_filters, ops)
337+
continue
338+
normalized[field_name] = cast(FilterField, dict(ops)) # type: ignore
339+
340+
return (normalized or None, must_filters or None, aggregate)
341+
342+
343+
def resolve_sort_options(sort: Optional[SortOptions]) -> Tuple[Optional[str], Literal["asc", "desc"]]:
344+
"""Extract sort field/order from the caller-provided SortOptions."""
345+
if not sort:
346+
return None, "asc"
347+
348+
sort_name = sort.get("name")
349+
if not sort_name:
350+
raise ValueError("Sort options must include a 'name' field")
351+
352+
sort_order = sort.get("order", "asc")
353+
if sort_order not in ("asc", "desc"):
354+
raise ValueError(f"Unsupported sort order '{sort_order}'")
355+
356+
return sort_name, sort_order

0 commit comments

Comments
 (0)