2
2
import warnings
3
3
from lightning .pytorch import LightningDataModule
4
4
import torch
5
- from torch_geometric .data import Data , Batch
5
+ from torch_geometric .data import Data
6
6
from torch .utils .data import DataLoader , SequentialSampler , RandomSampler
7
7
from torch .utils .data .distributed import DistributedSampler
8
8
from ..label_tensor import LabelTensor
9
- from .dataset import PinaDatasetFactory
9
+ from .dataset import PinaDatasetFactory , PinaTensorDataset
10
10
from ..collector import Collector
11
11
12
12
@@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
61
61
max_conditions_lengths is None else (
62
62
self ._collate_standard_dataloader )
63
63
self .dataset = dataset
64
+ if isinstance (self .dataset , PinaTensorDataset ):
65
+ self ._collate = self ._collate_tensor_dataset
66
+ else :
67
+ self ._collate = self ._collate_graph_dataset
64
68
65
69
def _collate_custom_dataloader (self , batch ):
66
70
return self .dataset .fetch_from_idx_list (batch )
@@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
73
77
if isinstance (batch , dict ):
74
78
return batch
75
79
conditions_names = batch [0 ].keys ()
76
-
77
80
# Condition names
78
81
for condition_name in conditions_names :
79
82
single_cond_dict = {}
@@ -82,15 +85,28 @@ def _collate_standard_dataloader(self, batch):
82
85
data_list = [batch [idx ][condition_name ][arg ] for idx in range (
83
86
min (len (batch ),
84
87
self .max_conditions_lengths [condition_name ]))]
85
- if isinstance (data_list [0 ], LabelTensor ):
86
- single_cond_dict [arg ] = LabelTensor .stack (data_list )
87
- elif isinstance (data_list [0 ], torch .Tensor ):
88
- single_cond_dict [arg ] = torch .stack (data_list )
89
- elif isinstance (data_list [0 ], Data ):
90
- single_cond_dict [arg ] = Batch .from_data_list (data_list )
88
+ single_cond_dict [arg ] = self ._collate (data_list )
89
+
91
90
batch_dict [condition_name ] = single_cond_dict
92
91
return batch_dict
93
92
93
+ @staticmethod
94
+ def _collate_tensor_dataset (data_list ):
95
+ if isinstance (data_list [0 ], LabelTensor ):
96
+ return LabelTensor .stack (data_list )
97
+ if isinstance (data_list [0 ], torch .Tensor ):
98
+ return torch .stack (data_list )
99
+ raise RuntimeError ("Data must be Tensors or LabelTensor " )
100
+
101
+ def _collate_graph_dataset (self , data_list ):
102
+ if isinstance (data_list [0 ], LabelTensor ):
103
+ return LabelTensor .cat (data_list )
104
+ if isinstance (data_list [0 ], torch .Tensor ):
105
+ return torch .cat (data_list )
106
+ if isinstance (data_list [0 ], Data ):
107
+ return self .dataset .create_graph_batch (data_list )
108
+ raise RuntimeError ("Data must be Tensors or LabelTensor or pyG Data" )
109
+
94
110
def __call__ (self , batch ):
95
111
return self .callable_function (batch )
96
112
@@ -157,14 +173,35 @@ def __init__(self,
157
173
logging .debug ('Start initialization of Pina DataModule' )
158
174
logging .info ('Start initialization of Pina DataModule' )
159
175
super ().__init__ ()
176
+
177
+ # Store fixed attributes
160
178
self .batch_size = batch_size
161
179
self .shuffle = shuffle
162
180
self .repeat = repeat
181
+ self .automatic_batching = automatic_batching
182
+ if batch_size is None and num_workers != 0 :
183
+ warnings .warn (
184
+ "Setting num_workers when batch_size is None has no effect on "
185
+ "the DataLoading process." )
186
+ self .num_workers = 0
187
+ else :
188
+ self .num_workers = num_workers
189
+ if batch_size is None and pin_memory :
190
+ warnings .warn ("Setting pin_memory to True has no effect when "
191
+ "batch_size is None." )
192
+ self .pin_memory = False
193
+ else :
194
+ self .pin_memory = pin_memory
195
+
196
+ # Collect data
197
+ collector = Collector (problem )
198
+ collector .store_fixed_data ()
199
+ collector .store_sample_domains ()
163
200
164
201
# Check if the splits are correct
165
202
self ._check_slit_sizes (train_size , test_size , val_size , predict_size )
166
203
167
- # Begin Data splitting
204
+ # Split input data into subsets
168
205
splits_dict = {}
169
206
if train_size > 0 :
170
207
splits_dict ['train' ] = train_size
@@ -186,23 +223,6 @@ def __init__(self,
186
223
self .predict_dataset = None
187
224
else :
188
225
self .predict_dataloader = super ().predict_dataloader
189
-
190
- collector = Collector (problem )
191
- collector .store_fixed_data ()
192
- collector .store_sample_domains ()
193
-
194
- self .automatic_batching = self ._set_automatic_batching_option (
195
- collector , automatic_batching )
196
-
197
- if batch_size is None and num_workers != 0 :
198
- warnings .warn (
199
- "Setting num_workers when batch_size is None has no effect on "
200
- "the DataLoading process." )
201
- if batch_size is None and pin_memory :
202
- warnings .warn ("Setting pin_memory to True has no effect when "
203
- "batch_size is None." )
204
- self .num_workers = num_workers
205
- self .pin_memory = pin_memory
206
226
self .collector_splits = self ._create_splits (collector , splits_dict )
207
227
self .transfer_batch_to_device = self ._transfer_batch_to_device
208
228
@@ -318,10 +338,10 @@ def _create_dataloader(self, split, dataset):
318
338
if self .batch_size is not None :
319
339
sampler = PinaSampler (dataset , shuffle )
320
340
if self .automatic_batching :
321
- collate = Collator (self .find_max_conditions_lengths (split ))
322
-
341
+ collate = Collator (self .find_max_conditions_lengths (split ),
342
+ dataset = dataset )
323
343
else :
324
- collate = Collator (None , dataset )
344
+ collate = Collator (None , dataset = dataset )
325
345
return DataLoader (dataset , self .batch_size ,
326
346
collate_fn = collate , sampler = sampler ,
327
347
num_workers = self .num_workers )
@@ -395,27 +415,6 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size):
395
415
if abs (train_size + test_size + val_size + predict_size - 1 ) > 1e-6 :
396
416
raise ValueError ("The sum of the splits must be 1" )
397
417
398
- @staticmethod
399
- def _set_automatic_batching_option (collector , automatic_batching ):
400
- """
401
- Determines whether automatic batching should be enabled.
402
-
403
- If all 'input_points' in the collector's data collections are
404
- tensors (torch.Tensor or LabelTensor), it respects the provided
405
- `automatic_batching` value; otherwise, mainly in the Graph scenario,
406
- it forces automatic batching on.
407
-
408
- :param Collector collector: Collector object with contains all data
409
- retrieved from input conditions
410
- :param bool automatic_batching : If the user wants to enable automatic
411
- batching or not
412
- """
413
- if all (isinstance (v ['input_points' ], (torch .Tensor , LabelTensor ))
414
- for v in collector .data_collections .values ()):
415
- return automatic_batching if automatic_batching is not None \
416
- else False
417
- return True
418
-
419
418
@property
420
419
def input_points (self ):
421
420
"""
0 commit comments