1- from typing import Generator , Optional
1+ from typing import Generator , Optional , Union , List , Dict , Optional
22import json
33from dataclasses import dataclass
44from enum import Enum
5- from typing import List , Dict , Optional
65import requests
76
87
@@ -37,36 +36,67 @@ def to_dict(self) -> Dict[str, str]:
3736@dataclass
3837class Model :
3938 """Represents an LLM model"""
39+
4040 name : str
4141
4242
4343@dataclass
4444class ProviderModels :
4545 """Groups models by provider"""
46+
4647 provider : Provider
4748 models : List [Model ]
4849
4950
5051@dataclass
5152class ResponseTokens :
5253 """Response tokens structure as defined in the API spec"""
54+
5355 role : str
5456 model : str
5557 content : str
5658
59+ @classmethod
60+ def from_dict (cls , data : dict ) -> "ResponseTokens" :
61+ """Create ResponseTokens from dictionary data
62+
63+ Args:
64+ data: Dictionary containing response data
65+
66+ Returns:
67+ ResponseTokens instance
68+
69+ Raises:
70+ TypeError: If data is not a dictionary
71+ ValueError: If required fields are missing
72+ """
73+ if not isinstance (data , dict ):
74+ raise TypeError (f"Expected dict, got { type (data )} " )
75+
76+ required = ["role" , "model" , "content" ]
77+ missing = [field for field in required if field not in data ]
78+
79+ if missing :
80+ raise ValueError (
81+ f"Missing required arguments: {
82+ ', ' .join (missing )} "
83+ )
84+
85+ return cls (role = data ["role" ], model = data ["model" ], content = data ["content" ])
86+
5787
5888@dataclass
5989class GenerateResponse :
6090 """Response structure for token generation"""
91+
6192 provider : str
6293 response : ResponseTokens
6394
6495 @classmethod
65- def from_dict (cls , data : dict ) -> ' GenerateResponse' :
96+ def from_dict (cls , data : dict ) -> " GenerateResponse" :
6697 """Create GenerateResponse from dictionary data"""
6798 return cls (
68- provider = data .get ('provider' , '' ),
69- response = ResponseTokens (** data .get ('response' , {}))
99+ provider = data .get ("provider" , "" ), response = ResponseTokens (** data .get ("response" , {}))
70100 )
71101
72102
@@ -86,9 +116,79 @@ def list_models(self) -> List[ProviderModels]:
86116 response .raise_for_status ()
87117 return response .json ()
88118
119+ def _parse_sse_chunk (self , chunk : bytes ) -> dict :
120+ """Parse an SSE message chunk into structured event data
121+
122+ Args:
123+ chunk: Raw SSE message chunk in bytes format
124+
125+ Returns:
126+ dict: Parsed SSE message with event type and data fields
127+
128+ Raises:
129+ json.JSONDecodeError: If chunk format or content is invalid
130+ """
131+ if not isinstance (chunk , bytes ):
132+ raise TypeError (f"Expected bytes, got { type (chunk )} " )
133+
134+ try :
135+ decoded = chunk .decode ("utf-8" )
136+ message = {}
137+
138+ for line in (l .strip () for l in decoded .split ("\n " ) if l .strip ()):
139+ if line .startswith ("event: " ):
140+ message ["event" ] = line .removeprefix ("event: " )
141+ elif line .startswith ("data: " ):
142+ try :
143+ json_str = line .removeprefix ("data: " )
144+ data = json .loads (json_str )
145+ if not isinstance (data , dict ):
146+ raise json .JSONDecodeError (
147+ f"Invalid SSE data format - expected object, got: {
148+ json_str } " ,
149+ json_str ,
150+ 0 ,
151+ )
152+ message ["data" ] = data
153+ except json .JSONDecodeError as e :
154+ raise json .JSONDecodeError (f"Invalid SSE JSON: { json_str } " , e .doc , e .pos )
155+
156+ if not message .get ("data" ):
157+ raise json .JSONDecodeError (
158+ f"Missing or invalid data field in SSE message: {
159+ decoded } " ,
160+ decoded ,
161+ 0 ,
162+ )
163+
164+ return message
165+
166+ except UnicodeDecodeError as e :
167+ raise json .JSONDecodeError (
168+ f"Invalid UTF-8 encoding in SSE chunk: {
169+ chunk !r} " ,
170+ str (chunk ),
171+ 0 ,
172+ )
173+
174+ def _parse_json_line (self , line : bytes ) -> ResponseTokens :
175+ """Parse a single JSON line into GenerateResponse"""
176+ try :
177+ decoded_line = line .decode ("utf-8" )
178+ data = json .loads (decoded_line )
179+ return ResponseTokens .from_dict (data )
180+ except UnicodeDecodeError as e :
181+ raise json .JSONDecodeError (f"Invalid UTF-8 encoding: { line } " , str (line ), 0 )
182+ except json .JSONDecodeError as e :
183+ raise json .JSONDecodeError (
184+ f"Invalid JSON response: {
185+ decoded_line } " ,
186+ e .doc ,
187+ e .pos ,
188+ )
189+
89190 def generate_content (self , provider : Provider , model : str , messages : List [Message ]) -> Dict :
90- payload = {"model" : model , "messages" : [
91- msg .to_dict () for msg in messages ]}
191+ payload = {"model" : model , "messages" : [msg .to_dict () for msg in messages ]}
92192
93193 response = self .session .post (
94194 f"{ self .base_url } /llms/{ provider .value } /generate" , json = payload
@@ -97,12 +197,8 @@ def generate_content(self, provider: Provider, model: str, messages: List[Messag
97197 return response .json ()
98198
99199 def generate_content_stream (
100- self ,
101- provider : Provider ,
102- model : str ,
103- messages : List [Message ],
104- use_sse : bool = False
105- ) -> Generator [Union [GenerateResponse , dict ], None , None ]:
200+ self , provider : Provider , model : str , messages : List [Message ], use_sse : bool = False
201+ ) -> Generator [Union [ResponseTokens , dict ], None , None ]:
106202 """Stream content generation from the model
107203
108204 Args:
@@ -112,33 +208,37 @@ def generate_content_stream(
112208 use_sse: Whether to use Server-Sent Events format
113209
114210 Yields:
115- Either GenerateResponse objects (for raw JSON) or dicts (for SSE)
211+ Either ResponseTokens objects (for raw JSON) or dicts (for SSE)
116212 """
117213 payload = {
118214 "model" : model ,
119215 "messages" : [msg .to_dict () for msg in messages ],
120216 "stream" : True ,
121- "ssevents" : use_sse
217+ "ssevents" : use_sse ,
122218 }
123219
124- with self .session .post (
125- f"{ self .base_url } /llms/{ provider .value } /generate" ,
126- json = payload ,
127- stream = True
128- ) as response :
129- response .raise_for_status ()
220+ response = self .session .post (
221+ f"{ self .base_url } /llms/{ provider .value } /generate" , json = payload , stream = True
222+ )
223+ response .raise_for_status ()
224+
225+ if use_sse :
226+ buffer = []
130227
131228 for line in response .iter_lines ():
132- if line :
133- if use_sse and line .startswith (b'data: ' ):
134- # Handle SSE format
135- data = json .loads (line .decode (
136- 'utf-8' ).replace ('data: ' , '' ))
137- yield data
138- else :
139- # Handle raw JSON format
140- data = json .loads (line )
141- yield GenerateResponse .from_dict (data )
229+ if not line :
230+ if buffer :
231+ chunk = b"\n " .join (buffer )
232+ yield self ._parse_sse_chunk (chunk )
233+ buffer = []
234+ continue
235+
236+ buffer .append (line )
237+ else :
238+ for line in response .iter_lines ():
239+ if not line :
240+ continue
241+ yield self ._parse_json_line (line )
142242
143243 def health_check (self ) -> bool :
144244 """Check if the API is healthy"""
0 commit comments