Skip to content

Commit e5dfb97

Browse files
committed
span algebra
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 655bd38 commit e5dfb97

File tree

7 files changed

+259
-179
lines changed

7 files changed

+259
-179
lines changed

src/pdl/pdl_ast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pydantic.json_schema import SkipJsonSchema
2626
from typing_extensions import TypeAliasType
2727

28-
from .pdl_context import DependentContext
28+
from .pdl_context import PDLContext
2929
from .pdl_lazy import PdlDict, PdlLazy
3030

3131

@@ -42,7 +42,7 @@ def _ensure_lower(value):
4242

4343

4444
LazyMessage: TypeAlias = PdlLazy[dict[str, Any]]
45-
LazyMessages: TypeAlias = DependentContext
45+
LazyMessages: TypeAlias = PDLContext
4646

4747

4848
class BlockKind(StrEnum):

src/pdl/pdl_context.py

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,72 @@
1+
from collections.abc import Sequence
12
from enum import StrEnum
2-
from typing import Any
3+
from typing import Any, Callable
34

4-
from .pdl_lazy import (
5-
PdlDict,
6-
PdlList,
7-
)
5+
from .pdl_lazy import PdlApply, PdlDict, PdlLazy, PdlList
86

97

108
class SerializeMode(StrEnum):
119
LITELLM = "litellm"
1210
GRANITEIO = "graniteio"
1311

1412

15-
class PDLContext:
13+
class PDLContext(Sequence):
1614

1715
def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
1816
return []
1917

18+
def __add__(self, value: "PDLContext"):
19+
return IndependentContext([self, value])
2020

21-
class BaseMessage(PDLContext):
22-
message: PdlDict[str, Any]
21+
def __mul__(self, value: "PDLContext"):
22+
return DependentContext([self, value])
2323

24-
def __init__(self, message: dict[str, Any]):
25-
if "role" not in message:
26-
assert False
27-
if "content" not in message:
28-
assert False
29-
self.message = PdlDict(message)
24+
def __len__(self):
25+
return 0
26+
27+
def __getitem__(self, index: int | slice): # pyright: ignore
28+
return []
29+
30+
31+
class SingletonContext(PDLContext):
32+
message: PdlLazy[dict[str, Any]]
33+
34+
def __init__(self, message: PdlLazy[dict[str, Any]]):
35+
self.message = message
3036

3137
def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
3238
result = self.message.result()
3339
return [result]
3440

41+
def __len__(self): # pyright: ignore
42+
return 1
3543

36-
class IndependentContext(PDLContext):
37-
context: PdlList[PDLContext]
44+
def __getitem__(self, index: int | slice): # pyright: ignore
45+
if index == 0:
46+
return self.message.result()
47+
print(index)
48+
assert False
49+
50+
def __repr__(self): # pyright: ignore
51+
return str(self.message.result())
3852

39-
def __init__(self, context: PdlList[PDLContext]):
40-
self.context = context
53+
54+
class IndependentContext(PDLContext):
55+
context: PdlLazy[list[PDLContext]]
56+
57+
def __init__(self, context: list[PDLContext]):
58+
ret: list[PDLContext] = []
59+
for item in context:
60+
if isinstance(item, IndependentContext):
61+
ret += item.context.data
62+
elif isinstance(item, SingletonContext):
63+
ret += [item]
64+
else:
65+
# Not all elements of the list are Independent, so return
66+
self.context = PdlList(context)
67+
return
68+
# All elements of the list are Independent
69+
self.context = PdlList(ret)
4170

4271
def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
4372
result = self.context.result()
@@ -47,31 +76,74 @@ def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
4776
return [{"independent": flat}]
4877
return flat
4978

79+
def __len__(self): # pyright: ignore
80+
return len(self.context.result())
81+
82+
def __getitem__(self, index: int | slice): # pyright: ignore
83+
return self.serialize(SerializeMode.LITELLM)[index]
84+
85+
def __repr__(self): # pyright: ignore
86+
ret = "{"
87+
ret += ",".join([i.__repr__() for i in self.context.result()])
88+
return ret + "}"
5089

51-
class DependentContext(PDLContext):
52-
context: PdlList[PDLContext]
5390

54-
def __init__(self, context: PdlList[PDLContext]):
55-
self.context = context
91+
class DependentContext(PDLContext):
92+
context: PdlLazy[list[PDLContext]]
93+
94+
def __init__(self, context: list[PDLContext]):
95+
ret: list[PDLContext] = []
96+
for item in context:
97+
if isinstance(item, DependentContext):
98+
ret += item.context.data
99+
elif isinstance(item, SingletonContext):
100+
ret += [item]
101+
else:
102+
# Not all elements of the list are Dependent, so return
103+
self.context = PdlList(context)
104+
return
105+
# All elements of the list are Dependent
106+
self.context = PdlList(ret)
56107

57108
def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
58109
result = self.context.result()
59110
contexts = [m.serialize(mode) for m in result]
60-
return [x for xs in contexts for x in xs]
111+
res = [x for xs in contexts for x in xs]
112+
return res
113+
114+
def __len__(self): # pyright: ignore
115+
return len(self.context.result())
116+
117+
def __getitem__(self, index: int | slice): # pyright: ignore
118+
return self.serialize(SerializeMode.LITELLM)[index]
119+
120+
def __repr__(self): # pyright: ignore
121+
ret = "["
122+
ret += ",".join([i.__repr__() for i in self.context.result()])
123+
return ret + "]"
61124

62125

63126
def deserialize(
64127
context: list[dict[str, Any]],
65128
) -> DependentContext: # Only support dependent for now
66-
ret: DependentContext = DependentContext(PdlList([]))
129+
ret: DependentContext = DependentContext([])
67130
for message in context:
68131
if isinstance(message, dict):
69-
if "role" not in message:
70-
assert False
71-
if "content" not in message:
72-
assert False
73-
ret = DependentContext(PdlList([ret, BaseMessage(message)]))
132+
ret = ret * SingletonContext(PdlDict(message))
74133
else:
75-
ret = DependentContext(PdlList([ret, message]))
76-
134+
ret = ret * message
77135
return ret
136+
137+
138+
def add_done_callback(
139+
f: Callable, p: PDLContext
140+
): # Assuming that f is the identity function
141+
match p:
142+
case SingletonContext(message=m):
143+
p.message = PdlApply(f, m)
144+
case DependentContext(context=c):
145+
p.context = PdlApply(f, c)
146+
case IndependentContext(context=c):
147+
p.context = PdlApply(f, c)
148+
case _:
149+
assert False

0 commit comments

Comments
 (0)