Open
Description
Bug description
I use trainer.fit(model, datamodule=dm)
to start training.
"dm" is an object whose class inherited from pl.LightningDataModule
, and in the class, I override the function:
def train_dataloader(self):
train_dataset = MixedBatchMultiviewDataset(self.args, self.tokenizer,
known_exs=self.known_train,
unknown_exs=self.unknown_train,
feature=self.args.feature)
train_dataloader = DataLoader(train_dataset,
batch_size = self.args.train_batch_size,
shuffle=True, num_workers=self.args.num_workers,
pin_memory=True, collate_fn=self.collate_batch_feat)
return train_dataloader
at the model's hook on_train_epoch_start
, I update the dataset:
train_dl = self.trainer.train_dataloader
train_dl.dataset.update_pseudo_labels(uid2pl)
loop = self.trainer.fit_loop
loop._combined_loader = None
loop.setup_data()
in the training_step
, the batch data is still old data, but trainer.train_dataloader.dataset
is new:
def training_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int):
self.mv_model._on_train_batch_start()
logger.info(self.trainer.train_dataloader.dataset.unknown_feats) # new
logger.info(batch) # old
What version are you seeing the problem on?
v2.3
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response