Skip to content

v2.0: Migrate to TensorFlow 2.19, Keras 3, and pixi#76

Merged
andrewkern merged 11 commits intomasterfrom
tf-migration
Mar 27, 2026
Merged

v2.0: Migrate to TensorFlow 2.19, Keras 3, and pixi#76
andrewkern merged 11 commits intomasterfrom
tf-migration

Conversation

@andrewkern
Copy link
Copy Markdown
Member

Summary

  • Migrate from TensorFlow 2.15 to 2.19 + Keras 3: Replace tf.compat.v1 GPU config with native tf.config.set_memory_growth(), fixing GPU underutilization reported in GPU Partial Usage and Detection Issues on HPC #65 and very ad-hoc fix for GPU usage on HPC #73. Switch from tensorflow.keras to standalone Keras 3 imports. Update model serialization from JSON+H5 to .keras format.
  • Replace conda/pip with pixi: Single pixi install handles Python, TensorFlow, CUDA, and all dependencies. Supports GPU (default) and CPU-only (cpu) environments. Removes setup.py and requirements.txt in favor of pyproject.toml and pixi.toml.
  • Add unit tests and CI: 63 tests covering helpers, batch generator preprocessing, and simulator methods. GitHub Actions workflow runs tests on every push/PR.
  • Fix compatibility bugs: NearestNeighbors positional arg (scikit-learn), random.seed() with numpy integers (Python 3.12+), plotResults with Keras 3 prediction shapes, Keras 3 model.fit() API changes.
  • Use cuda-compat for forward-compatible CUDA, allowing TF 2.19 (CUDA 12.8) on older drivers. Removes the need to pin tskit<1.0 or msprime<1.4.

Test plan

  • pixi run -e test test — 63 unit tests pass
  • pixi run example — full VCF pipeline (SIMULATE → TRAIN → PREDICT → BSCORRECT) completes on GPU
  • pixi run example-pool — pool-seq pipeline
  • CI passes on GitHub Actions
  • GPU profiling confirms CuDNN GRU kernel is active (353x speedup over fallback)

Closes #73, closes #65

andrewkern and others added 11 commits March 27, 2026 07:11
Replace the TF 1.x compat GPU configuration (ConfigProto/Session) with
native TF2 memory growth API, fixing the GPU underutilization reported
in issues #65 and #73. Switch from tensorflow.keras to standalone Keras 3
imports and update model serialization from JSON+H5 to the .keras format.

Replace mamba/conda/pip tooling with pixi for reproducible environment
management. All dependencies including CUDA are now handled by a single
`pixi install`. Remove setup.py and requirements.txt in favor of
pyproject.toml and pixi.toml.

Additional fixes for Python 3.12+ compatibility:
- Cast numpy integers to int for random.seed()
- Remove deprecated multiprocessing args from model.fit()
- Return tuples instead of lists from batch generators for Keras 3
- Add super().__init__() calls to all generator classes
- Disable XLA JIT compilation, set XLA_FLAGS for CUDA libdevice
- Pin tskit<1.0 and msprime<1.4 for numpy<2 compatibility with TF 2.17
Use cuda-compat package to provide forward-compatible CUDA driver
libraries, allowing TF 2.18+ (CUDA 12.8) to run on systems with
older NVIDIA drivers (e.g. 535/CUDA 12.2). This resolves the numpy
version conflict: TF 2.18+ supports numpy 2.x which tskit >= 1.0
requires.
Add 63 unit tests covering helpers, batch generator preprocessing,
and simulator methods. Tests run in ~4s with no GPU required via
`pixi run -e test test`.

Fix NearestNeighbors positional arg usage in sort_min_diff (both
helpers.py and sequenceBatchGenerator.py) broken by newer scikit-learn
requiring the keyword form n_neighbors=.

Add pytest and a test environment to pixi.toml.
Runs the test suite on push to master/tf-migration and on PRs
to master. Uses pixi's CPU-only test environment (no GPU needed).
Add examples of pixi shell and pixi run usage so users know how
to invoke ReLERNN commands after installation.
Keras 3 model.predict() returns 2D arrays (N, 1) instead of 1D.
Use .flatten() instead of iterating with float() to handle both.
Got a warning saying " WARN Encountered 1 warning while parsing the manifest:
  ⚠ The `project` field is deprecated. Use `workspace` instead."
Add CLAUDE.md, experiments/, example_output/, and baseline_results/
to gitignore. Update CLAUDE.md to reflect the new test suite.
@andrewkern andrewkern merged commit 6655efd into master Mar 27, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

very ad-hoc fix for GPU usage on HPC GPU Partial Usage and Detection Issues on HPC

2 participants