Skip to content

Commit db0cfba

Browse files
committed
reformat
1 parent 7505ed4 commit db0cfba

3 files changed

Lines changed: 83 additions & 49 deletions

File tree

deepeval/prompt/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,19 @@ class PromptApi(BaseModel):
272272
id: str
273273
type: PromptType
274274

275+
275276
class PromptBranch(BaseModel):
276277
id: str
277278
name: str
278279

280+
279281
class PromptBranchesHttpResponse(BaseModel):
280282
branches: List[PromptBranch]
281283

284+
282285
class PromptCreateBranchRequest(BaseModel):
283286
branch: str
284287

288+
285289
class PromptUpdateBranchRequest(BaseModel):
286290
name: str

deepeval/prompt/prompt.py

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def _read_from_cache(
357357

358358
def _write_to_cache(
359359
self,
360-
cache_key: Literal[VERSION_CACHE_KEY, LABEL_CACHE_KEY, HASH_CACHE_KEY, BRANCH_CACHE_KEY],
360+
cache_key: Literal[
361+
VERSION_CACHE_KEY, LABEL_CACHE_KEY, HASH_CACHE_KEY, BRANCH_CACHE_KEY
362+
],
361363
hash: str,
362364
version: Optional[str] = None,
363365
label: Optional[str] = None,
@@ -459,11 +461,7 @@ def _load_from_cache_with_progress(
459461
Raises if unable to load from cache.
460462
"""
461463
cached_prompt = self._read_from_cache(
462-
self.alias,
463-
version=version,
464-
label=label,
465-
hash=hash,
466-
branch=branch
464+
self.alias, version=version, label=label, hash=hash, branch=branch
467465
)
468466
if not cached_prompt:
469467
raise ValueError("Unable to fetch prompt and load from cache")
@@ -521,7 +519,11 @@ def pull(
521519
if refresh:
522520
# Check if we need to bootstrap the cache
523521
cached_prompt = self._read_from_cache(
524-
self.alias, version=version, label=label, hash=hash, branch=branch
522+
self.alias,
523+
version=version,
524+
label=label,
525+
hash=hash,
526+
branch=branch,
525527
)
526528
if cached_prompt is None:
527529
# No cache exists, so we should write after fetching to bootstrap
@@ -537,13 +539,18 @@ def pull(
537539
if refresh:
538540
loop = _get_or_create_polling_loop()
539541
asyncio.run_coroutine_threadsafe(
540-
self.create_polling_task(version, label, hash, branch, refresh), loop
542+
self.create_polling_task(version, label, hash, branch, refresh),
543+
loop,
541544
)
542545

543546
if default_to_cache:
544547
try:
545548
cached_prompt = self._read_from_cache(
546-
self.alias, version=version, label=label, hash=hash, branch=branch
549+
self.alias,
550+
version=version,
551+
label=label,
552+
hash=hash,
553+
branch=branch,
547554
)
548555
if cached_prompt:
549556
with self._lock:
@@ -592,7 +599,9 @@ def pull(
592599
HINT_TEXT = f"version={version}"
593600
else:
594601
branch_name = branch or self.branch
595-
HINT_TEXT = f"hash={hash or 'latest'}, branch={branch_name or 'main'}"
602+
HINT_TEXT = (
603+
f"hash={hash or 'latest'}, branch={branch_name or 'main'}"
604+
)
596605

597606
task_id = progress.add_task(
598607
f"Pulling [rgb(106,0,255)]'{self.alias}' ({HINT_TEXT})[/rgb(106,0,255)] from Confident AI...",
@@ -627,7 +636,7 @@ def pull(
627636
"alias": self.alias,
628637
"hash": hash or "latest",
629638
},
630-
params={"branch": branch or self.branch}
639+
params={"branch": branch or self.branch},
631640
)
632641

633642
response = PromptHttpResponse(
@@ -653,7 +662,7 @@ def pull(
653662
version=version,
654663
label=label,
655664
hash=hash,
656-
branch=branch
665+
branch=branch,
657666
)
658667
return
659668
raise
@@ -861,28 +870,34 @@ def update(
861870

862871
def get_branches(self) -> List[PromptBranch]:
863872
if not self.alias:
864-
raise ValueError("Prompt alias is not set. Please set an alias to continue.")
865-
873+
raise ValueError(
874+
"Prompt alias is not set. Please set an alias to continue."
875+
)
876+
866877
api = Api(api_key=self.confident_api_key)
867-
878+
868879
data, _ = api.send_request(
869880
method=HttpMethods.GET,
870881
endpoint=Endpoints.PROMPTS_BRANCHES_ENDPOINT,
871-
url_params={"alias": self.alias}
882+
url_params={"alias": self.alias},
872883
)
873-
884+
874885
response = PromptBranchesHttpResponse(**data)
875886
return response.branches or []
876887

877888
def create_branch(self, branch: str, _verbose: Optional[bool] = True):
878889
if not self.alias:
879-
raise ValueError("Prompt alias is not set. Please set an alias to continue.")
880-
890+
raise ValueError(
891+
"Prompt alias is not set. Please set an alias to continue."
892+
)
893+
881894
api = Api(api_key=self.confident_api_key)
882-
895+
883896
body = PromptCreateBranchRequest(branch=branch)
884897
try:
885-
body_dict = body.model_dump(by_alias=True, exclude_none=True, mode="json")
898+
body_dict = body.model_dump(
899+
by_alias=True, exclude_none=True, mode="json"
900+
)
886901
except AttributeError:
887902
body_dict = body.dict(by_alias=True, exclude_none=True)
888903

@@ -892,7 +907,7 @@ def create_branch(self, branch: str, _verbose: Optional[bool] = True):
892907
url_params={"alias": self.alias},
893908
body=body_dict,
894909
)
895-
910+
896911
self.branch = branch
897912

898913
if _verbose:
@@ -902,29 +917,35 @@ def create_branch(self, branch: str, _verbose: Optional[bool] = True):
902917
f"[link={link}]{link}[/link]"
903918
)
904919

905-
def update_branch(self, name: str, branch: Optional[str] = None, _verbose: Optional[bool] = True):
920+
def update_branch(
921+
self,
922+
name: str,
923+
branch: Optional[str] = None,
924+
_verbose: Optional[bool] = True,
925+
):
906926
if not self.alias:
907-
raise ValueError("Prompt alias is not set. Please set an alias to continue.")
908-
927+
raise ValueError(
928+
"Prompt alias is not set. Please set an alias to continue."
929+
)
930+
909931
branch_to_update = branch or self.branch
910932
if branch_to_update == "main":
911933
raise ValueError("Cannot update the name of the main branch.")
912934

913935
api = Api(api_key=self.confident_api_key)
914-
936+
915937
body = PromptUpdateBranchRequest(name=name)
916938
try:
917-
body_dict = body.model_dump(by_alias=True, exclude_none=True, mode="json")
939+
body_dict = body.model_dump(
940+
by_alias=True, exclude_none=True, mode="json"
941+
)
918942
except AttributeError:
919943
body_dict = body.dict(by_alias=True, exclude_none=True)
920944

921945
api.send_request(
922946
method=HttpMethods.PUT,
923947
endpoint=Endpoints.PROMPTS_BRANCH_ENDPOINT,
924-
url_params={
925-
"alias": self.alias,
926-
"name": branch_to_update
927-
},
948+
url_params={"alias": self.alias, "name": branch_to_update},
928949
body=body_dict,
929950
)
930951

@@ -934,12 +955,18 @@ def update_branch(self, name: str, branch: Optional[str] = None, _verbose: Optio
934955

935956
if _verbose:
936957
console = Console()
937-
console.print(f"✅ Successfully renamed branch '{branch_to_update}' to '{name}'.")
958+
console.print(
959+
f"✅ Successfully renamed branch '{branch_to_update}' to '{name}'."
960+
)
938961

939-
def delete_branch(self, branch: Optional[str] = None, _verbose: Optional[bool] = True):
962+
def delete_branch(
963+
self, branch: Optional[str] = None, _verbose: Optional[bool] = True
964+
):
940965
if not self.alias:
941-
raise ValueError("Prompt alias is not set. Please set an alias to continue.")
942-
966+
raise ValueError(
967+
"Prompt alias is not set. Please set an alias to continue."
968+
)
969+
943970
branch_to_delete = branch or self.branch
944971
if branch_to_delete == "main":
945972
raise ValueError("Cannot delete the main branch.")
@@ -949,10 +976,7 @@ def delete_branch(self, branch: Optional[str] = None, _verbose: Optional[bool] =
949976
api.send_request(
950977
method=HttpMethods.DELETE,
951978
endpoint=Endpoints.PROMPTS_BRANCH_ENDPOINT,
952-
url_params={
953-
"alias": self.alias,
954-
"name": branch_to_delete
955-
},
979+
url_params={"alias": self.alias, "name": branch_to_delete},
956980
)
957981

958982
# If we deleted the branch this instance is currently tracking, safely fall back to tracking "main"
@@ -961,7 +985,9 @@ def delete_branch(self, branch: Optional[str] = None, _verbose: Optional[bool] =
961985

962986
if _verbose:
963987
console = Console()
964-
console.print(f"✅ Successfully deleted branch '{branch_to_delete}'.")
988+
console.print(
989+
f"✅ Successfully deleted branch '{branch_to_delete}'."
990+
)
965991

966992
############################################
967993
### Polling
@@ -1063,7 +1089,7 @@ async def poll(
10631089
"alias": self.alias,
10641090
"hash": hash or "latest",
10651091
},
1066-
params={"branch": branch or self.branch}
1092+
params={"branch": branch or self.branch},
10671093
)
10681094

10691095
response = PromptHttpResponse(
@@ -1079,7 +1105,7 @@ async def poll(
10791105
output_type=data.get("outputType", None),
10801106
output_schema=data.get("outputSchema", None),
10811107
tools=data.get("tools", None),
1082-
branch=data.get("branch", None)
1108+
branch=data.get("branch", None),
10831109
)
10841110

10851111
# Update the cache with fresh data from server

tests/test_confident/test_prompt.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,14 @@ def test_create_branch(self):
619619
branch_names = [branch.name for branch in branches]
620620

621621
assert new_branch_name in branch_names
622-
622+
623623
def test_update_branch(self):
624624
UUID = str(uuid.uuid4())
625625
old_branch_name = f"old-branch-{UUID}"
626626
new_branch_name = f"new-branch-{UUID}"
627627

628628
prompt = Prompt(alias=self.BRANCH_ALIAS)
629-
629+
630630
prompt.create_branch(branch=old_branch_name)
631631

632632
# Pull all branches
@@ -663,7 +663,6 @@ def test_delete_branch(self):
663663
assert new_branch_name not in new_branch_names
664664

665665

666-
667666
class TestPromptList:
668667
ALIAS = "test_prompt_list"
669668
ALIAS_WITH_INTERPOLATION_TYPE = "test_prompt_list_interpolation_type"
@@ -1231,11 +1230,16 @@ def test_branch_push(self):
12311230
"""Test pushing to a new branch and main branch by default"""
12321231
prompt = Prompt(alias=self.BRANCH_ALIAS)
12331232
# Push to main branch
1234-
prompt.push(messages=[PromptMessage(role="user", content="New branch push")])
1233+
prompt.push(
1234+
messages=[PromptMessage(role="user", content="New branch push")]
1235+
)
12351236
first_branch_hash = prompt._hash
12361237

12371238
# Push to different branch
1238-
prompt.push(messages=[PromptMessage(role="user", content="New branch push")], branch=self.BRANCH_NAME)
1239+
prompt.push(
1240+
messages=[PromptMessage(role="user", content="New branch push")],
1241+
branch=self.BRANCH_NAME,
1242+
)
12391243
second_branch_hash = prompt._hash
12401244

12411245
main_commits = prompt._get_commits(branch="main")
@@ -1259,14 +1263,14 @@ def test_create_branch(self):
12591263
branch_names = [branch.name for branch in branches]
12601264

12611265
assert new_branch_name in branch_names
1262-
1266+
12631267
def test_update_branch(self):
12641268
UUID = str(uuid.uuid4())
12651269
old_branch_name = f"old-branch-{UUID}"
12661270
new_branch_name = f"new-branch-{UUID}"
12671271

12681272
prompt = Prompt(alias=self.BRANCH_ALIAS)
1269-
1273+
12701274
prompt.create_branch(branch=old_branch_name)
12711275

12721276
# Pull all branches

0 commit comments

Comments
 (0)