1
+ from collections .abc import Sequence
1
2
from enum import StrEnum
2
- from typing import Any
3
+ from typing import Any , Callable
3
4
4
- from .pdl_lazy import (
5
- PdlDict ,
6
- PdlList ,
7
- )
5
+ from .pdl_lazy import PdlApply , PdlDict , PdlLazy , PdlList
8
6
9
7
10
8
class SerializeMode (StrEnum ):
11
9
LITELLM = "litellm"
12
10
GRANITEIO = "graniteio"
13
11
14
12
15
- class PDLContext :
13
+ class PDLContext ( Sequence ) :
16
14
17
15
def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
18
16
return []
19
17
18
+ def __add__ (self , value : "PDLContext" ):
19
+ return IndependentContext ([self , value ])
20
20
21
- class BaseMessage ( PDLContext ):
22
- message : PdlDict [ str , Any ]
21
+ def __mul__ ( self , value : " PDLContext" ):
22
+ return DependentContext ([ self , value ])
23
23
24
- def __init__ (self , message : dict [str , Any ]):
25
- if "role" not in message :
26
- assert False
27
- if "content" not in message :
28
- assert False
29
- self .message = PdlDict (message )
24
+ def __len__ (self ):
25
+ return 0
26
+
27
+ def __getitem__ (self , index : int | slice ): # pyright: ignore
28
+ return []
29
+
30
+
31
+ class SingletonContext (PDLContext ):
32
+ message : PdlLazy [dict [str , Any ]]
33
+
34
+ def __init__ (self , message : PdlLazy [dict [str , Any ]]):
35
+ self .message = message
30
36
31
37
def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
32
38
result = self .message .result ()
33
39
return [result ]
34
40
41
+ def __len__ (self ): # pyright: ignore
42
+ return 1
35
43
36
- class IndependentContext (PDLContext ):
37
- context : PdlList [PDLContext ]
44
+ def __getitem__ (self , index : int | slice ): # pyright: ignore
45
+ if index == 0 :
46
+ return self .message .result ()
47
+ print (index )
48
+ assert False
49
+
50
+ def __repr__ (self ): # pyright: ignore
51
+ return str (self .message .result ())
38
52
39
- def __init__ (self , context : PdlList [PDLContext ]):
40
- self .context = context
53
+
54
+ class IndependentContext (PDLContext ):
55
+ context : PdlLazy [list [PDLContext ]]
56
+
57
+ def __init__ (self , context : list [PDLContext ]):
58
+ ret : list [PDLContext ] = []
59
+ for item in context :
60
+ if isinstance (item , IndependentContext ):
61
+ ret += item .context .data
62
+ elif isinstance (item , SingletonContext ):
63
+ ret += [item ]
64
+ else :
65
+ # Not all elements of the list are Independent, so return
66
+ self .context = PdlList (context )
67
+ return
68
+ # All elements of the list are Independent
69
+ self .context = PdlList (ret )
41
70
42
71
def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
43
72
result = self .context .result ()
@@ -47,31 +76,74 @@ def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
47
76
return [{"independent" : flat }]
48
77
return flat
49
78
79
+ def __len__ (self ): # pyright: ignore
80
+ return len (self .context .result ())
81
+
82
+ def __getitem__ (self , index : int | slice ): # pyright: ignore
83
+ return self .serialize (SerializeMode .LITELLM )[index ]
84
+
85
+ def __repr__ (self ): # pyright: ignore
86
+ ret = "{"
87
+ ret += "," .join ([i .__repr__ () for i in self .context .result ()])
88
+ return ret + "}"
50
89
51
- class DependentContext (PDLContext ):
52
- context : PdlList [PDLContext ]
53
90
54
- def __init__ (self , context : PdlList [PDLContext ]):
55
- self .context = context
91
+ class DependentContext (PDLContext ):
92
+ context : PdlLazy [list [PDLContext ]]
93
+
94
+ def __init__ (self , context : list [PDLContext ]):
95
+ ret : list [PDLContext ] = []
96
+ for item in context :
97
+ if isinstance (item , DependentContext ):
98
+ ret += item .context .data
99
+ elif isinstance (item , SingletonContext ):
100
+ ret += [item ]
101
+ else :
102
+ # Not all elements of the list are Dependent, so return
103
+ self .context = PdlList (context )
104
+ return
105
+ # All elements of the list are Dependent
106
+ self .context = PdlList (ret )
56
107
57
108
def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
58
109
result = self .context .result ()
59
110
contexts = [m .serialize (mode ) for m in result ]
60
- return [x for xs in contexts for x in xs ]
111
+ res = [x for xs in contexts for x in xs ]
112
+ return res
113
+
114
+ def __len__ (self ): # pyright: ignore
115
+ return len (self .context .result ())
116
+
117
+ def __getitem__ (self , index : int | slice ): # pyright: ignore
118
+ return self .serialize (SerializeMode .LITELLM )[index ]
119
+
120
+ def __repr__ (self ): # pyright: ignore
121
+ ret = "["
122
+ ret += "," .join ([i .__repr__ () for i in self .context .result ()])
123
+ return ret + "]"
61
124
62
125
63
126
def deserialize (
64
127
context : list [dict [str , Any ]],
65
128
) -> DependentContext : # Only support dependent for now
66
- ret : DependentContext = DependentContext (PdlList ([]) )
129
+ ret : DependentContext = DependentContext ([] )
67
130
for message in context :
68
131
if isinstance (message , dict ):
69
- if "role" not in message :
70
- assert False
71
- if "content" not in message :
72
- assert False
73
- ret = DependentContext (PdlList ([ret , BaseMessage (message )]))
132
+ ret = ret * SingletonContext (PdlDict (message ))
74
133
else :
75
- ret = DependentContext (PdlList ([ret , message ]))
76
-
134
+ ret = ret * message
77
135
return ret
136
+
137
+
138
+ def add_done_callback (
139
+ f : Callable , p : PDLContext
140
+ ): # Assuming that f is the identity function
141
+ match p :
142
+ case SingletonContext (message = m ):
143
+ p .message = PdlApply (f , m )
144
+ case DependentContext (context = c ):
145
+ p .context = PdlApply (f , c )
146
+ case IndependentContext (context = c ):
147
+ p .context = PdlApply (f , c )
148
+ case _:
149
+ assert False
0 commit comments