Skip to content

Commit 5d2b8f2

Browse files
authored
Fix invalid jsonargparse yaml using spherinator.data.DataModule (#255)
* Refactor DataModule to use LightningDataModule directly instead of importing lightning as L * remove dataloader_kwargs from DataModule init * remove kwargs from Column * Update DataModule to accept columns as a list of dictionaries for initialization * Remove dict_kwargs from the DataModule configuration in illustris.yaml
1 parent 44b3db2 commit 5d2b8f2

2 files changed

Lines changed: 10 additions & 12 deletions

File tree

experiments/illustris.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ data:
1616
init_args:
1717
degrees: 180
1818
return_dict: False
19-
dict_kwargs:
2019
batch_size: 512
2120
shuffle: True
2221
num_workers: 4

src/spherinator/data/data_module.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from pathlib import Path
22
from typing import Any, Callable, Optional
33

4-
import lightning as L
54
import pyarrow.parquet as pq
65
import torch
76
from datasets import load_dataset
7+
from lightning.pytorch import LightningDataModule
88
from torch.utils.data import Dataset
99

1010

@@ -15,24 +15,19 @@ class Column:
1515
Args:
1616
name: The name of the column
1717
transform: Optional transformation function to apply to the column data
18-
**kwargs: Additional parameters that can be stored with the column
18+
shape: Optional shape to reshape the data to before applying the transform
1919
"""
2020

2121
def __init__(
2222
self,
2323
name: str,
2424
transform: Optional[Callable[[Any], Any]] = None,
2525
shape: Optional[tuple[int, ...]] = None,
26-
**kwargs,
2726
):
2827
self.name = name
2928
self.transform = transform
3029
self.shape = shape
3130

32-
# Store any additional parameters
33-
for key, value in kwargs.items():
34-
setattr(self, key, value)
35-
3631
def apply_transform(self, data: Any) -> Any:
3732
"""
3833
Apply the transformation to the given data.
@@ -138,7 +133,7 @@ def __getitem__(self, idx):
138133
return transformed_sample
139134

140135

141-
class DataModule(L.LightningDataModule):
136+
class DataModule(LightningDataModule):
142137
"""DataModule for loading datasets from Hugging Face Datasets library and preparing them for
143138
PyTorch. This DataModule supports loading datasets from Hugging Face Datasets format (e.g.
144139
"ylecun/mnist") or from local parquet files. It also allows specifying which columns to load
@@ -167,17 +162,21 @@ class DataModule(L.LightningDataModule):
167162
def __init__(
168163
self,
169164
path: str,
170-
columns: Optional[list[Column]] = None,
165+
columns: Optional[list[dict[str, Any]]] = None,
171166
return_dict: bool = True,
172167
validation_size: float = 0.2,
173168
test_size: float = 0.5,
174169
in_gpu_memory: bool = False,
175-
**dataloader_kwargs,
170+
batch_size: int = 32,
171+
shuffle: bool = True,
172+
num_workers: int = 0,
176173
):
177174
super().__init__()
178175

179176
if columns is None:
180177
columns = [Column(name="data")]
178+
else:
179+
columns = [Column(**c) if isinstance(c, dict) else c for c in columns]
181180

182181
self.path: str = path
183182
self.columns: list[Column] = columns
@@ -190,7 +189,7 @@ def __init__(
190189
self.in_gpu_memory: bool = in_gpu_memory
191190

192191
# Store DataLoader kwargs for forwarding
193-
self.dataloader_kwargs = dataloader_kwargs
192+
self.dataloader_kwargs = {"batch_size": batch_size, "shuffle": shuffle, "num_workers": num_workers}
194193

195194
def prepare_data(self):
196195
load_dataset(self.path)

0 commit comments

Comments
 (0)