Skip to content

Commit fc9e950

Browse files
Merge pull request #28 from VectorInstitute/improve_pmc2m
Cleaned PMC-2M+inline Dataset
2 parents 2549e35 + c3da505 commit fc9e950

File tree

11 files changed

+192
-23
lines changed

11 files changed

+192
-23
lines changed

openpmcvl/experiment/configs/experiment/biomedclip_ppr.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ trainer:
108108
callbacks:
109109
model_checkpoint:
110110
monitor: val/loss
111-
save_top_k: 1
111+
save_top_k: -1
112112
save_last: True
113113
every_n_epochs: 1
114114
dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment

openpmcvl/experiment/configs/experiment/pmcoa2_matched_512.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ seed: 0
3131
datasets:
3232
train:
3333
pmc2m_sum:
34-
split: train_clean
34+
split: train_clean_sep
3535
val:
3636
pmc2m_sum:
37-
split: valid_clean
37+
split: valid_clean_sep
3838
transform:
3939
job_type: eval
4040
test:
4141
pmc2m_sum:
42-
split: test_clean
42+
split: test_clean_sep
4343
transform:
4444
job_type: eval
4545

openpmcvl/experiment/datasets/pmc2m_sum.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
] = None,
4242
) -> None:
4343
"""Initialize the dataset."""
44-
data_path = os.path.join(root_dir, f"{split}.jsonl")
44+
data_path = os.path.join(root_dir, "clean", f"{split}.jsonl")
4545
with open(data_path, encoding="utf-8") as file:
4646
entries = [json.loads(line) for line in file.readlines()]
4747
self.entries = entries
@@ -62,17 +62,16 @@ def __getitem__(self, idx: int) -> Example:
6262
try:
6363
with Image.open(entry["image_fullpath"]) as img:
6464
image = img.convert("RGB")
65+
with open(entry["caption_fullpath"], encoding="utf-8") as file:
66+
caption = file.read()
6567
except Exception as e:
6668
print(
67-
f"Error loading image for entry {idx}: image_path={entry['image_fullpath']}",
69+
f"Error loading image or caption for entry {idx}: image_path={entry['image_fullpath']} caption_path={entry['caption_fullpath']}",
6870
e,
6971
)
7072
idx = (idx + 1) % len(self.entries)
7173
return self.__getitem__(idx)
7274

73-
# load text
74-
caption = " ".join([entry["caption"], entry["intext_refs_summary"]])
75-
7675
# apply transform and tokenization
7776
if self.transform is not None:
7877
image = self.transform(image)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""PMC-2M with summarized inline references Dataset."""
2+
3+
import json
4+
import os
5+
from typing import Callable, Dict, Literal, Optional, Union
6+
7+
import torch
8+
from mmlearn.conf import external_store
9+
from mmlearn.constants import EXAMPLE_INDEX_KEY
10+
from mmlearn.datasets.core import Modalities
11+
from mmlearn.datasets.core.example import Example
12+
from omegaconf import MISSING
13+
from PIL import Image
14+
from torch.utils.data import Dataset
15+
from torchvision.transforms import ToTensor
16+
17+
18+
@external_store(group="datasets", root_dir=os.getenv("PMC2M_SUMM_ROOT_DIR", MISSING))
19+
class PMC2MSum(Dataset[Example]):
20+
"""PMC-2M with summarized inline references dataset.
21+
22+
Parameters
23+
----------
24+
root_dir : str
25+
Path to the root folder containing jsonl file with data entries.
26+
split : {"train", "valid", "test"}
27+
Dataset split.
28+
transform : Optional[Callable], default=None
29+
Transform applied to images.
30+
tokenizer : Optional[Callable], default=None
31+
Function applied to textual captions.
32+
"""
33+
34+
def __init__(
35+
self,
36+
root_dir: str,
37+
split: Literal["train", "valid", "test"] = "train",
38+
transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
39+
tokenizer: Optional[
40+
Callable[[str], Union[torch.Tensor, Dict[str, torch.Tensor]]]
41+
] = None,
42+
) -> None:
43+
"""Initialize the dataset."""
44+
data_path = os.path.join(root_dir, f"{split}.jsonl")
45+
with open(data_path, encoding="utf-8") as file:
46+
entries = [json.loads(line) for line in file.readlines()]
47+
self.entries = entries
48+
49+
self.root_dir = root_dir
50+
51+
if transform is None:
52+
self.transform = ToTensor()
53+
else:
54+
self.transform = transform
55+
56+
self.tokenizer = tokenizer
57+
58+
def __getitem__(self, idx: int) -> Example:
59+
"""Return the idx'th data sample."""
60+
entry = self.entries[idx]
61+
# load image
62+
try:
63+
with Image.open(entry["image_fullpath"]) as img:
64+
image = img.convert("RGB")
65+
except Exception as e:
66+
print(
67+
f"Error loading image for entry {idx}: image_path={entry['image_fullpath']}",
68+
e,
69+
)
70+
idx = (idx + 1) % len(self.entries)
71+
return self.__getitem__(idx)
72+
73+
# load text
74+
caption = " ".join([entry["caption"], entry["intext_refs_summary"]])
75+
76+
# apply transform and tokenization
77+
if self.transform is not None:
78+
image = self.transform(image)
79+
80+
tokens = self.tokenizer(caption) if self.tokenizer is not None else None
81+
82+
example = Example(
83+
{
84+
Modalities.RGB.name: image,
85+
Modalities.TEXT.name: caption,
86+
EXAMPLE_INDEX_KEY: idx,
87+
}
88+
)
89+
90+
if tokens is not None:
91+
if isinstance(tokens, dict): # output of HFTokenizer
92+
assert (
93+
Modalities.TEXT.name in tokens
94+
), f"Missing key `{Modalities.TEXT.name}` in tokens."
95+
example.update(tokens)
96+
else:
97+
example[Modalities.TEXT.name] = tokens
98+
99+
return example
100+
101+
def __len__(self) -> int:
102+
"""Return the length of the dataset."""
103+
return len(self.entries)

