Skip to content

jax disagreement with pandarallel #147

@jgallowa07

Description

@jgallowa07

Note that, for the time being, the issue described below is not causing any real problems (except the annoying warning). But it worth documenting here as it would be nice to patch at some point.

The Problem

pandarallel is conflicting with the jax code since it explicitly sets the context to be "fork".
and thus we get the warning

../../../../mambaforge-pypy3/envs/multidms-dev/lib/python3.12/multiprocessing/popen_fork.py:66: 15 warnings
multidms/data.py: 15 warnings
tests/test_data.py: 210 warnings
  /home/jared/mambaforge-pypy3/envs/multidms-dev/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=99501) is multi-threaded, use of fork() may lead to deadlocks in the child.
    self.pid = os.fork()

The Cause

likely it's throwing this warning just because jax is loaded in the forked processes. However, because no jax operations happen within the context of the forked processes, no actual deadlock or issues arise.

The Solution

It seems that pandarallel is no longer being maintained, and so it may be nice to remove it completely, and replace it with something a little better like swifter, or [polars] for fast table operations. Though it's unclear whether these will have the same issues.

Ultimately, removing the jax import from the Data module is reasonable thing to do. The training data could simply be converted to jax pytrees at the time of fitting so long as the memory burden of copying the training data each time you fit is feasible. Testing will need to be done.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions