11from pathlib import Path
22from typing import Any , Callable , Optional
33
4- import lightning as L
54import pyarrow .parquet as pq
65import torch
76from datasets import load_dataset
7+ from lightning .pytorch import LightningDataModule
88from torch .utils .data import Dataset
99
1010
@@ -15,24 +15,19 @@ class Column:
1515 Args:
1616 name: The name of the column
1717 transform: Optional transformation function to apply to the column data
18- **kwargs: Additional parameters that can be stored with the column
18+ shape: Optional shape to reshape the data to before applying the transform
1919 """
2020
2121 def __init__ (
2222 self ,
2323 name : str ,
2424 transform : Optional [Callable [[Any ], Any ]] = None ,
2525 shape : Optional [tuple [int , ...]] = None ,
26- ** kwargs ,
2726 ):
2827 self .name = name
2928 self .transform = transform
3029 self .shape = shape
3130
32- # Store any additional parameters
33- for key , value in kwargs .items ():
34- setattr (self , key , value )
35-
3631 def apply_transform (self , data : Any ) -> Any :
3732 """
3833 Apply the transformation to the given data.
@@ -138,7 +133,7 @@ def __getitem__(self, idx):
138133 return transformed_sample
139134
140135
141- class DataModule (L . LightningDataModule ):
136+ class DataModule (LightningDataModule ):
142137 """DataModule for loading datasets from Hugging Face Datasets library and preparing them for
143138 PyTorch. This DataModule supports loading datasets from Hugging Face Datasets format (e.g.
144139 "ylecun/mnist") or from local parquet files. It also allows specifying which columns to load
@@ -167,17 +162,21 @@ class DataModule(L.LightningDataModule):
167162 def __init__ (
168163 self ,
169164 path : str ,
170- columns : Optional [list [Column ]] = None ,
165+ columns : Optional [list [dict [ str , Any ] ]] = None ,
171166 return_dict : bool = True ,
172167 validation_size : float = 0.2 ,
173168 test_size : float = 0.5 ,
174169 in_gpu_memory : bool = False ,
175- ** dataloader_kwargs ,
170+ batch_size : int = 32 ,
171+ shuffle : bool = True ,
172+ num_workers : int = 0 ,
176173 ):
177174 super ().__init__ ()
178175
179176 if columns is None :
180177 columns = [Column (name = "data" )]
178+ else :
179+ columns = [Column (** c ) if isinstance (c , dict ) else c for c in columns ]
181180
182181 self .path : str = path
183182 self .columns : list [Column ] = columns
@@ -190,7 +189,7 @@ def __init__(
190189 self .in_gpu_memory : bool = in_gpu_memory
191190
192191 # Store DataLoader kwargs for forwarding
193- self .dataloader_kwargs = dataloader_kwargs
192+ self .dataloader_kwargs = { "batch_size" : batch_size , "shuffle" : shuffle , "num_workers" : num_workers }
194193
195194 def prepare_data (self ):
196195 load_dataset (self .path )
0 commit comments