feat: complete ensemble refactoring with parallel training support#15
Merged
Conversation
- Created EnsembleMixin with modern patterns (IndexSet, tf.data pipeline) - Added k_fold_split method to IndexSet for efficient fold creation - Refactored EnsembleLocator as a legacy compatibility wrapper - Added comprehensive tests for ensemble functionality - Integrated EnsembleMixin into core Locator class Key improvements: - Memory-efficient data handling without array copies - Consistent NA handling with na_action parameter - Integration with modern tf.data pipeline - Backward compatibility through legacy wrapper - Uses standard normalize_locs function instead of manual normalization - Uses NormalizationParams class for denormalization - Reduced cyclomatic complexity by extracting helper methods BREAKING CHANGE: EnsembleLocator is now deprecated in favor of Locator's ensemble methods (train_ensemble, predict_ensemble). The old API still works but shows deprecation warnings.
Phase 1: Create EnsembleMixin with modern patterns - Created EnsembleMixin with modern patterns (IndexSet, tf.data pipeline) - Added k_fold_split method to IndexSet for efficient fold creation - Refactored EnsembleLocator as a legacy compatibility wrapper - Uses standard normalize_locs function instead of manual normalization - Uses NormalizationParams class for denormalization - Reduced cyclomatic complexity by extracting helper methods Phase 2: Memory efficiency and model management - Implemented _train_single_fold method to avoid creating separate Locator instances - Created EnsembleModelManager for efficient model storage and lazy loading - Fixed _create_model signature to use input_shape parameter - Fixed save_fold_models parameter passing through method chain - Made JSON serialization robust by filtering out DataFrames from config Test consolidation: - Consolidated test_ensemble_mixin.py and test_ensemble_phase2.py into test_ensemble.py - All 12 tests passing with comprehensive coverage of both phases Key improvements: - Memory-efficient data handling without array copies - Consistent NA handling with na_action parameter - Integration with modern tf.data pipeline - Backward compatibility through legacy wrapper - Efficient model management with lazy loading support - Comprehensive test coverage for both phases BREAKING CHANGE: EnsembleLocator is now deprecated in favor of Locator's ensemble methods (train_ensemble, predict_ensemble). The old API still works but shows deprecation warnings.
Implemented comprehensive ensemble functionality for Locator with k-fold cross-validation and advanced training optimizations. Phase 1 - Core Ensemble Functionality: - Added EnsembleMixin with train_ensemble() and predict_ensemble() methods - Implemented memory-efficient k-fold splitting using IndexSet - Support for NA sample handling during ensemble training - Proper normalization parameter averaging across folds Phase 2 - Model Persistence: - Created EnsembleModelManager for efficient model storage/loading - Memory-efficient prediction without loading all models at once - Metadata tracking for ensemble configuration and fold information - Support for on-demand model loading during prediction Phase 4 - Training Improvements: - Mixed precision training support via GPUOptimizer integration - Automatic batch size optimization for each fold - Enhanced callbacks with patience multiplier for ensemble training - Per-fold learning rate variation for improved diversity - Memory clearing between folds to prevent OOM errors Architecture: - All functionality consolidated in ensemble_mixin.py for maintainability - Reuses existing Locator infrastructure (tf.data pipeline, GPU optimizer) - Maintains backward compatibility with standard Locator interface - Comprehensive test suite with 15 tests covering all functionality Performance: - Memory-efficient training without creating separate Locator instances - GPU optimizations automatically applied when available - Efficient prediction pipeline with on-demand model loading - Proper memory management between fold training This refactoring enables robust ensemble predictions while maintaining code clarity and performance efficiency.
Implemented comprehensive ensemble functionality for Locator with k-fold cross-validation, advanced training optimizations, and parallel GPU execution.
- Added module-level skip decorator when Ray is not installed - Fixed mock patches to use 'ray' directly instead of module path - Added checks for stub functions in signature tests - Fixed unused variable warnings - Removed duplicate test file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat: complete ensemble refactoring with parallel training support
Summary
Refactored ensemble functionality from separate
EnsembleLocatorclass into integratedEnsembleMixin, adding k-fold cross-validation with parallel GPU support.Key Changes
EnsembleMixin-train_ensemble()andpredict_ensemble()methodsparallel_train_ensemble()distributes folds across multiple GPUsEnsembleModelManagerfor efficient model persistence and on-demand loadingTesting
Documentation
docs/source/ensemble_guide.rst)Breaking Changes
EnsembleLocatorclass removed - useLocatorwith ensemble methodstrain()→train_ensemble(),predict()→predict_ensemble()