Skip to content

Commit 86e2366

Browse files
committed
update the bench to support parquet
1 parent 4f0a4bf commit 86e2366

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

camel/benchmarks/gaia.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pathlib import Path
2323
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
2424

25+
import pandas as pd
2526
from tqdm import tqdm
2627

2728
from camel.agents import ChatAgent
@@ -181,15 +182,17 @@ def load(self, force_download=False):
181182
# Load metadata for both validation and test datasets
182183
for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
183184
self._data[label] = []
184-
with open(path / "metadata.jsonl", "r") as f:
185-
lines = f.readlines()
186-
for line in lines:
187-
data = json.loads(line)
188-
if data["task_id"] == "0-0-0-0-0":
189-
continue
190-
if data["file_name"]:
191-
data["file_name"] = path / data["file_name"]
192-
self._data[label].append(data)
185+
metadata_file = path / "metadata.parquet"
186+
df = pd.read_parquet(metadata_file)
187+
for _, row in df.iterrows():
188+
data = row.to_dict()
189+
if data["task_id"] == "0-0-0-0-0":
190+
continue
191+
# convert level to int (parquet stores as string)
192+
data["Level"] = int(data["Level"])
193+
if data["file_name"]:
194+
data["file_name"] = path / data["file_name"]
195+
self._data[label].append(data)
193196
return self
194197

195198
@property
@@ -333,7 +336,7 @@ def _process_result(
333336
}
334337
self._results.append(result_data)
335338
file_obj.write(
336-
json.dumps(result_data, indent=2) + "\n", ensure_ascii=False
339+
json.dumps(result_data, indent=2, ensure_ascii=False) + "\n"
337340
)
338341
file_obj.flush()
339342

@@ -354,7 +357,7 @@ def _handle_error(
354357
}
355358
self._results.append(error_data)
356359
file_obj.write(
357-
json.dumps(error_data, indent=2) + "\n", ensure_ascii=False
360+
json.dumps(error_data, indent=2, ensure_ascii=False) + "\n"
358361
)
359362
file_obj.flush()
360363

0 commit comments

Comments
 (0)