-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Note that, for the time being, the issue described below is not causing any real problems (except the annoying warning). But it worth documenting here as it would be nice to patch at some point.
The Problem
pandarallel is conflicting with the jax code since it explicitly sets the context to be "fork".
and thus we get the warning
../../../../mambaforge-pypy3/envs/multidms-dev/lib/python3.12/multiprocessing/popen_fork.py:66: 15 warnings
multidms/data.py: 15 warnings
tests/test_data.py: 210 warnings
/home/jared/mambaforge-pypy3/envs/multidms-dev/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=99501) is multi-threaded, use of fork() may lead to deadlocks in the child.
self.pid = os.fork()
The Cause
likely it's throwing this warning just because jax is loaded in the forked processes. However, because no jax operations happen within the context of the forked processes, no actual deadlock or issues arise.
The Solution
It seems that pandarallel is no longer being maintained, and so it may be nice to remove it completely, and replace it with something a little better like swifter, or [polars] for fast table operations. Though it's unclear whether these will have the same issues.
Ultimately, removing the jax import from the Data module is reasonable thing to do. The training data could simply be converted to jax pytrees at the time of fitting so long as the memory burden of copying the training data each time you fit is feasible. Testing will need to be done.