Skip to content

Commit

Permalink
Add comments in DataModule class and bug fix in collate
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Feb 19, 2025
1 parent 1d0ea1c commit a919d06
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 78 deletions.
101 changes: 50 additions & 51 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import warnings
from lightning.pytorch import LightningDataModule
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.data import Data
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory
from .dataset import PinaDatasetFactory, PinaTensorDataset
from ..collector import Collector


Expand Down Expand Up @@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
max_conditions_lengths is None else (
self._collate_standard_dataloader)
self.dataset = dataset
if isinstance(self.dataset, PinaTensorDataset):
self._collate = self._collate_tensor_dataset
else:
self._collate = self._collate_graph_dataset

def _collate_custom_dataloader(self, batch):
return self.dataset.fetch_from_idx_list(batch)
Expand All @@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
if isinstance(batch, dict):
return batch
conditions_names = batch[0].keys()

# Condition names
for condition_name in conditions_names:
single_cond_dict = {}
Expand All @@ -82,15 +85,28 @@ def _collate_standard_dataloader(self, batch):
data_list = [batch[idx][condition_name][arg] for idx in range(
min(len(batch),
self.max_conditions_lengths[condition_name]))]
if isinstance(data_list[0], LabelTensor):
single_cond_dict[arg] = LabelTensor.stack(data_list)
elif isinstance(data_list[0], torch.Tensor):
single_cond_dict[arg] = torch.stack(data_list)
elif isinstance(data_list[0], Data):
single_cond_dict[arg] = Batch.from_data_list(data_list)
single_cond_dict[arg] = self._collate(data_list)

batch_dict[condition_name] = single_cond_dict
return batch_dict

@staticmethod
def _collate_tensor_dataset(data_list):
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")

def _collate_graph_dataset(self, data_list):
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")

def __call__(self, batch):
return self.callable_function(batch)

Expand Down Expand Up @@ -157,14 +173,35 @@ def __init__(self,
logging.debug('Start initialization of Pina DataModule')
logging.info('Start initialization of Pina DataModule')
super().__init__()

# Store fixed attributes
self.batch_size = batch_size
self.shuffle = shuffle
self.repeat = repeat
self.automatic_batching = automatic_batching
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
self.num_workers = 0
else:
self.num_workers = num_workers
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.pin_memory = False
else:
self.pin_memory = pin_memory

# Collect data
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()

# Check if the splits are correct
self._check_slit_sizes(train_size, test_size, val_size, predict_size)

# Begin Data splitting
# Split input data into subsets
splits_dict = {}
if train_size > 0:
splits_dict['train'] = train_size
Expand All @@ -186,23 +223,6 @@ def __init__(self,
self.predict_dataset = None
else:
self.predict_dataloader = super().predict_dataloader

collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()

self.automatic_batching = self._set_automatic_batching_option(
collector, automatic_batching)

if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.num_workers = num_workers
self.pin_memory = pin_memory
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device

Expand Down Expand Up @@ -318,10 +338,10 @@ def _create_dataloader(self, split, dataset):
if self.batch_size is not None:
sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths(split))

collate = Collator(self.find_max_conditions_lengths(split),
dataset=dataset)
else:
collate = Collator(None, dataset)
collate = Collator(None, dataset=dataset)
return DataLoader(dataset, self.batch_size,
collate_fn=collate, sampler=sampler,
num_workers=self.num_workers)
Expand Down Expand Up @@ -395,27 +415,6 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size):
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
raise ValueError("The sum of the splits must be 1")

@staticmethod
def _set_automatic_batching_option(collector, automatic_batching):
"""
Determines whether automatic batching should be enabled.
If all 'input_points' in the collector's data collections are
tensors (torch.Tensor or LabelTensor), it respects the provided
`automatic_batching` value; otherwise, mainly in the Graph scenario,
it forces automatic batching on.
:param Collector collector: Collector object with contains all data
retrieved from input conditions
:param bool automatic_batching : If the user wants to enable automatic
batching or not
"""
if all(isinstance(v['input_points'], (torch.Tensor, LabelTensor))
for v in collector.data_collections.values()):
return automatic_batching if automatic_batching is not None \
else False
return True

