77import logging
88import optuna
99import random
10- from torch .utils .data import Subset , DataLoader
10+ from torch .utils .data import Subset , DataLoader , DistributedSampler
1111import torch .distributed as dist
12+ import pandas as pd
1213
1314logger = logging .getLogger (__name__ )
1415
@@ -22,11 +23,15 @@ class TorchModelTrainMixin:
2223 lr : float = 2e-5
2324 device : torch .device
2425
25- def get_sampled_dataloader (self , frac = 0.1 ):
26+ def sample_dataset (self , frac = 0.1 ):
2627 dataset_size = len (self .dataset )
2728 sample_size = int (frac * dataset_size )
2829 indices = random .sample (range (dataset_size ), sample_size )
2930 sampled_dataset = Subset (self .dataset , indices )
31+ return sampled_dataset
32+
33+ def get_sampled_dataloader (self , frac = 0.1 ):
34+ sampled_dataset = self .sample_dataset (frac = frac )
3035 return DataLoader (sampled_dataset , batch_size = self .batch_size , shuffle = True )
3136
3237 def optuna_train (self , run_name :str = "" , n_trials :int = 30 , frac = 0.1 ):
@@ -53,6 +58,13 @@ def objective(self, trial):
5358 self .reinit_scheduler_optimizer (** kwargs )
5459 acc = self .train ()
5560 return acc
61+
62+ def get_ddp_dataloader (self , frac = 1.0 ):
63+ sampled_dataset = self .sample_dataset (frac = frac )
64+ sampler = DistributedSampler (sampled_dataset )
65+ dataloader = DataLoader (sampled_dataset , batch_size = self .batch_size , sampler = sampler )
66+ return dataloader , sampler
67+
5668
5769 def _train_batch (self , x , y ):
5870 inputs = self .tokenizer (x , return_tensors = "pt" , truncation = True , padding = True )
@@ -103,7 +115,6 @@ def train(self):
103115 current_acc = self ._train_batch (tweets , labels .float ())
104116 acc .append (current_acc )
105117 except RuntimeError as e :
106- raise e
107118 logger .error (e )
108119 del tweets , labels , self .optimizer
109120 gc .collect ()
0 commit comments