Skip to content

Commit dc2397b

Browse files
committed
Refactor annotation extraction to jointly extract all at once, enable parallel processing, and fix cache
1 parent a44d9da commit dc2397b

9 files changed

Lines changed: 1187 additions & 203 deletions

File tree

autonima/annotation/client.py

Lines changed: 187 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
import json
44
import logging
5-
from typing import Dict, Any, Optional
5+
from typing import List, Dict, Any
66
from pydantic import BaseModel
7-
import sys
8-
import os
97
from ..llm.client import GenericLLMClient
108
from .schema import AnalysisMetadata, AnnotationCriteriaConfig, AnnotationDecision
119

@@ -16,6 +14,22 @@ class AnnotationDecisionOutput(BaseModel):
1614
"""Output schema for annotation decision."""
1715
include: bool
1816
reasoning: str
17+
inclusion_criteria_applied: List[str] = []
18+
exclusion_criteria_applied: List[str] = []
19+
20+
21+
class MultiAnnotationDecisionOutput(BaseModel):
22+
"""Output schema for multiple annotation decisions."""
23+
annotation_name: str
24+
include: bool
25+
reasoning: str
26+
inclusion_criteria_applied: List[str] = []
27+
exclusion_criteria_applied: List[str] = []
28+
29+
30+
class MultiAnnotationDecisionOutputList(BaseModel):
31+
"""Output schema for list of multiple annotation decisions."""
32+
decisions: List[MultiAnnotationDecisionOutput]
1933

2034

2135
class AnnotationClient:
@@ -25,6 +39,63 @@ def __init__(self):
2539
"""Initialize the annotation client."""
2640
self._client = GenericLLMClient()
2741