openpmcvl/experiment/scripts/eval/pmc_oa_2/ppr.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
77
hydra.launcher.tasks_per_node=4 \
88
hydra.launcher.nodes=1 \
99
hydra.launcher.stderr_to_stdout=true \
10-
hydra.launcher.timeout_min=900 \
10+
hydra.launcher.timeout_min=420 \
1111
'+hydra.launcher.additional_parameters={export: ALL}' \
1212
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
1313
+experiment=biomedclip_ppr \
@@ -30,7 +30,7 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
3030
hydra.launcher.tasks_per_node=4 \
3131
hydra.launcher.nodes=1 \
3232
hydra.launcher.stderr_to_stdout=true \
33-
hydra.launcher.timeout_min=900 \
33+
hydra.launcher.timeout_min=20 \
3434
'+hydra.launcher.additional_parameters={export: ALL}' \
3535
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
3636
+experiment=biomedclip_ppr \

openpmcvl/experiment/scripts/eval/roco/ppr.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
77
hydra.launcher.tasks_per_node=4 \
88
hydra.launcher.nodes=1 \
99
hydra.launcher.stderr_to_stdout=true \
10-
hydra.launcher.timeout_min=900 \
10+
hydra.launcher.timeout_min=420 \
1111
'+hydra.launcher.additional_parameters={export: ALL}' \
1212
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
1313
+experiment=biomedclip_ppr \
@@ -30,7 +30,7 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
3030
hydra.launcher.tasks_per_node=4 \
3131
hydra.launcher.nodes=1 \
3232
hydra.launcher.stderr_to_stdout=true \
33-
hydra.launcher.timeout_min=900 \
33+
hydra.launcher.timeout_min=20 \
3434
'+hydra.launcher.additional_parameters={export: ALL}' \
3535
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
3636
+experiment=biomedclip_ppr \

openpmcvl/experiment/scripts/eval/vitb16_bert256_pmcoa/ppr.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
77
hydra.launcher.tasks_per_node=4 \
88
hydra.launcher.nodes=1 \
99
hydra.launcher.stderr_to_stdout=true \
10-
hydra.launcher.timeout_min=900 \
10+
hydra.launcher.timeout_min=420 \
1111
'+hydra.launcher.additional_parameters={export: ALL}' \
1212
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
1313
+experiment=biomedclip_ppr \
@@ -30,7 +30,7 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
3030
hydra.launcher.tasks_per_node=4 \
3131
hydra.launcher.nodes=1 \
3232
hydra.launcher.stderr_to_stdout=true \
33-
hydra.launcher.timeout_min=900 \
33+
hydra.launcher.timeout_min=20 \
3434
'+hydra.launcher.additional_parameters={export: ALL}' \
3535
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
3636
+experiment=biomedclip_ppr \

