Skip to content

feat: complete ensemble refactoring with parallel training support#15

Merged
andrewkern merged 7 commits into
mainfrom
feature/ensemble-refactoring
Jul 7, 2025
Merged

feat: complete ensemble refactoring with parallel training support#15
andrewkern merged 7 commits into
mainfrom
feature/ensemble-refactoring

Conversation

@andrewkern

Copy link
Copy Markdown
Member

feat: complete ensemble refactoring with parallel training support

Summary

Refactored ensemble functionality from separate EnsembleLocator class into integrated EnsembleMixin, adding k-fold cross-validation with parallel GPU support.

Key Changes

  • Integrated ensemble into main Locator class via EnsembleMixin - train_ensemble() and predict_ensemble() methods
  • Added parallel ensemble training - parallel_train_ensemble() distributes folds across multiple GPUs
  • Implemented EnsembleModelManager for efficient model persistence and on-demand loading
  • Added uncertainty quantification - predictions include standard deviation across ensemble members
  • Fixed critical bugs in parallel implementation (genotype serialization, sample ID alignment, division by zero)
  • Reduced verbosity - cleaner output with configurable logging levels

Testing

  • 10 comprehensive tests for ensemble functionality
  • 12 tests for parallel ensemble training
  • Integration tests with NA handling and GPU optimization

Documentation

  • New comprehensive ensemble guide (docs/source/ensemble_guide.rst)
  • Updated API reference, usage guide, and examples
  • Added parallel ensemble section to multi-GPU guide

Breaking Changes

  • EnsembleLocator class removed - use Locator with ensemble methods
  • API simplified: train()train_ensemble(), predict()predict_ensemble()

- Created EnsembleMixin with modern patterns (IndexSet, tf.data pipeline)
- Added k_fold_split method to IndexSet for efficient fold creation
- Refactored EnsembleLocator as a legacy compatibility wrapper
- Added comprehensive tests for ensemble functionality
- Integrated EnsembleMixin into core Locator class

Key improvements:
- Memory-efficient data handling without array copies
- Consistent NA handling with na_action parameter
- Integration with modern tf.data pipeline
- Backward compatibility through legacy wrapper
- Uses standard normalize_locs function instead of manual normalization
- Uses NormalizationParams class for denormalization
- Reduced cyclomatic complexity by extracting helper methods

BREAKING CHANGE: EnsembleLocator is now deprecated in favor of Locator's
ensemble methods (train_ensemble, predict_ensemble). The old API still
works but shows deprecation warnings.
Phase 1: Create EnsembleMixin with modern patterns
- Created EnsembleMixin with modern patterns (IndexSet, tf.data pipeline)
- Added k_fold_split method to IndexSet for efficient fold creation
- Refactored EnsembleLocator as a legacy compatibility wrapper
- Uses standard normalize_locs function instead of manual normalization
- Uses NormalizationParams class for denormalization
- Reduced cyclomatic complexity by extracting helper methods

Phase 2: Memory efficiency and model management
- Implemented _train_single_fold method to avoid creating separate Locator instances
- Created EnsembleModelManager for efficient model storage and lazy loading
- Fixed _create_model signature to use input_shape parameter
- Fixed save_fold_models parameter passing through method chain
- Made JSON serialization robust by filtering out DataFrames from config

Test consolidation:
- Consolidated test_ensemble_mixin.py and test_ensemble_phase2.py into test_ensemble.py
- All 12 tests passing with comprehensive coverage of both phases

Key improvements:
- Memory-efficient data handling without array copies
- Consistent NA handling with na_action parameter
- Integration with modern tf.data pipeline
- Backward compatibility through legacy wrapper
- Efficient model management with lazy loading support
- Comprehensive test coverage for both phases

BREAKING CHANGE: EnsembleLocator is now deprecated in favor of Locator's
ensemble methods (train_ensemble, predict_ensemble). The old API still
works but shows deprecation warnings.
Implemented comprehensive ensemble functionality for Locator with k-fold cross-validation
and advanced training optimizations.

Phase 1 - Core Ensemble Functionality:
- Added EnsembleMixin with train_ensemble() and predict_ensemble() methods
- Implemented memory-efficient k-fold splitting using IndexSet
- Support for NA sample handling during ensemble training
- Proper normalization parameter averaging across folds

Phase 2 - Model Persistence:
- Created EnsembleModelManager for efficient model storage/loading
- Memory-efficient prediction without loading all models at once
- Metadata tracking for ensemble configuration and fold information
- Support for on-demand model loading during prediction

Phase 4 - Training Improvements:
- Mixed precision training support via GPUOptimizer integration
- Automatic batch size optimization for each fold
- Enhanced callbacks with patience multiplier for ensemble training
- Per-fold learning rate variation for improved diversity
- Memory clearing between folds to prevent OOM errors

Architecture:
- All functionality consolidated in ensemble_mixin.py for maintainability
- Reuses existing Locator infrastructure (tf.data pipeline, GPU optimizer)
- Maintains backward compatibility with standard Locator interface
- Comprehensive test suite with 15 tests covering all functionality

Performance:
- Memory-efficient training without creating separate Locator instances
- GPU optimizations automatically applied when available
- Efficient prediction pipeline with on-demand model loading
- Proper memory management between fold training

This refactoring enables robust ensemble predictions while maintaining
code clarity and performance efficiency.
Implemented comprehensive ensemble functionality for Locator with k-fold cross-validation,
advanced training optimizations, and parallel GPU execution.
- Added module-level skip decorator when Ray is not installed
- Fixed mock patches to use 'ray' directly instead of module path
- Added checks for stub functions in signature tests
- Fixed unused variable warnings
- Removed duplicate test file
@andrewkern andrewkern merged commit 45f2338 into main Jul 7, 2025
6 checks passed
@andrewkern andrewkern deleted the feature/ensemble-refactoring branch July 7, 2025 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant