Skip to content

Commit c91f032

Browse files
authored
Merge pull request #32 from VectorInstitute/add_extended_pipeline
Add dataset package.
2 parents ae91449 + b4ce05c commit c91f032

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed

openpmcvl/granular/dataset/__init__.py

Whitespace-only changes.

openpmcvl/granular/dataset/dataset.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import json
2+
from tqdm import tqdm
3+
from PIL import Image
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.utils.data import Dataset
8+
from torchvision import transforms
9+
10+
11+
class SubfigureDataset(Dataset):
12+
def __init__(self, data_list, transform=None):
13+
"""
14+
PyTorch Dataset class to load images from subfig_path and apply transformations.
15+
16+
Args:
17+
data_list (List[Dict]): List of dictionaries with dataset information.
18+
transform (callable, optional): Optional transform to be applied on an image.
19+
"""
20+
self.data_list = data_list
21+
self.transform = transform
22+
23+
def __len__(self):
24+
return len(self.data_list)
25+
26+
def __getitem__(self, idx):
27+
item = self.data_list[idx]
28+
subfig_path = item["subfig_path"]
29+
image = Image.open(subfig_path).convert("RGB")
30+
if self.transform:
31+
image = self.transform(image)
32+
return image, idx
33+
34+
35+
class Fig_Separation_Dataset(Dataset):
36+
def __init__(
37+
self,
38+
filepath,
39+
only_medical=True,
40+
normalization=False,
41+
start=0,
42+
end=-1,
43+
input_size=512,
44+
):
45+
self.images = [] # list of {'path':'xxx/xxx.png', 'w':256, 'h':256}
46+
if normalization:
47+
self.image_transform = transforms.Compose(
48+
[
49+
transforms.ToTensor(),
50+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
51+
]
52+
)
53+
else:
54+
self.image_transform = transforms.Compose([transforms.ToTensor()])
55+
56+
# preprocessing
57+
lines = open(filepath).readlines()
58+
dataset = [json.loads(line) for line in lines]
59+
60+
if only_medical:
61+
dataset = [data for data in dataset if data["is_medical"]]
62+
63+
dataset = dataset[start : len(dataset)]
64+
filtered_compound_fig_num = 0
65+
print(f"Total {len(dataset)} Compound Figures.")
66+
count = start
67+
68+
for datum in tqdm(dataset):
69+
image_info = {}
70+
image_info["path"] = datum["image_path"]
71+
image_info["id"] = datum["id"]
72+
image_info["index"] = count
73+
image_info["w"] = datum["width"]
74+
image_info["h"] = datum["height"]
75+
count += 1
76+
77+
self.images.append(image_info)
78+
filtered_compound_fig_num += 1
79+
80+
self.input_size = input_size
81+
82+
def __len__(self):
83+
return len(self.images)
84+
85+
def __getitem__(self, index):
86+
unpadded_image = Image.open(self.images[index]["path"]).convert("RGB")
87+
unpadded_image = self.image_transform(unpadded_image)
88+
89+
return (
90+
unpadded_image,
91+
self.images[index]["h"],
92+
self.images[index]["w"],
93+
self.images[index]["id"],
94+
self.images[index]["index"],
95+
self.input_size,
96+
)
97+
98+
99+
def fig_separation_collate(data):
100+
"""
101+
Args:
102+
data: refer to __getitem__() in FigCap_Dataset
103+
104+
Returns
105+
-------
106+
images: tensor (bs, 3, max_h, max_w)
107+
# subfigs: list of lists [ ... [box(tensor, (subfig_num, 4)), alignment(tensor, (subfig_num, max_l))], ... ]
108+
other info: ......
109+
"""
110+
pad_imgs = []
111+
unpadded_hws = []
112+
image_ids = []
113+
image_index = []
114+
unpadded_images = []
115+
116+
for sample in data:
117+
unpadded_image, unpadded_h, unpadded_w, sample_id, index, input_size = sample
118+
image_ids.append(sample_id)
119+
image_index.append(index)
120+
unpadded_hws.append([unpadded_h, unpadded_w])
121+
122+
_, h, w = unpadded_image.shape
123+
scale = min(input_size / h, input_size / w)
124+
resize_transform = transforms.Resize([round(scale * h), round(scale * w)])
125+
resized_img = resize_transform(unpadded_image) # resize within input_size
126+
pad = (0, input_size - round(scale * w), 0, input_size - round(scale * h))
127+
padded_img = F.pad(
128+
resized_img, pad, "constant", 0
129+
) # pad image to input_size x input_size
130+
pad_imgs.append(padded_img)
131+
132+
unpadded_images.append(unpadded_image) # [bs * (3, h, w)]
133+
134+
pad_imgs = torch.stack(pad_imgs, dim=0) # (bs, 3, max_w, max_h)
135+
136+
return {
137+
"image": pad_imgs,
138+
"unpadded_hws": unpadded_hws,
139+
"image_id": image_ids,
140+
"image_index": image_index,
141+
"original_image": unpadded_images,
142+
}

0 commit comments

Comments
 (0)