33from __future__ import annotations
44
55from typing import (
6+ TYPE_CHECKING ,
67 Any ,
78 AsyncContextManager ,
9+ Awaitable ,
10+ Callable ,
11+ Dict ,
812 Generic ,
13+ List ,
14+ Literal ,
15+ Mapping ,
16+ MutableMapping ,
917 Optional ,
1018 Sequence ,
19+ Tuple ,
1120 Type ,
1221 TypeVar ,
22+ cast ,
1323)
1424
25+ if TYPE_CHECKING :
26+ from typing import Self
27+
1528from agentlightning .types import (
1629 Attempt ,
30+ FilterField ,
1731 FilterOptions ,
1832 PaginatedResult ,
1933 ResourcesUpdate ,
@@ -36,13 +50,13 @@ def primary_keys(self) -> Sequence[str]:
3650 raise NotImplementedError ()
3751
3852 def __repr__ (self ) -> str :
39- return f"<{ self .__class__ .__name__ } [{ self .item_type ().__name__ } ] ( { self . size () } ) >"
53+ return f"<{ self .__class__ .__name__ } [{ self .item_type ().__name__ } ]>"
4054
4155 def item_type (self ) -> Type [T ]:
4256 """Get the type of the items in the collection."""
4357 raise NotImplementedError ()
4458
45- def size (self ) -> int :
59+ async def size (self ) -> int :
4660 """Get the number of items in the collection."""
4761 raise NotImplementedError ()
4862
@@ -132,7 +146,7 @@ class Queue(Generic[T]):
132146 """Behaves like a deque. Supporting appending items to the end and popping items from the front."""
133147
134148 def __repr__ (self ) -> str :
135- return f"<{ self .__class__ .__name__ } [{ self .item_type ().__name__ } ] ( { self . size () } ) >"
149+ return f"<{ self .__class__ .__name__ } [{ self .item_type ().__name__ } ]>"
136150
137151 def item_type (self ) -> Type [T ]:
138152 """Get the type of the items in the queue."""
@@ -177,7 +191,7 @@ async def peek(self, limit: int = 1) -> Sequence[T]:
177191 """
178192 raise NotImplementedError ()
179193
180- def size (self ) -> int :
194+ async def size (self ) -> int :
181195 """Get the number of items in the queue."""
182196 raise NotImplementedError ()
183197
@@ -186,7 +200,7 @@ class KeyValue(Generic[K, V]):
186200 """Behaves like a dictionary. Supporting addition, updating, and deletion of items."""
187201
188202 def __repr__ (self ) -> str :
189- return f"<{ self .__class__ .__name__ } ( { self . size () } ) >"
203+ return f"<{ self .__class__ .__name__ } >"
190204
191205 async def has (self , key : K ) -> bool :
192206 """Check if the given key is in the dictionary."""
@@ -204,7 +218,7 @@ async def pop(self, key: K, default: V | None = None) -> V | None:
204218 """Pop the value for the given key, or the default value if the key is not found."""
205219 raise NotImplementedError ()
206220
207- def size (self ) -> int :
221+ async def size (self ) -> int :
208222 """Get the number of items in the dictionary."""
209223 raise NotImplementedError ()
210224
@@ -251,7 +265,7 @@ def span_sequence_ids(self) -> KeyValue[str, int]:
251265 """Dictionary (counter) of span sequence IDs."""
252266 raise NotImplementedError ()
253267
254- def atomic (self , * args : Any , ** kwargs : Any ) -> AsyncContextManager [None ]:
268+ def atomic (self , * args : Any , ** kwargs : Any ) -> AsyncContextManager [Self ]:
255269 """Perform a atomic operation on the collections.
256270
257271 Subclass may use args and kwargs to support multiple levels of atomicity.
@@ -261,3 +275,82 @@ def atomic(self, *args: Any, **kwargs: Any) -> AsyncContextManager[None]:
261275 **kwargs: Keyword arguments to pass to the operation.
262276 """
263277 raise NotImplementedError ()
278+
279+ async def execute (self , callback : Callable [[Self ], Awaitable [T ]]) -> T :
280+ """Execute the given callback within an atomic operation."""
281+ async with self .atomic () as collections :
282+ return await callback (collections )
283+
284+
285+ FilterMap = Mapping [str , FilterField ]
286+
287+
288+ def merge_must_filters (target : MutableMapping [str , FilterField ], definition : Any ) -> None :
289+ """Normalize a `_must` filter group into the provided mapping.
290+
291+ Mainly for validation purposes.
292+ """
293+ if definition is None :
294+ return
295+
296+ entries : List [Mapping [str , FilterField ]] = []
297+ if isinstance (definition , Mapping ):
298+ entries .append (cast (Mapping [str , FilterField ], definition ))
299+ elif isinstance (definition , Sequence ) and not isinstance (definition , (str , bytes )):
300+ for entry in definition : # type: ignore
301+ if not isinstance (entry , Mapping ):
302+ raise TypeError ("Each `_must` entry must be a mapping of field names to operators" )
303+ entries .append (cast (Mapping [str , FilterField ], entry ))
304+ else :
305+ raise TypeError ("`_must` filters must be provided as a mapping or sequence of mappings" )
306+
307+ for entry in entries :
308+ for field_name , ops in entry .items ():
309+ existing = target .get (field_name , {})
310+ merged_ops : Dict [str , Any ] = dict (existing )
311+ for op_name , expected in ops .items ():
312+ if op_name in merged_ops :
313+ raise ValueError (f"Duplicate operator '{ op_name } ' for field '{ field_name } ' in must filters" )
314+ merged_ops [op_name ] = expected
315+ target [field_name ] = cast (FilterField , merged_ops )
316+
317+
318+ def normalize_filter_options (
319+ filter_options : Optional [FilterOptions ],
320+ ) -> Tuple [Optional [FilterMap ], Optional [FilterMap ], Literal ["and" , "or" ]]:
321+ """Convert FilterOptions to the internal structure and resolve aggregate logic."""
322+ if not filter_options :
323+ return None , None , "and"
324+
325+ aggregate = cast (Literal ["and" , "or" ], filter_options .get ("_aggregate" , "and" ))
326+ if aggregate not in ("and" , "or" ):
327+ raise ValueError (f"Unsupported filter aggregate '{ aggregate } '" )
328+
329+ # Extract normalized filters and must filters from the filter options.
330+ normalized : Dict [str , FilterField ] = {}
331+ must_filters : Dict [str , FilterField ] = {}
332+ for field_name , ops in filter_options .items ():
333+ if field_name == "_aggregate" :
334+ continue
335+ if field_name == "_must" :
336+ merge_must_filters (must_filters , ops )
337+ continue
338+ normalized [field_name ] = cast (FilterField , dict (ops )) # type: ignore
339+
340+ return (normalized or None , must_filters or None , aggregate )
341+
342+
343+ def resolve_sort_options (sort : Optional [SortOptions ]) -> Tuple [Optional [str ], Literal ["asc" , "desc" ]]:
344+ """Extract sort field/order from the caller-provided SortOptions."""
345+ if not sort :
346+ return None , "asc"
347+
348+ sort_name = sort .get ("name" )
349+ if not sort_name :
350+ raise ValueError ("Sort options must include a 'name' field" )
351+
352+ sort_order = sort .get ("order" , "asc" )
353+ if sort_order not in ("asc" , "desc" ):
354+ raise ValueError (f"Unsupported sort order '{ sort_order } '" )
355+
356+ return sort_name , sort_order
0 commit comments