Skip to content

Commit a919d06

Browse files
committed
Add comments in DataModule class and bug fix in collate
1 parent 1d0ea1c commit a919d06

File tree

6 files changed

+183
-78
lines changed

6 files changed

+183
-78
lines changed

pina/data/data_module.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import warnings
33
from lightning.pytorch import LightningDataModule
44
import torch
5-
from torch_geometric.data import Data, Batch
5+
from torch_geometric.data import Data
66
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
77
from torch.utils.data.distributed import DistributedSampler
88
from ..label_tensor import LabelTensor
9-
from .dataset import PinaDatasetFactory
9+
from .dataset import PinaDatasetFactory, PinaTensorDataset
1010
from ..collector import Collector
1111

1212

@@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
6161
max_conditions_lengths is None else (
6262
self._collate_standard_dataloader)
6363
self.dataset = dataset
64+
if isinstance(self.dataset, PinaTensorDataset):
65+
self._collate = self._collate_tensor_dataset
66+
else:
67+
self._collate = self._collate_graph_dataset
6468

6569
def _collate_custom_dataloader(self, batch):
6670
return self.dataset.fetch_from_idx_list(batch)
@@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
7377
if isinstance(batch, dict):
7478
return batch
7579
conditions_names = batch[0].keys()
76-
7780
# Condition names
7881
for condition_name in conditions_names:
7982
single_cond_dict = {}
@@ -82,15 +85,28 @@ def _collate_standard_dataloader(self, batch):
8285
data_list = [batch[idx][condition_name][arg] for idx in range(
8386
min(len(batch),
8487
self.max_conditions_lengths[condition_name]))]
85-
if isinstance(data_list[0], LabelTensor):
86-
single_cond_dict[arg] = LabelTensor.stack(data_list)
87-
elif isinstance(data_list[0], torch.Tensor):
88-
single_cond_dict[arg] = torch.stack(data_list)
89-
elif isinstance(data_list[0], Data):
90-
single_cond_dict[arg] = Batch.from_data_list(data_list)
88+
single_cond_dict[arg] = self._collate(data_list)
89+
9190
batch_dict[condition_name] = single_cond_dict
9291
return batch_dict
9392

93+
@staticmethod
94+
def _collate_tensor_dataset(data_list):
95+
if isinstance(data_list[0], LabelTensor):
96+
return LabelTensor.stack(data_list)
97+
if isinstance(data_list[0], torch.Tensor):
98+
return torch.stack(data_list)
99+
raise RuntimeError("Data must be Tensors or LabelTensor ")
100+
101+
def _collate_graph_dataset(self, data_list):
102+
if isinstance(data_list[0], LabelTensor):
103+
return LabelTensor.cat(data_list)
104+
if isinstance(data_list[0], torch.Tensor):
105+
return torch.cat(data_list)
106+
if isinstance(data_list[0], Data):
107+
return self.dataset.create_graph_batch(data_list)
108+
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
109+
94110
def __call__(self, batch):
95111
return self.callable_function(batch)
96112

