@@ -305,7 +305,12 @@ class EditorAPIContext:
305305 user_agent : str
306306
307307 token : str | None
308+ subscription_topic : str | None = None
308309 logger : logging .Logger
310+
311+ # Note: this is always set when calling `login` or `me` and updated using SSE.
312+ # If you use a subscription topic it is *not* updated, use the value
313+ # from the response metadata or call `me`.
309314 credits : int | None = None
310315
311316 _client : httpx .AsyncClient | None
@@ -317,6 +322,7 @@ class EditorAPIContext:
317322
318323 def __init__ (
319324 self ,
325+ * ,
320326 credentials : str | None = None ,
321327 api_key : str | None = None ,
322328 user : str | None = None ,
@@ -325,12 +331,14 @@ def __init__(
325331 priority : Priority = "standard" ,
326332 verify : bool | str = True ,
327333 default_timeout : float = 60.0 ,
334+ subscription_topic : str | None = None ,
328335 user_agent : str | None = None ,
329336 ) -> None :
330337 self .base_url = base_url or "https://api.finegrain.ai/editor"
331338 self .priority = priority
332339 self .verify = verify
333340 self .default_timeout = default_timeout
341+ self .subscription_topic = subscription_topic
334342
335343 if credentials is not None :
336344 if (m := API_KEY_PATTERN .match (credentials )) is not None :
@@ -407,6 +415,12 @@ async def login(self) -> None:
407415 self .credits = r ["user" ]["credits" ]
408416 self .token = r ["token" ]
409417
418+ async def me (self ) -> dict [str , Any ]:
419+ response = await self .request ("GET" , "auth/me" )
420+ r = response .json ()
421+ self .credits = r ["credits" ]
422+ return r
423+
410424 async def request (
411425 self ,
412426 method : Literal ["GET" , "POST" ],
@@ -441,7 +455,11 @@ async def _q() -> httpx.Response:
441455 return r
442456
443457 async def get_sub_url (self ) -> str :
444- response = await self .request ("POST" , "sub-auth" )
458+ if self .subscription_topic is not None :
459+ params = {"subscription_topic" : self .subscription_topic }
460+ else :
461+ params = None
462+ response = await self .request ("POST" , "sub-auth" , json = params )
445463 jdata = response .json ()
446464 sub_token = jdata ["token" ]
447465 self ._ping_interval = float (jdata .get ("ping_interval" , 0.0 ))
@@ -557,6 +575,8 @@ async def call_skill(
557575 timeout = timeout or self .default_timeout
558576 user_timeout = max (int (timeout ), 1 )
559577 params = {"priority" : self .priority , "user_timeout" : user_timeout } | (params or {})
578+ if self .subscription_topic is not None :
579+ params ["subscription_topic" ] = self .subscription_topic
560580 response = await self .request ("POST" , f"skills/{ url } " , json = params )
561581 state_id : StateID = response .json ()["state" ]
562582 status = await self .sse_await (state_id , timeout = timeout )
@@ -926,6 +946,8 @@ async def _create_state(
926946 data : dict [str , str ] = {}
927947 if file_url is not None :
928948 data ["file_url" ] = file_url
949+ if self .ctx .subscription_topic is not None :
950+ data ["subscription_topic" ] = self .ctx .subscription_topic
929951 if meta is not None :
930952 data ["meta" ] = json .dumps (meta )
931953 response = await self .ctx .request ("POST" , "state/create" , files = files , data = data )
0 commit comments