Skip to content

Commit f4fb97b

Browse files
authored
Merge pull request #1628 from SYED-M-HUSSAIN/fix/gsm8k-tuple-response-attributeerror
Fix/gsm8k tuple response attribute error
2 parents 4d3ed75 + b75f228 commit f4fb97b

1 file changed

Lines changed: 35 additions & 6 deletions

File tree

deepeval/benchmarks/gsm8k/gsm8k.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Dict
1+
from typing import List, Optional, Dict, Union
22
from tqdm import tqdm
33

44
from deepeval.dataset import Golden
@@ -52,7 +52,10 @@ def evaluate(self, model: DeepEvalBaseLLM) -> Dict:
5252
for idx, golden in enumerate(
5353
tqdm(goldens, desc=f"Processing {self.n_problems} problems")
5454
):
55-
prediction, score = self.predict(model, golden).values()
55+
result = self.predict(model, golden)
56+
prediction = result["prediction"]
57+
score = result["score"]
58+
5659
if score:
5760
overall_correct_predictions += 1
5861
predictions_row.append(
@@ -94,14 +97,17 @@ def predict(self, model: DeepEvalBaseLLM, golden: Golden) -> Dict:
9497
)
9598

9699
# Enforced model generation
100+
prediction = None
97101
try:
98102
res: NumberSchema = model.generate(
99103
prompt=prompt, schema=NumberSchema
100104
)
101-
prediction = str(res.answer)
102-
except TypeError:
105+
prediction = self._extract_prediction_from_response(res)
106+
except (TypeError, AttributeError) as e:
107+
103108
prompt += f"\n\n{self.confinement_instructions}"
104-
prediction = model.generate(prompt)
109+
res = model.generate(prompt)
110+
prediction = self._extract_prediction_from_response(res)
105111

106112
# For native models, shouldn't happen but just in case
107113
if isinstance(prediction, tuple):
@@ -114,6 +120,29 @@ def predict(self, model: DeepEvalBaseLLM, golden: Golden) -> Dict:
114120

115121
return {"prediction": prediction, "score": score}
116122

123+
def _extract_prediction_from_response(self, res) -> str:
124+
"""
125+
Extract prediction from model response, handling various response types.
126+
"""
127+
# Case 1: Response has .answer attribute (NumberSchema case)
128+
if hasattr(res, 'answer'):
129+
return str(res.answer)
130+
131+
# Case 2: Response is a tuple
132+
elif isinstance(res, tuple):
133+
return self._extract_from_tuple(res)
134+
135+
else:
136+
return str(res)
137+
138+
def _extract_from_tuple(self, res: tuple) -> str:
139+
"""Extract prediction from tuple response."""
140+
if len(res) == 0:
141+
return ""
142+
first_elem = res[0]
143+
if hasattr(first_elem, 'answer'):
144+
return str(first_elem.answer)
145+
117146
def load_benchmark_dataset(self) -> List[Golden]:
118147
from datasets import load_dataset
119148

@@ -171,4 +200,4 @@ def print_verbose_logs(
171200
print("")
172201
print("=" * 70)
173202

174-
return verbose_logs
203+
return verbose_logs

0 commit comments

Comments
 (0)