@@ -357,7 +357,9 @@ def _read_from_cache(
357357
358358 def _write_to_cache (
359359 self ,
360- cache_key : Literal [VERSION_CACHE_KEY , LABEL_CACHE_KEY , HASH_CACHE_KEY , BRANCH_CACHE_KEY ],
360+ cache_key : Literal [
361+ VERSION_CACHE_KEY , LABEL_CACHE_KEY , HASH_CACHE_KEY , BRANCH_CACHE_KEY
362+ ],
361363 hash : str ,
362364 version : Optional [str ] = None ,
363365 label : Optional [str ] = None ,
@@ -459,11 +461,7 @@ def _load_from_cache_with_progress(
459461 Raises if unable to load from cache.
460462 """
461463 cached_prompt = self ._read_from_cache (
462- self .alias ,
463- version = version ,
464- label = label ,
465- hash = hash ,
466- branch = branch
464+ self .alias , version = version , label = label , hash = hash , branch = branch
467465 )
468466 if not cached_prompt :
469467 raise ValueError ("Unable to fetch prompt and load from cache" )
@@ -521,7 +519,11 @@ def pull(
521519 if refresh :
522520 # Check if we need to bootstrap the cache
523521 cached_prompt = self ._read_from_cache (
524- self .alias , version = version , label = label , hash = hash , branch = branch
522+ self .alias ,
523+ version = version ,
524+ label = label ,
525+ hash = hash ,
526+ branch = branch ,
525527 )
526528 if cached_prompt is None :
527529 # No cache exists, so we should write after fetching to bootstrap
@@ -537,13 +539,18 @@ def pull(
537539 if refresh :
538540 loop = _get_or_create_polling_loop ()
539541 asyncio .run_coroutine_threadsafe (
540- self .create_polling_task (version , label , hash , branch , refresh ), loop
542+ self .create_polling_task (version , label , hash , branch , refresh ),
543+ loop ,
541544 )
542545
543546 if default_to_cache :
544547 try :
545548 cached_prompt = self ._read_from_cache (
546- self .alias , version = version , label = label , hash = hash , branch = branch
549+ self .alias ,
550+ version = version ,
551+ label = label ,
552+ hash = hash ,
553+ branch = branch ,
547554 )
548555 if cached_prompt :
549556 with self ._lock :
@@ -592,7 +599,9 @@ def pull(
592599 HINT_TEXT = f"version={ version } "
593600 else :
594601 branch_name = branch or self .branch
595- HINT_TEXT = f"hash={ hash or 'latest' } , branch={ branch_name or 'main' } "
602+ HINT_TEXT = (
603+ f"hash={ hash or 'latest' } , branch={ branch_name or 'main' } "
604+ )
596605
597606 task_id = progress .add_task (
598607 f"Pulling [rgb(106,0,255)]'{ self .alias } ' ({ HINT_TEXT } )[/rgb(106,0,255)] from Confident AI..." ,
@@ -627,7 +636,7 @@ def pull(
627636 "alias" : self .alias ,
628637 "hash" : hash or "latest" ,
629638 },
630- params = {"branch" : branch or self .branch }
639+ params = {"branch" : branch or self .branch },
631640 )
632641
633642 response = PromptHttpResponse (
@@ -653,7 +662,7 @@ def pull(
653662 version = version ,
654663 label = label ,
655664 hash = hash ,
656- branch = branch
665+ branch = branch ,
657666 )
658667 return
659668 raise
@@ -861,28 +870,34 @@ def update(
861870
862871 def get_branches (self ) -> List [PromptBranch ]:
863872 if not self .alias :
864- raise ValueError ("Prompt alias is not set. Please set an alias to continue." )
865-
873+ raise ValueError (
874+ "Prompt alias is not set. Please set an alias to continue."
875+ )
876+
866877 api = Api (api_key = self .confident_api_key )
867-
878+
868879 data , _ = api .send_request (
869880 method = HttpMethods .GET ,
870881 endpoint = Endpoints .PROMPTS_BRANCHES_ENDPOINT ,
871- url_params = {"alias" : self .alias }
882+ url_params = {"alias" : self .alias },
872883 )
873-
884+
874885 response = PromptBranchesHttpResponse (** data )
875886 return response .branches or []
876887
877888 def create_branch (self , branch : str , _verbose : Optional [bool ] = True ):
878889 if not self .alias :
879- raise ValueError ("Prompt alias is not set. Please set an alias to continue." )
880-
890+ raise ValueError (
891+ "Prompt alias is not set. Please set an alias to continue."
892+ )
893+
881894 api = Api (api_key = self .confident_api_key )
882-
895+
883896 body = PromptCreateBranchRequest (branch = branch )
884897 try :
885- body_dict = body .model_dump (by_alias = True , exclude_none = True , mode = "json" )
898+ body_dict = body .model_dump (
899+ by_alias = True , exclude_none = True , mode = "json"
900+ )
886901 except AttributeError :
887902 body_dict = body .dict (by_alias = True , exclude_none = True )
888903
@@ -892,7 +907,7 @@ def create_branch(self, branch: str, _verbose: Optional[bool] = True):
892907 url_params = {"alias" : self .alias },
893908 body = body_dict ,
894909 )
895-
910+
896911 self .branch = branch
897912
898913 if _verbose :
@@ -902,29 +917,35 @@ def create_branch(self, branch: str, _verbose: Optional[bool] = True):
902917 f"[link={ link } ]{ link } [/link]"
903918 )
904919
905- def update_branch (self , name : str , branch : Optional [str ] = None , _verbose : Optional [bool ] = True ):
920+ def update_branch (
921+ self ,
922+ name : str ,
923+ branch : Optional [str ] = None ,
924+ _verbose : Optional [bool ] = True ,
925+ ):
906926 if not self .alias :
907- raise ValueError ("Prompt alias is not set. Please set an alias to continue." )
908-
927+ raise ValueError (
928+ "Prompt alias is not set. Please set an alias to continue."
929+ )
930+
909931 branch_to_update = branch or self .branch
910932 if branch_to_update == "main" :
911933 raise ValueError ("Cannot update the name of the main branch." )
912934
913935 api = Api (api_key = self .confident_api_key )
914-
936+
915937 body = PromptUpdateBranchRequest (name = name )
916938 try :
917- body_dict = body .model_dump (by_alias = True , exclude_none = True , mode = "json" )
939+ body_dict = body .model_dump (
940+ by_alias = True , exclude_none = True , mode = "json"
941+ )
918942 except AttributeError :
919943 body_dict = body .dict (by_alias = True , exclude_none = True )
920944
921945 api .send_request (
922946 method = HttpMethods .PUT ,
923947 endpoint = Endpoints .PROMPTS_BRANCH_ENDPOINT ,
924- url_params = {
925- "alias" : self .alias ,
926- "name" : branch_to_update
927- },
948+ url_params = {"alias" : self .alias , "name" : branch_to_update },
928949 body = body_dict ,
929950 )
930951
@@ -934,12 +955,18 @@ def update_branch(self, name: str, branch: Optional[str] = None, _verbose: Optio
934955
935956 if _verbose :
936957 console = Console ()
937- console .print (f"✅ Successfully renamed branch '{ branch_to_update } ' to '{ name } '." )
958+ console .print (
959+ f"✅ Successfully renamed branch '{ branch_to_update } ' to '{ name } '."
960+ )
938961
939- def delete_branch (self , branch : Optional [str ] = None , _verbose : Optional [bool ] = True ):
962+ def delete_branch (
963+ self , branch : Optional [str ] = None , _verbose : Optional [bool ] = True
964+ ):
940965 if not self .alias :
941- raise ValueError ("Prompt alias is not set. Please set an alias to continue." )
942-
966+ raise ValueError (
967+ "Prompt alias is not set. Please set an alias to continue."
968+ )
969+
943970 branch_to_delete = branch or self .branch
944971 if branch_to_delete == "main" :
945972 raise ValueError ("Cannot delete the main branch." )
@@ -949,10 +976,7 @@ def delete_branch(self, branch: Optional[str] = None, _verbose: Optional[bool] =
949976 api .send_request (
950977 method = HttpMethods .DELETE ,
951978 endpoint = Endpoints .PROMPTS_BRANCH_ENDPOINT ,
952- url_params = {
953- "alias" : self .alias ,
954- "name" : branch_to_delete
955- },
979+ url_params = {"alias" : self .alias , "name" : branch_to_delete },
956980 )
957981
958982 # If we deleted the branch this instance is currently tracking, safely fall back to tracking "main"
@@ -961,7 +985,9 @@ def delete_branch(self, branch: Optional[str] = None, _verbose: Optional[bool] =
961985
962986 if _verbose :
963987 console = Console ()
964- console .print (f"✅ Successfully deleted branch '{ branch_to_delete } '." )
988+ console .print (
989+ f"✅ Successfully deleted branch '{ branch_to_delete } '."
990+ )
965991
966992 ############################################
967993 ### Polling
@@ -1063,7 +1089,7 @@ async def poll(
10631089 "alias" : self .alias ,
10641090 "hash" : hash or "latest" ,
10651091 },
1066- params = {"branch" : branch or self .branch }
1092+ params = {"branch" : branch or self .branch },
10671093 )
10681094
10691095 response = PromptHttpResponse (
@@ -1079,7 +1105,7 @@ async def poll(
10791105 output_type = data .get ("outputType" , None ),
10801106 output_schema = data .get ("outputSchema" , None ),
10811107 tools = data .get ("tools" , None ),
1082- branch = data .get ("branch" , None )
1108+ branch = data .get ("branch" , None ),
10831109 )
10841110
10851111 # Update the cache with fresh data from server
0 commit comments