Open
Description
Bug description
- When I use cuda within the
collate_fn
parameter of the dataloader to pre-process generated data in bulk, andnum_workers > 0
, - I am required to use the
spawn_ddp
strategy in the trainer - Then, I get this error:
Traceback (most recent call last):
File "/home/myuser/myproject/scripts/../train.py", line 1139, in <module>
trainer.fit(training, data)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 136, in launch
process_context = mp.start_processes(
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 268, in start_processes
idx, process, tf_name = start_process(i)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 263, in start_process
process.start()
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
return Popen(process_obj)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'TorchGraph.create_forward_hook.<locals>.after_forward_hook'
- Removing the line
wandb_logger.watch(training)
fixes the problem
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 3.10.15
#- Python version (e.g., 3.12): 3.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.4/11.5
#- GPU models and configuration: 1xRTX 3090
#- How you installed Lightning(`conda`, `pip`, source): `conda`
More info
No response