-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathdataset_utils.py
More file actions
150 lines (125 loc) · 5.47 KB
/
dataset_utils.py
File metadata and controls
150 lines (125 loc) · 5.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Simple Dataset Utilities for Bio-Medical AI Competition
This module contains only the essential CureBenchDataset class and related utilities
for loading bio-medical datasets in the competition starter kit.
Note: Data should be preprocessed using preprocess_data.py to add dataset_type fields
before using this module.
"""
import json
import os
import sys
try:
import torch
from torch.utils.data import Dataset, DataLoader
except ImportError:
print("Warning: PyTorch not available. Some features may not work.")
# Create dummy classes for basic functionality
class Dataset:
pass
class DataLoader:
def __init__(self, *args, **kwargs):
pass
def read_and_process_json_file(file_path):
"""
Reads a JSON file and processes it into a standardized format.
Handles both single JSON objects and line-delimited JSON files.
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
# Try to read as line-delimited JSON first
try:
data = [json.loads(line) for line in file if line.strip()]
# If first item is a list, flatten it
if data and isinstance(data[0], list):
data = [item for sublist in data for item in sublist]
return data
except json.JSONDecodeError:
# If that fails, try reading as single JSON object
file.seek(0)
content = file.read()
data = json.loads(content)
return data
except FileNotFoundError:
print(f"Error: The file {file_path} was not found.")
return []
except json.JSONDecodeError as e:
print(f"Error: Failed to decode JSON from {file_path}: {e}")
return []
except Exception as e:
print(f"Error: Unexpected error reading {file_path}: {e}")
return []
class CureBenchDataset(Dataset):
"""
Dataset class for FDA drug labeling data.
This class handles loading and processing FDA drug labeling questions
for the bio-medical AI competition. It supports:
- Multiple choice questions
- Open-ended questions
- Drug name extraction tasks
- Subset filtering by FDA categories
Example usage:
dataset = CureBenchDataset("fda_data.json")
question, options, answer = dataset[0]
"""
def __init__(self, json_file):
"""
Initialize the FDA Drug Dataset.
Args:
json_file (str): Path to the JSON file containing FDA data
"""
# Load the data
self.data = read_and_process_json_file(json_file)
if not self.data:
print(f"Warning: No data loaded from {json_file}")
self.data = []
return
print(f"CureBenchDataset initialized with {len(self.data)} examples")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""
Get a single example from the dataset.
Returns:
- For multiple choice: (question, options, answer) or (question, options, answer, drug_names)
- For open-ended: (question, answer, id)
"""
if idx >= len(self.data):
raise IndexError(f"Index {idx} out of range for dataset of size {len(self.data)}")
item = self.data[idx]
# Extract basic fields
question_type = item['question_type']
question = item.get('question', '')
answer = item.get('correct_answer', item.get('answer', ''))
meta_question = ""
id_value = item['id']
if question_type == 'multi_choice':
options = item['options']
options_list = '\n'.join([f"{opt}: {options[opt]}" for opt in sorted(options.keys())])
question = f"{question}\n{options_list}"
meta_question = ""
return question_type, id_value, question, answer, meta_question
elif question_type == 'open_ended_multi_choice':
options = item['options']
options_list = '\n'.join([f"{opt}: {options[opt]}" for opt in sorted(options.keys())])
question = f"{question}"
meta_question = f"The following is a multiple choice question about medicine and the agent's open-ended answer to the question. Convert the agent's answer to the final answer format using the corresponding option label, e.g., 'A', 'B', 'C', 'D', 'E' or 'None'. \n\nQuestion: {question}\n{options_list}\n\n"
return question_type, id_value, question, answer, meta_question
elif question_type == 'open_ended':
question = f"The following is an open-ended question about medicine. Provide a comprehensive answer.\n\nQuestion: {question}\n\nAnswer:"
meta_question = ""
return question_type, id_value, question, answer, meta_question
else:
raise ValueError(f"Unsupported question type: {question_type}")
def build_dataset(dataset_path=None):
"""
Build a dataset based on the dataset name and configuration.
This is the main function used by the competition framework to load datasets.
Args:
dataset_name (str): Name of the dataset ('yesno', 'treatment', or FDA subset name)
dataset_path (str): Path to the dataset file
Returns:
Dataset: Configured dataset object
"""
print("dataset_path:", dataset_path)
dataset = CureBenchDataset(dataset_path)
return dataset