42+
def _generate_function_schema(
43+
self,
44+
model_class: BaseModel,
45+
function_name: str
46+
) -> Dict[str, Any]:
47+
"""Generate OpenAI function schema from Pydantic model.
48+
49+
Args:
50+
model_class: Pydantic model class
51+
function_name: Name for the function
52+
53+
Returns:
54+
Dict representing the OpenAI function schema
55+
"""
56+
schema = model_class.model_json_schema()
57+
58+
# Convert JSON schema to OpenAI function schema
59+
properties = {}
60+
required = []
61+
62+
for field_name, field_info in schema.get("properties", {}).items():
63+
properties[field_name] = {
64+
"type": field_info["type"],
65+
"description": field_info.get("description", "")
66+
}
67+
68+
# Handle enum values
69+
if "enum" in field_info:
70+
properties[field_name]["enum"] = field_info["enum"]
71+
72+
# Handle array items
73+
if field_info["type"] == "array" and "items" in field_info:
74+
properties[field_name]["items"] = field_info["items"]
75+
76+
# Handle numeric constraints
77+
if field_info["type"] == "number":
78+
if "minimum" in field_info:
79+
properties[field_name]["minimum"] = field_info["minimum"]
80+
if "maximum" in field_info:
81+
properties[field_name]["maximum"] = field_info["maximum"]
82+
83+
# Get required fields
84+
required = schema.get("required", [])
85+
86+
# Create description for the function
87+
description = "Make an annotation decision for a neuroimaging analysis"
88+
89+
return {
90+
"name": function_name,
91+
"description": description,
92+
"parameters": {
93+
"type": "object",
94+
"properties": properties,
95+
"required": required
96+
}
97+
}
98+
2899
def make_decision(
29100
self,
30101
metadata: AnalysisMetadata,
@@ -49,30 +120,41 @@ def make_decision(
49120
metadata_fields = getattr(criteria, 'metadata_fields', None)
50121
prompt = create_annotation_prompt(metadata, criteria, metadata_fields)
51122

52-
# Get the response from the LLM
53-
response_text = self.chat_completion(
123+
# Generate function schema from Pydantic model
124+
func_name = "make_annotation_decision"
125+
function_schema = self._generate_function_schema(
126+
AnnotationDecisionOutput,
127+
func_name
128+
)
129+
130+
# Call the LLM API with function calling
131+
response = self._client.client.chat.completions.create(
132+
model=model,
54133
messages=[
55134
{
56135
"role": "system",
57-
"content": "You are a neuroimaging meta-analysis expert."
136+
"content": (
137+
"You are a neuroimaging meta-analysis expert. "
138+
"Respond using the make_annotation_decision function."
139+
)
58140
},
59141
{
60142
"role": "user",
61143
"content": prompt
62144
}
63145
],
64-
model=model,
65-
response_format={"type": "json_object"}
146+
functions=[function_schema],
147+
function_call={"name": func_name}
66148
)
67149

68-
# Parse the response
69-
try:
70-
response_data = json.loads(response_text)
71-
decision_output = AnnotationDecisionOutput(**response_data)
72-
except Exception as e:
73-
logger.warning(f"Failed to parse JSON response: {e}. Response: {response}")
74-
# Try to extract the information manually
75-
decision_output = self._parse_response_manually(response)
150+
# Extract the function call result
151+
function_call = response.choices[0].message.function_call
152+
if not function_call:
153+
raise ValueError("No function call returned from API")
154+
155+
# Parse the result
156+
result_dict = json.loads(function_call.arguments)
157+
decision_output = AnnotationDecisionOutput(**result_dict)
76158

77159
# Create the annotation decision
78160
decision = AnnotationDecision(
@@ -81,7 +163,9 @@ def make_decision(
81163
study_id=metadata.study_id,
82164
include=decision_output.include,
83165
reasoning=decision_output.reasoning,
84-
model_used=model
166+
model_used=model,
167+
inclusion_criteria_applied=decision_output.inclusion_criteria_applied,
168+
exclusion_criteria_applied=decision_output.exclusion_criteria_applied
85169
)
86170

87171
return decision
@@ -98,51 +182,99 @@ def make_decision(
98182
model_used=model
99183
)
100184

101-
def _parse_response_manually(self, response: str) -> AnnotationDecisionOutput:
185+
def make_multi_decision(
186+
self,
187+
metadata: AnalysisMetadata,
188+
criteria_list: List[AnnotationCriteriaConfig],
189+
model: str = "gpt-4o-mini"
190+
) -> List[AnnotationDecision]:
102191
"""
103-
Attempt to parse the response manually if JSON parsing fails.
192+
Make decisions about whether an analysis should be included in multiple annotations.
104193
105194
Args:
106-
response: Raw response string from the LLM
195+
metadata: Analysis metadata
196+
criteria_list: List of annotation criteria configurations
197+
model: LLM model to use
107198
108199
Returns:
109-
Annotation decision output
200+
List of annotation decisions
110201
"""
111-
# Default to excluding if we can't parse
112-
include = False
113-
reasoning = "Failed to parse response"
202+
if not criteria_list:
203+
return []
114204

115-
# Simple heuristics to extract information
116-
response_lower = response.lower()
117-
118-
# Look for inclusion indicators
119-
if '"include": true' in response_lower or '"include":true' in response_lower:
120-
include = True
121-
elif '"include": false' in response_lower or '"include":false' in response_lower:
122-
include = False
123-
124-
# Try to extract reasoning
125-
if '"reasoning":' in response:
126-
try:
127-
# Find the reasoning part
128-
start = response.find('"reasoning":') + len('"reasoning":')
129-
if response[start] == '"':
130-
start += 1
131-
end = response.find('"', start)
132-
reasoning = response[start:end]
133-
else:
134-
# Handle non-string reasoning (shouldn't happen with our prompt)
135-
end = response.find('}', start)
136-
if end == -1:
137-
end = len(response)
138-
reasoning = response[start:end].strip()
139-
# Remove trailing comma if present
140-
if reasoning.endswith(','):
141-
reasoning = reasoning[:-1]
142-
except Exception:
143-
pass
144-
145-
return AnnotationDecisionOutput(include=include, reasoning=reasoning)
205+
try:
206+
# Create the prompt for all annotations at once
207+
from .prompts import create_multi_annotation_prompt
208+
# Get metadata_fields from the first criteria or use default
209+
metadata_fields = getattr(criteria_list[0], 'metadata_fields', None)
210+
prompt = create_multi_annotation_prompt(metadata, criteria_list, metadata_fields)
211+
212+
# Generate function schema from Pydantic model
213+
func_name = "make_multi_annotation_decisions"
214+
function_schema = self._generate_function_schema(
215+
MultiAnnotationDecisionOutputList,
216+
func_name
217+
)
218+
219+
# Call the LLM API with function calling
220+
response = self._client.client.chat.completions.create(
221+
model=model,
222+
messages=[
223+
{
224+
"role": "system",
225+
"content": (
226+
"You are a neuroimaging meta-analysis expert. "
227+
"Respond using the make_multi_annotation_decisions function."
228+
)
229+
},
230+
{
231+
"role": "user",
232+
"content": prompt
233+
}
234+
],
235+
functions=[function_schema],
236+
function_call={"name": func_name}
237+
)
238+
239+
# Extract the function call result
240+
function_call = response.choices[0].message.function_call
241+
if not function_call:
242+
raise ValueError("No function call returned from API")
243+
244+
# Parse the result
245+
result_dict = json.loads(function_call.arguments)
246+
decision_list_output = MultiAnnotationDecisionOutputList(**result_dict)
247+
decision_outputs = decision_list_output.decisions
248+
249+
# Create the annotation decisions
250+
decisions = []
251+
for i, decision_output in enumerate(decision_outputs):
252+
if i < len(criteria_list):
253+
criteria = criteria_list[i]
254+
decision = AnnotationDecision(
255+
annotation_name=decision_output.annotation_name or criteria.name,
256+
analysis_id=metadata.analysis_id,
257+
study_id=metadata.study_id,
258+
include=decision_output.include,
259+
reasoning=decision_output.reasoning,
260+
model_used=model,
261+
inclusion_criteria_applied=decision_output.inclusion_criteria_applied,
262+
exclusion_criteria_applied=decision_output.exclusion_criteria_applied
263+
)
264+
decisions.append(decision)
265+
266+
# If we didn't get enough responses, fill in with individual decisions
267+
while len(decisions) < len(criteria_list):
268+
criteria = criteria_list[len(decisions)]
269+
decision = self.make_decision(metadata, criteria, model)
270+
decisions.append(decision)
271+
272+
return decisions
273+
274+
except Exception as e:
275+
logger.error(f"Error making multi annotation decisions: {e}")
276+
# Return individual decisions as fallback
277+
return [self.make_decision(metadata, criteria, model) for criteria in criteria_list]
146278

147279
def chat_completion(self, messages, model, response_format=None):
148280
"""

0 commit comments

Comments
 (0)