v2.0: Migrate to TensorFlow 2.19, Keras 3, and pixi#76
Merged
andrewkern merged 11 commits intomasterfrom Mar 27, 2026
Merged
Conversation
Replace the TF 1.x compat GPU configuration (ConfigProto/Session) with native TF2 memory growth API, fixing the GPU underutilization reported in issues #65 and #73. Switch from tensorflow.keras to standalone Keras 3 imports and update model serialization from JSON+H5 to the .keras format. Replace mamba/conda/pip tooling with pixi for reproducible environment management. All dependencies including CUDA are now handled by a single `pixi install`. Remove setup.py and requirements.txt in favor of pyproject.toml and pixi.toml. Additional fixes for Python 3.12+ compatibility: - Cast numpy integers to int for random.seed() - Remove deprecated multiprocessing args from model.fit() - Return tuples instead of lists from batch generators for Keras 3 - Add super().__init__() calls to all generator classes - Disable XLA JIT compilation, set XLA_FLAGS for CUDA libdevice - Pin tskit<1.0 and msprime<1.4 for numpy<2 compatibility with TF 2.17
Use cuda-compat package to provide forward-compatible CUDA driver libraries, allowing TF 2.18+ (CUDA 12.8) to run on systems with older NVIDIA drivers (e.g. 535/CUDA 12.2). This resolves the numpy version conflict: TF 2.18+ supports numpy 2.x which tskit >= 1.0 requires.
Add 63 unit tests covering helpers, batch generator preprocessing, and simulator methods. Tests run in ~4s with no GPU required via `pixi run -e test test`. Fix NearestNeighbors positional arg usage in sort_min_diff (both helpers.py and sequenceBatchGenerator.py) broken by newer scikit-learn requiring the keyword form n_neighbors=. Add pytest and a test environment to pixi.toml.
Runs the test suite on push to master/tf-migration and on PRs to master. Uses pixi's CPU-only test environment (no GPU needed).
Add examples of pixi shell and pixi run usage so users know how to invoke ReLERNN commands after installation.
Keras 3 model.predict() returns 2D arrays (N, 1) instead of 1D. Use .flatten() instead of iterating with float() to handle both.
Got a warning saying " WARN Encountered 1 warning while parsing the manifest: ⚠ The `project` field is deprecated. Use `workspace` instead."
Add CLAUDE.md, experiments/, example_output/, and baseline_results/ to gitignore. Update CLAUDE.md to reflect the new test suite.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
tf.compat.v1GPU config with nativetf.config.set_memory_growth(), fixing GPU underutilization reported in GPU Partial Usage and Detection Issues on HPC #65 and very ad-hoc fix for GPU usage on HPC #73. Switch fromtensorflow.kerasto standalone Keras 3 imports. Update model serialization from JSON+H5 to.kerasformat.pixi installhandles Python, TensorFlow, CUDA, and all dependencies. Supports GPU (default) and CPU-only (cpu) environments. Removessetup.pyandrequirements.txtin favor ofpyproject.tomlandpixi.toml.NearestNeighborspositional arg (scikit-learn),random.seed()with numpy integers (Python 3.12+),plotResultswith Keras 3 prediction shapes, Keras 3model.fit()API changes.cuda-compatfor forward-compatible CUDA, allowing TF 2.19 (CUDA 12.8) on older drivers. Removes the need to pintskit<1.0ormsprime<1.4.Test plan
pixi run -e test test— 63 unit tests passpixi run example— full VCF pipeline (SIMULATE → TRAIN → PREDICT → BSCORRECT) completes on GPUpixi run example-pool— pool-seq pipelineCloses #73, closes #65