@property
def input_points(self):
"""
Expand Down
48 changes: 37 additions & 11 deletions pina/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch.utils.data import Dataset
from abc import abstractmethod
from torch_geometric.data import Batch
from torch_geometric.data import Batch, Data
from pina import LabelTensor


Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths,
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
self._getitem_func = self._getitem_dummy

def _getitem_int(self, idx):
return {
Expand All @@ -84,7 +84,7 @@ def fetch_from_idx_list(self, idx):
return to_return_dict

@staticmethod
def _getitem_list(idx):
def _getitem_dummy(idx):
return idx

def get_all_data(self):
Expand All @@ -104,13 +104,29 @@ def input_points(self):
}


class PinaBatch(Batch):
def __init__(self):
super().__init__(self)

def extract(self, labels):
x = self.x
if labels != x.labels:
self.x = x.extract(labels)
return self


class PinaGraphDataset(PinaDataset):

def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
self.in_labels = {}
self.out_labels = None
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_dummy

ex_data = conditions_dict[list(conditions_dict.keys())[
0]]['input_points'][0]
for name, attr in ex_data.items():
Expand All @@ -137,22 +153,25 @@ def fetch_from_idx_list(self, idx):
if self.length > condition_len:
cond_idx = [idx % condition_len for idx in cond_idx]
to_return_dict[condition] = {
k: self._create_graph_batch_from_list(v, cond_idx)
k: self._create_graph_batch_from_list([v[i] for i in idx])
if isinstance(v, list)
else self._create_output_batch(v, cond_idx)
else self._create_output_batch(v[idx])
for k, v in data.items()
}

return to_return_dict

def _base_create_graph_batch_from_list(self, data, idx):
batch = Batch.from_data_list([data[i] for i in idx])
def _base_create_graph_batch_from_list(self, data):
batch = PinaBatch.from_data_list(data)
return batch

def _base_create_output_batch(self, data, idx):
out = data[idx].reshape(-1, *data[idx].shape[2:])
def _base_create_output_batch(self, data):
out = data.reshape(-1, *data.shape[2:])
return out

def _getitem_dummy(self, idx):
return idx

def _getitem_int(self, idx):
return {
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
Expand All @@ -164,8 +183,7 @@ def get_all_data(self):
return self.fetch_from_idx_list(index)

def __getitem__(self, idx):
return self._getitem_int(idx) if isinstance(idx, int) else \
self.fetch_from_idx_list(idx=idx)
return self._getitem_func(idx)

def _labelise_batch(self, func):
@functools.wraps(func)
Expand All @@ -186,3 +204,11 @@ def wrapper(*args, **kwargs):
out.labels = self.out_labels
return out
return wrapper

def create_graph_batch(self, data):
"""
# TODO
"""
if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data)
return self._create_output_batch(data)
3 changes: 2 additions & 1 deletion pina/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr,

@staticmethod
def _build_edge_attr(x, pos, edge_index):
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
distance = torch.abs(pos[edge_index[0]] -
pos[edge_index[1]]).as_subclass(torch.Tensor)
return distance

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,16 @@ def __init__(self,
if compile is None or sys.platform == "win32":
compile = False

self.automatic_batching = automatic_batching if automatic_batching \
is not None else False
# set attributes
self.compile = compile
self.solver = solver
self.batch_size = batch_size
self._move_to_device()
self.data_module = None
self._create_datamodule(train_size, test_size, val_size, predict_size,
batch_size, automatic_batching, pin_memory,
batch_size, automatic_batching, pin_memory,
num_workers)

# logging
Expand Down
3 changes: 1 addition & 2 deletions pina/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def labelize_forward(forward, input_variables, output_variables):
:type output_variables: list[str] | tuple[str]
"""
def wrapper(x):
if isinstance(x, LabelTensor):
x = x.extract(input_variables)
x = x.extract(input_variables)
output = forward(x)
# keep it like this, directly using LabelTensor(...) raises errors
# when compiling the code
Expand Down
Loading

0 comments on commit a919d06

Please sign in to comment.