openpmcvl/experiment/scripts/train/pmc_oa_2_512/pmc_oa_2_cl512_train_bs256_slurm.sh

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,31 @@ mmlearn_run --multirun hydra.launcher.mem_gb=0 \
1414
experiment_name=pmcoa2_matched_512_train \
1515
dataloader.train.batch_size=128 \
1616
dataloader.val.batch_size=32 \
17+
dataloader.train.num_workers=4 \
18+
dataloader.val.num_workers=4 \
19+
task.encoders.text.pretrained=False \
20+
task.encoders.rgb.pretrained=False \
21+
trainer.max_epochs=64 \
22+
task.lr_scheduler.scheduler.t_max=54476 \
23+
task.lr_scheduler.scheduler.warmup_length=5448 \
24+
~trainer.callbacks.early_stopping
25+
26+
# a100
27+
mmlearn_run --multirun hydra.launcher.mem_gb=0 \
28+
hydra.launcher.qos=a100_arashaf \
29+
hydra.launcher.partition=a100 \
30+
hydra.launcher.gres=gpu:4 \
31+
hydra.launcher.cpus_per_task=4 \
32+
hydra.launcher.tasks_per_node=4 \
33+
hydra.launcher.nodes=1 \
34+
hydra.launcher.stderr_to_stdout=true \
35+
hydra.launcher.timeout_min=600 \
36+
'+hydra.launcher.additional_parameters={export: ALL}' \
37+
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
38+
+experiment=pmcoa2_matched_512 \
39+
experiment_name=pmcoa2_matched_512_train \
40+
dataloader.train.batch_size=256 \
41+
dataloader.val.batch_size=32 \
1742
dataloader.train.num_workers=3 \
1843
dataloader.val.num_workers=3 \
1944
task.encoders.text.pretrained=False \

openpmcvl/experiment/scripts/train/pmc_oa_2_512/pmc_oa_2_cl512_train_bs256_slurm_multinode.slrm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ srun mmlearn_run \
3434
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
3535
+experiment=pmcoa2_matched_512 \
3636
experiment_name=pmcoa2_matched_512_train \
37-
dataloader.train.batch_size=8 \
37+
dataloader.train.batch_size=256 \
3838
dataloader.val.batch_size=32 \
39-
dataloader.train.num_workers=4 \
40-
dataloader.val.num_workers=4 \
39+
dataloader.train.num_workers=2 \
40+
dataloader.val.num_workers=2 \
4141
task.encoders.text.pretrained=False \
4242
task.encoders.rgb.pretrained=False \
4343
trainer.max_epochs=64 \

openpmcvl/experiment/tests/test_datasets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ def test_pmc2m_sum():
168168
), "Please set PMC2M-Sum root directory in `PMC2M_SUMM_ROOT_DIR` environment variable."
169169

170170
# test without transform and tokenizer
171-
split = "train"
171+
split = "test_clean_sep"
172172
transform = None
173173
tokenizer = None
174174
dataset = PMC2MSum(root_dir, split, transform, tokenizer)
175175
sample = dataset[0]
176+
print(f"sample: {sample}")
176177
assert isinstance(
177178
sample[Modalities.TEXT.name], str
178179
), f"Expected to find `str` in `Modalities.TEXT` but found {type(sample[Modalities.TEXT.name])}"
@@ -194,6 +195,7 @@ def test_pmc2m_sum():
194195
)
195196
dataset = PMC2MSum(root_dir, split, transform, tokenizer)
196197
sample = dataset[0]
198+
print(f"sample: {sample}")
197199
assert isinstance(
198200
sample[Modalities.TEXT.name], torch.Tensor
199201
), f"Expected to find `Tensor` in `Modalities.TEXT` but found {type(sample[Modalities.TEXT.name])}"
@@ -213,7 +215,7 @@ def test_pmc2m_sum_2():
213215
), "Please set PMC2M-Sum root directory in `PMC2M_SUMM_ROOT_DIR` environment variable."
214216

215217
# test with transform and tokenizer and dataloader
216-
split = "train_clean"
218+
split = "test_clean"
217219
batch_size = 64
218220
transform = biomedclip_vision_transform(image_crop_size=224, job_type="train")
219221
tokenizer = HFTokenizer(
@@ -250,5 +252,6 @@ def test_pmc2m_sum_2():
250252

251253

252254
if __name__ == "__main__":
255+
test_pmc2m_sum()
253256
test_pmc2m_sum_2()
254257
print("Passed")

0 commit comments

Comments
 (0)