@@ -157,14 +173,35 @@ def __init__(self,
157173
logging.debug('Start initialization of Pina DataModule')
158174
logging.info('Start initialization of Pina DataModule')
159175
super().__init__()
176+
177+
# Store fixed attributes
160178
self.batch_size = batch_size
161179
self.shuffle = shuffle
162180
self.repeat = repeat
181+
self.automatic_batching = automatic_batching
182+
if batch_size is None and num_workers != 0:
183+
warnings.warn(
184+
"Setting num_workers when batch_size is None has no effect on "
185+
"the DataLoading process.")
186+
self.num_workers = 0
187+
else:
188+
self.num_workers = num_workers
189+
if batch_size is None and pin_memory:
190+
warnings.warn("Setting pin_memory to True has no effect when "
191+
"batch_size is None.")
192+
self.pin_memory = False
193+
else:
194+
self.pin_memory = pin_memory
195+
196+
# Collect data
197+
collector = Collector(problem)
198+
collector.store_fixed_data()
199+
collector.store_sample_domains()
163200

164201
# Check if the splits are correct
165202
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
166203

167-
# Begin Data splitting
204+
# Split input data into subsets
168205
splits_dict = {}
169206
if train_size > 0:
170207
splits_dict['train'] = train_size
@@ -186,23 +223,6 @@ def __init__(self,
186223
self.predict_dataset = None
187224
else:
188225
self.predict_dataloader = super().predict_dataloader
189-
190-
collector = Collector(problem)
191-
collector.store_fixed_data()
192-
collector.store_sample_domains()
193-
194-
self.automatic_batching = self._set_automatic_batching_option(
195-
collector, automatic_batching)
196-
197-
if batch_size is None and num_workers != 0:
198-
warnings.warn(
199-
"Setting num_workers when batch_size is None has no effect on "
200-
"the DataLoading process.")
201-
if batch_size is None and pin_memory:
202-
warnings.warn("Setting pin_memory to True has no effect when "
203-
"batch_size is None.")
204-
self.num_workers = num_workers
205-
self.pin_memory = pin_memory
206226
self.collector_splits = self._create_splits(collector, splits_dict)
207227
self.transfer_batch_to_device = self._transfer_batch_to_device
208228

@@ -318,10 +338,10 @@ def _create_dataloader(self, split, dataset):
318338
if self.batch_size is not None:
319339
sampler = PinaSampler(dataset, shuffle)
320340
if self.automatic_batching:
321-
collate = Collator(self.find_max_conditions_lengths(split))
322-
341+
collate = Collator(self.find_max_conditions_lengths(split),
342+
dataset=dataset)
323343
else:
324-
collate = Collator(None, dataset)
344+
collate = Collator(None, dataset=dataset)
325345
return DataLoader(dataset, self.batch_size,
326346
collate_fn=collate, sampler=sampler,
327347
num_workers=self.num_workers)
@@ -395,27 +415,6 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size):
395415
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
396416
raise ValueError("The sum of the splits must be 1")
397417

398-
@staticmethod
399-
def _set_automatic_batching_option(collector, automatic_batching):
400-
"""
401-
Determines whether automatic batching should be enabled.
402-
403-
If all 'input_points' in the collector's data collections are
404-
tensors (torch.Tensor or LabelTensor), it respects the provided
405-
`automatic_batching` value; otherwise, mainly in the Graph scenario,
406-
it forces automatic batching on.
407-
408-
:param Collector collector: Collector object with contains all data
409-
retrieved from input conditions
410-
:param bool automatic_batching : If the user wants to enable automatic
411-
batching or not
412-
"""
413-
if all(isinstance(v['input_points'], (torch.Tensor, LabelTensor))
414-
for v in collector.data_collections.values()):
415-
return automatic_batching if automatic_batching is not None \
416-
else False
417-
return True
418-
419418
@property
420419
def input_points(self):
421420
"""

pina/data/dataset.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch.utils.data import Dataset
77
from abc import abstractmethod
8-
from torch_geometric.data import Batch
8+
from torch_geometric.data import Batch, Data
99
from pina import LabelTensor
1010

1111

@@ -64,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths,
6464
if automatic_batching:
6565
self._getitem_func = self._getitem_int
6666
else:
67-
self._getitem_func = self._getitem_list
67+
self._getitem_func = self._getitem_dummy
6868

6969
def _getitem_int(self, idx):
7070
return {
@@ -84,7 +84,7 @@ def fetch_from_idx_list(self, idx):
8484
return to_return_dict
8585

8686
@staticmethod
87-
def _getitem_list(idx):
87+
def _getitem_dummy(idx):
8888
return idx
8989

9090
def get_all_data(self):
@@ -104,13 +104,29 @@ def input_points(self):
104104
}
105105

106106

107+
class PinaBatch(Batch):
108+
def __init__(self):
109+
super().__init__(self)
110+
111+
def extract(self, labels):
112+
x = self.x
113+
if labels != x.labels:
114+
self.x = x.extract(labels)
115+
return self
116+
117+
107118
class PinaGraphDataset(PinaDataset):
108119

109120
def __init__(self, conditions_dict, max_conditions_lengths,
110121
automatic_batching):
111122
super().__init__(conditions_dict, max_conditions_lengths)
112123
self.in_labels = {}
113124
self.out_labels = None
125+
if automatic_batching:
126+
self._getitem_func = self._getitem_int
127+
else:
128+
self._getitem_func = self._getitem_dummy
129+
114130
ex_data = conditions_dict[list(conditions_dict.keys())[
115131
0]]['input_points'][0]
116132
for name, attr in ex_data.items():
@@ -137,22 +153,25 @@ def fetch_from_idx_list(self, idx):
137153
if self.length > condition_len:
138154
cond_idx = [idx % condition_len for idx in cond_idx]
139155
to_return_dict[condition] = {
140-
k: self._create_graph_batch_from_list(v, cond_idx)
156+
k: self._create_graph_batch_from_list([v[i] for i in idx])
141157
if isinstance(v, list)
142-
else self._create_output_batch(v, cond_idx)
158+
else self._create_output_batch(v[idx])
143159
for k, v in data.items()
144160
}
145161

146162
return to_return_dict
147163

148-
def _base_create_graph_batch_from_list(self, data, idx):
149-
batch = Batch.from_data_list([data[i] for i in idx])
164+
def _base_create_graph_batch_from_list(self, data):
165+
batch = PinaBatch.from_data_list(data)
150166
return batch
151167

152-
def _base_create_output_batch(self, data, idx):
153-
out = data[idx].reshape(-1, *data[idx].shape[2:])
168+
def _base_create_output_batch(self, data):
169+
out = data.reshape(-1, *data.shape[2:])
154170
return out
155171

172+
def _getitem_dummy(self, idx):
173+
return idx
174+
156175
def _getitem_int(self, idx):
157176
return {
158177
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
@@ -164,8 +183,7 @@ def get_all_data(self):
164183
return self.fetch_from_idx_list(index)
165184

166185
def __getitem__(self, idx):
167-
return self._getitem_int(idx) if isinstance(idx, int) else \
168-
self.fetch_from_idx_list(idx=idx)
186+
return self._getitem_func(idx)
169187

170188
def _labelise_batch(self, func):
171189
@functools.wraps(func)
@@ -186,3 +204,11 @@ def wrapper(*args, **kwargs):
186204
out.labels = self.out_labels
187205
return out
188206
return wrapper
207+
208+
def create_graph_batch(self, data):
209+
"""
210+
# TODO
211+
"""
212+
if isinstance(data[0], Data):
213+
return self._create_graph_batch_from_list(data)
214+
return self._create_output_batch(data)

pina/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr,
125125

126126
@staticmethod
127127
def _build_edge_attr(x, pos, edge_index):
128-
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
128+
distance = torch.abs(pos[edge_index[0]] -
129+
pos[edge_index[1]]).as_subclass(torch.Tensor)
129130
return distance
130131

131132
@staticmethod

pina/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,16 @@ def __init__(self,
106106
if compile is None or sys.platform == "win32":
107107
compile = False
108108

109+
self.automatic_batching = automatic_batching if automatic_batching \
110+
is not None else False
109111
# set attributes
110112
self.compile = compile
111113
self.solver = solver
112114
self.batch_size = batch_size
113115
self._move_to_device()
114116
self.data_module = None
115117
self._create_datamodule(train_size, test_size, val_size, predict_size,
116-
batch_size, automatic_batching, pin_memory,
118+
batch_size, automatic_batching, pin_memory,
117119
num_workers)
118120

119121
# logging

pina/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def labelize_forward(forward, input_variables, output_variables):
4848
:type output_variables: list[str] | tuple[str]
4949
"""
5050
def wrapper(x):
51-
if isinstance(x, LabelTensor):
52-
x = x.extract(input_variables)
51+
x = x.extract(input_variables)
5352
output = forward(x)
5453
# keep it like this, directly using LabelTensor(...) raises errors
5554
# when compiling the code

0 commit comments

Comments
 (0)