@@ -204,33 +204,60 @@ async def delete_router(self, router_id: int) -> Router | None:
204204 await self .postgres_session .execute (delete (RouterTable ).where (RouterTable .id == router_id ))
205205 return router
206206
207- async def update_router (self , router : Router ) -> Router | RouterNameAlreadyExistsError :
208- db_user_id = None if router .user_id == MASTER_USER_ID else router .user_id
207+ async def update_router (self , router_to_update : Router ) -> Router | RouterNameAlreadyExistsError :
208+ db_user_id = None if router_to_update .user_id == MASTER_USER_ID else router_to_update .user_id
209209
210210 try :
211- await self . postgres_session . execute (
211+ update_query = (
212212 update (RouterTable )
213- .where (RouterTable .id == router .id )
213+ .where (RouterTable .id == router_to_update .id )
214214 .values (
215215 user_id = db_user_id ,
216- name = router .name ,
217- type = router .type .value ,
218- load_balancing_strategy = router .load_balancing_strategy .value ,
219- cost_prompt_tokens = router .cost_prompt_tokens ,
220- cost_completion_tokens = router .cost_completion_tokens ,
216+ name = router_to_update .name ,
217+ type = router_to_update .type .value ,
218+ load_balancing_strategy = router_to_update .load_balancing_strategy .value ,
219+ cost_prompt_tokens = router_to_update .cost_prompt_tokens ,
220+ cost_completion_tokens = router_to_update .cost_completion_tokens ,
221+ )
222+ .returning (
223+ RouterTable .id ,
224+ RouterTable .name ,
225+ RouterTable .user_id ,
226+ RouterTable .type ,
227+ RouterTable .load_balancing_strategy ,
228+ RouterTable .cost_prompt_tokens ,
229+ RouterTable .cost_completion_tokens ,
230+ cast (func .extract ("epoch" , RouterTable .created ), Integer ).label ("created" ),
231+ cast (func .extract ("epoch" , RouterTable .updated ), Integer ).label ("updated" ),
221232 )
222233 )
234+ result = await self .postgres_session .execute (update_query )
235+ row = result .one ()
223236
224- if router .aliases is not None :
225- await self .postgres_session .execute (delete (RouterAliasTable ).where (RouterAliasTable .router_id == router .id ))
226- if router .aliases :
237+ if router_to_update .aliases is not None :
238+ await self .postgres_session .execute (delete (RouterAliasTable ).where (RouterAliasTable .router_id == router_to_update .id ))
239+ if router_to_update .aliases :
227240 await self .postgres_session .execute (
228241 insert (RouterAliasTable ),
229- [{"value" : alias , "router_id" : router .id } for alias in router .aliases ],
242+ [{"value" : alias , "router_id" : router_to_update .id } for alias in router_to_update .aliases ],
230243 )
231244 except IntegrityError as e :
232245 if "router_name_key" in str (e .orig ):
233- return RouterNameAlreadyExistsError (name = router .name )
246+ return RouterNameAlreadyExistsError (name = router_to_update .name )
234247 raise
235248
236- return router
249+ return Router (
250+ id = row .id ,
251+ name = row .name ,
252+ user_id = router_to_update .user_id ,
253+ type = RouterType (row .type ),
254+ aliases = router_to_update .aliases ,
255+ load_balancing_strategy = RouterLoadBalancingStrategy (row .load_balancing_strategy ),
256+ vector_size = router_to_update .vector_size ,
257+ max_context_length = router_to_update .max_context_length ,
258+ cost_prompt_tokens = row .cost_prompt_tokens or 0.0 ,
259+ cost_completion_tokens = row .cost_completion_tokens or 0.0 ,
260+ providers = router_to_update .providers ,
261+ created = row .created ,
262+ updated = row .updated ,
263+ )
0 commit comments