Skip to content

Commit 2148000

Browse files
authored
Merge pull request #147 from monarch-initiative/prep_for_0.2.4
Prep for 0.2.4 release
2 parents 98a1cd6 + 70aab15 commit 2148000

19 files changed

+2998
-2901
lines changed

poetry.lock

Lines changed: 2801 additions & 2788 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "curategpt"
3-
version = "0.2.3"
3+
version = "0.2.4"
44
description = "CurateGPT"
55
authors = ["Chris Mungall <cjmungall@lbl.gov>", "Carlo Kroll <ckroll95@gmail.com>", "Harshad Hegde <hhegde@lbl.gov>", "J. Harry Caufield <jhc@lbl.gov>"]
66
license = "BSD-3"

src/curategpt/agents/agent_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def select_from_options_prompt(
3232
"""
3333
if prompt_template is None:
3434
if query is None:
35-
raise ValueError("Either query or prompt_template must be specified.")
35+
raise ValueError(
36+
"Either query or prompt_template must be specified.")
3637
prompt_template = "I will first give background facts, then ask a question."
3738
prompt_template += "Use the background fact to answer\n"
3839
prompt_template += "---\nBackground facts:\n"
@@ -53,7 +54,8 @@ def select_from_options_prompt(
5354
current_length = 0
5455
for obj, _, _obj_meta in kb_results:
5556
i += 1
56-
obj_text = yaml.dump({k: v for k, v in obj.items() if v}, sort_keys=False)
57+
obj_text = yaml.dump(
58+
{k: v for k, v in obj.items() if v}, sort_keys=False)
5759
references[str(i)] = obj_text
5860
objects[str(i)] = obj
5961
if id_field and id_field in obj:
@@ -65,7 +67,8 @@ def select_from_options_prompt(
6567
prompt = prompt_template.format(body="".join(texts), query=query)
6668
logger.info(f"Prompt: {prompt}")
6769
estimated_length = estimate_num_tokens([prompt])
68-
logger.debug(f"Max tokens {model.model_id}: {max_tokens_by_model(model.model_id)}")
70+
logger.debug(
71+
f"Max tokens {model.model_id}: {max_tokens_by_model(model.model_id)}")
6972
# TODO: use a more precise estimate of the length
7073
if estimated_length + 300 < max_tokens_by_model(model.model_id):
7174
break

src/curategpt/agents/base_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ class BaseAgent(ABC): # noqa: B024
2828
"""Engine performing LLM operations, including extracting from prompt responses"""
2929

3030
def search(self):
31-
raise NotImplementedError("Search method must be implemented by subclass")
31+
raise NotImplementedError(
32+
"Search method must be implemented by subclass")

src/curategpt/agents/concept_recognition_agent.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Annotation (Concept Recognition) in texts."""
22

33
import logging
4+
import re
45
from dataclasses import dataclass
56
from enum import Enum
67
from typing import Dict, List, Optional, Tuple
@@ -129,9 +130,6 @@ class AnnotatedText(BaseModel):
129130
"""
130131

131132

132-
import re
133-
134-
135133
def parse_annotations(text, marker_char: str = None) -> List[CONCEPT]:
136134
"""
137135
Parse annotations from text.
@@ -283,7 +281,8 @@ def ground_concept(
283281
)
284282
spans.append(span)
285283
# spans = parse_spans(response.text(), concept_dict)
286-
ann = GroundingResult(input_text=text, annotated_text=response.text(), spans=spans)
284+
ann = GroundingResult(
285+
input_text=text, annotated_text=response.text(), spans=spans)
287286
return ann
288287

289288
def annotate(
@@ -340,7 +339,8 @@ def annotate_two_pass(
340339
**kwargs,
341340
)
342341
if not concept.spans:
343-
logger.debug(f"Unable to ground concept {term} in category {category}")
342+
logger.debug(
343+
f"Unable to ground concept {term} in category {category}")
344344
continue
345345
main_span = concept.spans[0]
346346
spans.append(
@@ -376,7 +376,8 @@ def annotate_inline(
376376

377377
logger.info(f"Anns: {anns}")
378378
spans = [
379-
Span(text=ann[0], concept_id=ann[1], concept_label=concept_dict.get(ann[1], None))
379+
Span(text=ann[0], concept_id=ann[1],
380+
concept_label=concept_dict.get(ann[1], None))
380381
for ann in anns
381382
]
382383
return AnnotatedText(input_text=text, spans=spans, annotated_text=response.text())
@@ -436,7 +437,8 @@ def _label_id_pairs_prompt_section(
436437
if not id:
437438
raise ValueError(f"Object {obj} has no ID field {id_field}")
438439
if not label:
439-
raise ValueError(f"Object {obj} has no label field {label_field}")
440+
raise ValueError(
441+
f"Object {obj} has no label field {label_field}")
440442
prompt += f"{label} // {id} \n"
441443
concept_pairs.append((id, label))
442444
return concept_pairs, prompt

src/curategpt/agents/dase_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ class DatabaseAugmentedStructuredExtraction(BaseAgent):
6565
background_document_limit: int = 3
6666
"""Number of background documents to use. TODO: more sophisticated way to estimate."""
6767

68-
default_masked_fields: List[str] = field(default_factory=lambda: ["original_id"])
68+
default_masked_fields: List[str] = field(
69+
default_factory=lambda: ["original_id"])
6970

7071
def extract(
7172
self,
@@ -135,7 +136,7 @@ def extract(
135136
):
136137
obj_text = obj_meta["document"]
137138
# TODO: use tiktoken to estimate
138-
obj_text = obj_text[0 : self.max_background_document_size]
139+
obj_text = obj_text[0: self.max_background_document_size]
139140
docs.append(obj_text)
140141
if generate_background:
141142
# prompt = f"Generate a comprehensive description about the {target_class} with {context_property} = {seed}"

src/curategpt/agents/dragon_agent.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ class DragonAgent(BaseAgent):
6767
background_document_limit: int = 3
6868
"""Number of background documents to use. TODO: more sophisticated way to estimate."""
6969

70-
default_masked_fields: List[str] = field(default_factory=lambda: ["original_id"])
70+
default_masked_fields: List[str] = field(
71+
default_factory=lambda: ["original_id"])
7172

7273
def complete(
7374
self,
@@ -139,10 +140,12 @@ def generate_input_str(obj: Union[str, Dict], prefix="Structured representation
139140
elif isinstance(obj, str):
140141
return f"{prefix} {target_class} with {context_property} = {obj}"
141142
else:
142-
raise ValueError(f"Invalid type for obj: {type(obj)} // {obj}")
143+
raise ValueError(
144+
f"Invalid type for obj: {type(obj)} // {obj}")
143145

144146
annotated_examples = []
145-
seed_search_term = seed if isinstance(seed, str) else yaml.safe_dump(seed, sort_keys=True)
147+
seed_search_term = seed if isinstance(
148+
seed, str) else yaml.safe_dump(seed, sort_keys=True)
146149
logger.debug(f"Searching for seed: {seed_search_term}")
147150
for obj, _, _obj_meta in self.knowledge_source.search(
148151
seed_search_term,
@@ -167,7 +170,8 @@ def generate_input_str(obj: Union[str, Dict], prefix="Structured representation
167170
f"Num examples={len(annotated_examples)}"
168171
)
169172
continue
170-
ae = AnnotatedObject(object=obj_predicted_part, annotations={"text": input_text})
173+
ae = AnnotatedObject(object=obj_predicted_part,
174+
annotations={"text": input_text})
171175
annotated_examples.append(ae)
172176
if not annotated_examples:
173177
logger.error(f"No suitable examples found for seed: {seed}")
@@ -181,7 +185,7 @@ def generate_input_str(obj: Union[str, Dict], prefix="Structured representation
181185
):
182186
obj_text = obj_meta["document"]
183187
# TODO: use tiktoken to estimate
184-
obj_text = obj_text[0 : self.max_background_document_size]
188+
obj_text = obj_text[0: self.max_background_document_size]
185189
docs.append(obj_text)
186190
gen_text = generate_input_str(seed)
187191
if generate_background:
@@ -249,7 +253,8 @@ def generate_all(
249253
continue
250254
curr_val = obj.get(field_to_predict, None)
251255
if missing_only and curr_val:
252-
logger.debug(f"Skipping; {field_to_predict} already present: {curr_val}")
256+
logger.debug(
257+
f"Skipping; {field_to_predict} already present: {curr_val}")
253258
continue
254259
ao = self.complete(obj, collection=collection, **kwargs)
255260
yield PredictedFieldValue(
@@ -286,7 +291,7 @@ def generate_queries(self, context_property="name", n=5, **kwargs) -> List[str]:
286291
response = self.extractor.model.prompt(prompt, system=system)
287292
txt = response.text()
288293
if "." in txt:
289-
txt = txt[0 : txt.index(".")]
294+
txt = txt[0: txt.index(".")]
290295
suggestions = txt.split(",")
291296
suggestions = [x.strip() for x in suggestions]
292297
ids_norm = [x.lower() for x in ids]
@@ -309,7 +314,8 @@ def review(
309314
:param obj:
310315
"""
311316
if fields_to_predict and not context_property:
312-
raise ValueError("context_property is required if fields_to_predict")
317+
raise ValueError(
318+
"context_property is required if fields_to_predict")
313319

314320
pk_val = obj.get(primary_key, None)
315321

@@ -360,7 +366,8 @@ def _obj_as_str(obj: dict) -> str:
360366
if not isinstance(ao.object, dict):
361367
logger.warning(f"Expected dict, got {ao.object}")
362368
if isinstance(ao.object, list):
363-
logger.warning(f"Taking first element of list of len {len(ao.object)}")
369+
logger.warning(
370+
f"Taking first element of list of len {len(ao.object)}")
364371
ao.object = ao.object[0]
365372
if isinstance(ao.object, dict):
366373
if context_property:

src/curategpt/agents/evidence_agent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class EvidenceAgent(BaseAgent):
3838

3939
chat_agent: Union[ChatAgent, BaseWrapper] = None
4040

41-
evidence_update_policy: EvidenceUpdatePolicyEnum = field(default=EvidenceUpdatePolicyEnum.skip)
41+
evidence_update_policy: EvidenceUpdatePolicyEnum = field(
42+
default=EvidenceUpdatePolicyEnum.skip)
4243

4344
def find_evidence(self, obj: Union[str, Dict]) -> ChatResponse:
4445
obj_as_str = obj if isinstance(obj, str) else object_as_yaml(obj)
@@ -78,7 +79,8 @@ def find_evidence_simple(self, query: str, limit: int = 10, **kwargs) -> Optiona
7879
current_length = 0
7980
for obj, _, _obj_meta in kb_results:
8081
i += 1
81-
obj_text = yaml.dump({k: v for k, v in obj.items() if v}, sort_keys=False)
82+
obj_text = yaml.dump(
83+
{k: v for k, v in obj.items() if v}, sort_keys=False)
8284
references[str(i)] = obj_text
8385
texts.append(f"## Reference\n{obj_text}")
8486
current_length += len(obj_text)
@@ -97,7 +99,8 @@ def find_evidence_simple(self, query: str, limit: int = 10, **kwargs) -> Optiona
9799
break
98100
else:
99101
# remove least relevant
100-
logger.debug(f"Removing least relevant of {len(kb_results)}: {kb_results[-1]}")
102+
logger.debug(
103+
f"Removing least relevant of {len(kb_results)}: {kb_results[-1]}")
101104
if not kb_results:
102105
raise ValueError(f"Prompt too long: {prompt}.")
103106
kb_results.pop()

src/curategpt/agents/huggingface_agent.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
HF_DOWNLOAD_PATH = Path(__file__).resolve().parents[4]
1414
HF_DOWNLOAD_PATH = HF_DOWNLOAD_PATH / "hf_download"
1515

16+
1617
@dataclass
1718
class HuggingFaceAgent:
1819

@@ -38,13 +39,17 @@ def upload(self, objects, metadata, repo_id, private=False, **kwargs):
3839
metadata_file = "metadata.yaml"
3940

4041
try:
41-
df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['document']) for obj in objects])
42+
df = pd.DataFrame(
43+
data=[(obj[0], obj[2]['_embeddings'], obj[2]['document']) for obj in objects])
4244
except Exception as e:
43-
raise ValueError(f"Creation of Dataframe not successful: {e}") from e
45+
raise ValueError(
46+
f"Creation of Dataframe not successful: {e}") from e
4447

4548
with ExitStack() as stack:
46-
tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
47-
tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))
49+
tmp_parquet = stack.enter_context(
50+
tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
51+
tmp_yaml = stack.enter_context(
52+
tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))
4853

4954
embedding_path = tmp_parquet.name
5055
metadata_path = tmp_yaml.name
@@ -56,8 +61,8 @@ def upload(self, objects, metadata, repo_id, private=False, **kwargs):
5661
self._create_repo(repo_id, private=private)
5762

5863
self._upload_files(repo_id, {
59-
embedding_path : repo_id + "/" + embedding_file,
60-
metadata_path : repo_id + "/" + metadata_file
64+
embedding_path: repo_id + "/" + embedding_file,
65+
metadata_path: repo_id + "/" + metadata_file
6166
})
6267

6368
def upload_duckdb(self, objects, metadata, repo_id, private=False, **kwargs):
@@ -74,13 +79,17 @@ def upload_duckdb(self, objects, metadata, repo_id, private=False, **kwargs):
7479
embedding_file = "embeddings.parquet"
7580
metadata_file = "metadata.yaml"
7681
try:
77-
df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects])
82+
df = pd.DataFrame(
83+
data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects])
7884
except Exception as e:
79-
raise ValueError(f"Creation of Dataframe not successful: {e}") from e
85+
raise ValueError(
86+
f"Creation of Dataframe not successful: {e}") from e
8087

8188
with ExitStack() as stack:
82-
tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
83-
tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))
89+
tmp_parquet = stack.enter_context(
90+
tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
91+
tmp_yaml = stack.enter_context(
92+
tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))
8493

8594
embedding_path = tmp_parquet.name
8695
metadata_path = tmp_yaml.name
@@ -92,8 +101,8 @@ def upload_duckdb(self, objects, metadata, repo_id, private=False, **kwargs):
92101
self._create_repo(repo_id, private=private)
93102

94103
self._upload_files(repo_id, {
95-
embedding_path : repo_id + "/" + embedding_file,
96-
metadata_path : repo_id + "/" + metadata_file
104+
embedding_path: repo_id + "/" + embedding_file,
105+
metadata_path: repo_id + "/" + metadata_file
97106
})
98107

99108
def _create_repo(self, repo_id: str, private: bool = False):
@@ -104,13 +113,15 @@ def _create_repo(self, repo_id: str, private: bool = False):
104113
:param private: Whether the repository is private.
105114
"""
106115
try:
107-
create_repo(repo_id=repo_id, token=self.token, repo_type="dataset", private=private)
108-
logger.info(f"Repository {repo_id} created successfully on Hugging Face.")
116+
create_repo(repo_id=repo_id, token=self.token,
117+
repo_type="dataset", private=private)
118+
logger.info(
119+
f"Repository {repo_id} created successfully on Hugging Face.")
109120
except Exception as e:
110-
logger.error(f"Failed to create repository {repo_id} on Hugging Face: {e}")
121+
logger.error(
122+
f"Failed to create repository {repo_id} on Hugging Face: {e}")
111123
raise
112124

113-
114125
def _upload_files(self, repo_id: str, files: Dict[str, str]):
115126
"""
116127
Upload files to a Hugging Face repository.
@@ -126,9 +137,11 @@ def _upload_files(self, repo_id: str, files: Dict[str, str]):
126137
repo_id=repo_id,
127138
repo_type="dataset",
128139
)
129-
logger.info(f"Uploaded {local_path} to {repo_path} in {repo_id}")
140+
logger.info(
141+
f"Uploaded {local_path} to {repo_path} in {repo_id}")
130142
except Exception as e:
131-
logger.error(f"Failed to upload files to {repo_id} on Hugging Face: {e}")
143+
logger.error(
144+
f"Failed to upload files to {repo_id} on Hugging Face: {e}")
132145
raise
133146

134147
def cached_download(
@@ -145,7 +158,3 @@ def cached_download(
145158
)
146159

147160
return download_path
148-
149-
150-
151-

0 commit comments

Comments
 (0)