[WIP] Refactor Training Runtime and Simplify Backend/Trainer Hierarchy#536
[WIP] Refactor Training Runtime and Simplify Backend/Trainer Hierarchy#536Xiaoming-AMD wants to merge 13 commits intomainfrom
Conversation
- Merge TrainerComponent into BaseTrainer - Remove redundant primus_config/module_config from trainers - Remove verbose param from setup_backend_path - Remove setup_sys_path interface (use setup_backend_path override) - Remove unused model/optimizer/opt_param_scheduler attributes - Remove empty __init__ in MegatronPretrainTrainer - Remove detect_version from BaseTrainer (handled by adapter) - Use AST parsing for version detection (avoid __init__ execution) - Add _log_step helper for trainer lifecycle logging - Remove framework/model validation in BaseTrainer - Make _patch_parse_args fail fast (no try-except)
- Rename run_train() to train() in BaseTrainer and all subclasses - Simplify convert_config() to accept params directly instead of module_config - Move model_provider.py to primus/modules/trainer/megatron/ - Simplify stage detection in train_runtime.py - Fix merge_namespace to handle None excepts parameter - Update unit tests to match interface changes
- Add patch to skip megatron.training.initialize._compile_dependencies - Ensures ROCm compatibility by bypassing CUDA-specific compilation - Add unit tests for the patch
- Simplify detect_backend_version using AST parsing (consistent with megatron_adapter) - Add module_config to run_patches in MegatronBridgeBaseTrainer - Update example config
There was a problem hiding this comment.
Pull request overview
This PR refactors the training runtime lifecycle by moving patch orchestration from individual components (BaseTrainer, BackendAdapter) to the centralized PrimusRuntime. The core change is that trainers become simpler objects focused on training logic, while runtime owns the complete lifecycle and patch phase management.
Changes:
- Moved patch application (setup, build_args, before_train, after_train) from BaseTrainer/BackendAdapter to PrimusRuntime
- Simplified BaseTrainer to a minimal ABC focused on setup/init/train/cleanup lifecycle hooks
- Updated BackendAdapter to own backend path setup logic (removed from BackendRegistry)
- Changed backend version detection to use AST parsing instead of imports to avoid executing init.py
- Updated trainers to receive only backend_args (removed primus_config and module_config parameters)
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit_tests/core/trainer/test_base_trainer.py | Removed patch-related test cases; updated to test simplified trainer.run() -> train() delegation |
| tests/unit_tests/core/runtime/test_train_runtime.py | Added tests for runtime patch orchestration; updated to reflect new initialization flow |
| tests/unit_tests/core/patches/test_patch_runner.py | Updated test expectations for patch priority-based sorting |
| tests/unit_tests/core/backend/test_backend_registry.py | Removed path-name and setup_backend_path tests; added version detection guards for trainer class registry |
| tests/unit_tests/core/backend/test_backend_adapter.py | Removed create_trainer orchestration tests; added adapter.setup_backend_path tests |
| tests/unit_tests/backends/megatron/test_runtime_hooks_patches.py | New test for runtime hook patch idempotency |
| tests/unit_tests/backends/megatron/test_megatron_registration.py | Removed path-name registration tests; added skip for missing trainer-class registry |
| tests/unit_tests/backends/megatron/test_megatron_pretrain_trainer.py | Updated test names from run_train to train |
| tests/unit_tests/backends/megatron/test_megatron_adapter.py | Updated version detection tests to validate AST parsing without import |
| primus/modules/trainer/megatron/trainer.py | Fixed import path for primus_model_provider |
| primus/modules/trainer/megatron/model_provider.py | Removed commented-out code |
| primus/core/utils/yaml_utils.py | Changed merge_namespace to skip duplicates instead of raising error |
| primus/core/trainer/trainer_component.py | Removed entire file (functionality merged into BaseTrainer) |
| primus/core/trainer/base_trainer.py | Simplified to minimal ABC with setup/init/train/cleanup; removed run() template method and patch orchestration |
| primus/core/runtime/train_runtime.py | Added patch orchestration methods; trainer initialization now includes patch phases |
| primus/core/patches/patch_runner.py | Added priority-based sorting of patches |
| primus/core/patches/context.py | Updated comment to reflect runtime ownership of patch phases |
| primus/core/backend/backend_registry.py | Removed path-name mapping and setup_backend_path; commented out trainer class registry |
| primus/core/backend/backend_adapter.py | Added setup_backend_path method; removed create_trainer and patch orchestration |
| primus/backends/torchtitan/*.py | Updated trainer signatures and adapter to match new pattern |
| primus/backends/megatron_bridge/*.py | Updated trainer signatures and adapter to match new pattern |
| primus/backends/megatron/patches/runtime_hooks_patches.py | New patch file for Megatron runtime hooks |
| primus/backends/megatron/*.py | Updated trainer signatures and adapter to use AST-based version detection |
| docs/README.md | Added link to ADDING-MODEL-AND-BACKEND.md |
Comments suppressed due to low confidence (1)
primus/backends/megatron_bridge/megatron_bridge_base_trainer.py:1
- The code references
self.model_namebut BaseTrainer no longer has this attribute after the refactor. The BaseTrainer.init signature changed and model_name is not set.
###############################################################################
| if key in dst_dict and not allow_override: | ||
| raise ValueError(f"Key '{key}' from {src.name} already exists in {dst.name}.") | ||
| else: | ||
| setattr(dst, key, value) | ||
| continue # Skip duplicate keys, keep dst value |
There was a problem hiding this comment.
The behavior change from raising ValueError to silently skipping duplicate keys is significant and could hide configuration errors. Consider adding a logging statement when keys are skipped to maintain visibility of potential configuration conflicts.
primus/core/runtime/train_runtime.py
Outdated
| assert ( | ||
| self.ctx is not None and self.ctx.adapter is not None | ||
| ), "Backend adapter must be loaded before creating trainer." | ||
| self.ctx.primus_config |
There was a problem hiding this comment.
This line appears to be orphaned code that evaluates primus_config but does nothing with it. Either remove it or add a comment explaining its purpose.
| self.ctx.primus_config |
| assert trainer.opt_param_scheduler is None | ||
|
|
||
| def test_run_train_invokes_megatron_pretrain_with_expected_arguments( | ||
| def test_train_invokes_megatron_pretrain_with_expected_arguments( |
There was a problem hiding this comment.
The test helper _build_trainer stubs out MegatronBaseTrainer.__init__ with a signature that includes primus_config and module_config, but the actual implementation now only accepts backend_args. This test needs to be updated to match the new trainer signature.
| return resolved | ||
|
|
||
| # 3) Default: <repo_root>/third_party/<dir_name> must exist. | ||
| dir_name = self.third_party_dir_name or self.framework |
There was a problem hiding this comment.
The fallback to self.framework for third_party_dir_name may not match the actual directory structure. For example, 'megatron' framework uses 'Megatron-LM' directory. Consider requiring subclasses to explicitly set third_party_dir_name to avoid incorrect path assumptions.
| def init(self): | ||
| """Initialize Megatron training components.""" | ||
| log_rank_0("Initializing Megatron training...") | ||
| # log_dict_aligned("Backend arguments", self.backend_args) |
There was a problem hiding this comment.
Remove commented-out code or explain why it's kept for future reference.
| # log_dict_aligned("Backend arguments", self.backend_args) |
- Remove unused detect_version method from MegatronBridgeBaseTrainer - Remove commented out TrainerClass registration code from BackendRegistry - Fix megatron unit tests to match new interface (convert_config, load_trainer_class) - Update test mocks for new constructor signatures
- Update test_train_runtime.py: fix DummyTrainer signature and method names - Update test_base_trainer.py: rewrite tests for new BaseTrainer interface - Remove obsolete tests for removed run(), detect_version(), validation - Add new tests for backend_args, lifecycle methods, distributed env
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 33 out of 33 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
primus/backends/megatron_bridge/megatron_bridge_base_trainer.py:1
- Reference to undefined attribute
self.model_name. The newBaseTrainerinterface no longer storesmodule_config, soself.model_nameis not available. Either remove this line or extract model_name frombackend_args.
###############################################################################
| _SUPPORTS_TRAINER_CLASS_REGISTRY = all( | ||
| hasattr(registry_module.BackendRegistry, attr) | ||
| for attr in ("_trainer_classes", "register_trainer_class", "get_trainer_class", "has_trainer_class") | ||
| ) |
There was a problem hiding this comment.
The comment on line 46 states _trainer_classes: Dict[tuple, Type] = {} is present in the registry, but the code immediately below checks if this attribute exists using hasattr. This suggests the attribute is optional. Add a docstring or comment explaining when this feature is available and when it is not.
| resolved = _use_path(str(default_path), default_error) | ||
| return resolved |
There was a problem hiding this comment.
The return statement on line 108 is redundant since lines 83, 94, and 107 all return directly. Consider returning directly from _use_path on line 107 instead of assigning to a variable.
| resolved = _use_path(str(default_path), default_error) | |
| return resolved | |
| return _use_path(str(default_path), default_error) |
| pre = values.get("PRE_RELEASE") | ||
| return f"{values['MAJOR']}.{values['MINOR']}.{values['PATCH']}" + (str(pre) if pre else "") |
There was a problem hiding this comment.
KeyError will be raised if MAJOR, MINOR, or PATCH are missing from the parsed values. Add error handling or validation to ensure all required version components are present before formatting the version string.
| pre = values.get("PRE_RELEASE") | |
| return f"{values['MAJOR']}.{values['MINOR']}.{values['PATCH']}" + (str(pre) if pre else "") | |
| required_keys = {"MAJOR", "MINOR", "PATCH"} | |
| missing = required_keys.difference(values.keys()) | |
| if missing: | |
| raise RuntimeError( | |
| f"Invalid megatron/core/package_info.py: missing version component(s): " | |
| f\"{', '.join(sorted(missing))}\" | |
| ) | |
| pre = values.get("PRE_RELEASE") | |
| version = f"{values['MAJOR']}.{values['MINOR']}.{values['PATCH']}" | |
| return version + (str(pre) if pre else "") |
| if name == "__version__": | ||
| return ast.literal_eval(node.value) |
There was a problem hiding this comment.
The function parse_version will raise UnboundLocalError if version is not found, as it falls through without returning. Add an explicit raise statement after the loop to handle this case clearly.
| if not _SUPPORTS_TRAINER_CLASS_REGISTRY: | ||
| pytest.skip( | ||
| "BackendRegistry trainer-class registration API is not available; " | ||
| "skip megatron registration tests that depend on it.", | ||
| allow_module_level=True, | ||
| ) |
There was a problem hiding this comment.
The module-level skip will prevent all tests from running if the registry doesn't support trainer classes, but some tests (like test_adapter_is_registered) don't require this feature. Consider making the skip more granular by applying it only to tests that actually need trainer class registry.
- HummingbirdXT: update adapter interface, merge base trainer into posttrain trainer - TorchTitan: merge base trainer into pretrain trainer - Update convert_config signature to accept params directly - Simplify detect_backend_version implementations - Remove redundant base trainer classes
- Add distributed context fields to TrainContext dataclass - Fix corrupted comment line - Remove dead/commented code - Replace print() with log_rank_0() for consistent logging - Remove unused parameter from _apply_overrides
| ): | ||
| with self.assertRaises(ValueError) as ctx: | ||
| runtime._initialize_backend() | ||
| runtime._initialize_adapter() |
There was a problem hiding this comment.
The method was renamed from _initialize_backend() to _initialize_adapter(), but the test name test_initialize_backend_wraps_adapter_errors wasn't updated to match. Consider renaming to test_initialize_adapter_wraps_errors for consistency.
| if getattr(megatron_initialize._compile_dependencies, "_primus_patched", False): | ||
| return |
There was a problem hiding this comment.
The idempotency check uses a custom _primus_patched attribute. This pattern is repeated across multiple patches. Consider extracting this into a decorator or utility function to avoid duplication and ensure consistent patch idempotency behavior.
|
|
||
| # (Backend, Stage) → TrainerClass | ||
| _trainer_classes: Dict[tuple, Type] = {} | ||
| # _trainer_classes: Dict[tuple, Type] = {} |
There was a problem hiding this comment.
The commented-out _trainer_classes dictionary should be removed rather than commented out, as trainer class registration has been moved to adapters. Leaving commented code can cause confusion about whether the feature is deprecated or temporarily disabled.
| assert "megatron" not in sys.modules, "megatron should not be imported" | ||
| assert "megatron.core" not in sys.modules, "megatron.core should not be imported" | ||
| assert ( | ||
| "megatron.core.package_info" not in sys.modules | ||
| ), "megatron.core.package_info should not be imported" |
There was a problem hiding this comment.
These assertions verify that AST parsing doesn't import modules, which is good. However, the test doesn't verify that the AST parsing correctly handles edge cases like missing PRE_RELEASE or malformed version strings. Consider adding test cases for these scenarios.
| # Register posttrain trainer as the default trainer | ||
| # Megatron-Bridge is designed for post-training tasks (SFT, instruction tuning, LoRA) | ||
| BackendRegistry.register_trainer_class(MegatronBridgePosttrainTrainer, "megatron_bridge", "sft") | ||
| # BackendRegistry.register_trainer_class(MegatronBridgePosttrainTrainer, "megatron_bridge", "sft") |
There was a problem hiding this comment.
This commented-out trainer registration should be removed. The comment above indicates that trainer class loading has been moved to the adapter, so this legacy code is no longer needed.
| log_dict_aligned("Primus-specific parameters", primus_only_params) | ||
|
|
||
| # Merge backend_args into params (backend_args overrides params) | ||
| merge_namespace(backend_args, module_config.params, allow_override=False, excepts=[]) |
There was a problem hiding this comment.
Passing an empty list excepts=[] is redundant since merge_namespace now defaults excepts to an empty list. This parameter can be omitted for cleaner code.
| merge_namespace(backend_args, module_config.params, allow_override=False, excepts=[]) | |
| merge_namespace(backend_args, module_config.params, allow_override=False) |
- Fix conftest.py to add project root before Megatron-LM - Use append instead of insert(0) for Megatron-LM path - Update test_train_runtime.py to match new _apply_overrides signature
| BackendRegistry.register_adapter("megatron_bridge", MegatronBridgeAdapter) | ||
|
|
||
| # Register posttrain trainer as the default trainer | ||
| # Megatron-Bridge is designed for post-training tasks (SFT, instruction tuning, LoRA) |
There was a problem hiding this comment.
This commented-out registration suggests incomplete refactoring. Either remove it completely or add a comment explaining why it's being kept.
| # Megatron-Bridge is designed for post-training tasks (SFT, instruction tuning, LoRA) | |
| # Megatron-Bridge is designed for post-training tasks (SFT, instruction tuning, LoRA) | |
| # NOTE: Trainer registration is intentionally disabled for now; kept as reference for potential future enablement. |
| # Create module_config from backend_args for patch context | ||
| module_config = SimpleNamespace(params=self.backend_args) |
There was a problem hiding this comment.
Creating a synthetic module_config from backend_args to satisfy patch requirements is a code smell. This suggests the patch system still expects the old architecture. Consider updating patches to work with the new architecture where trainers only have backend_args.
- Replace log_rank_0() with print() in _apply_overrides and _initialize_distributed_context - These are called before _initialize_logging(), so logger is still None
| # Phase: build_args (after args creation, before trainer instantiation) | ||
| self._run_phase_patches(phase="build_args", backend_args=backend_args) |
There was a problem hiding this comment.
setup phase patches are applied after build_args patches and after the trainer is instantiated, but setup is documented/defined as occurring before building args and trainer construction. This can break patches that need to adjust environment/config before convert_config() or trainer creation. Move self._run_phase_patches(phase=\"setup\", ...) earlier (e.g., right after adapter setup / before prepare_backend and convert_config), and update the patch-order test accordingly.
| # Load trainer class and instantiate | ||
| stage = getattr(module_config.params, "stage", "pretrain") or "pretrain" | ||
| TrainerClass = adapter.load_trainer_class(stage=stage) | ||
| trainer = TrainerClass(backend_args=backend_args) |
There was a problem hiding this comment.
setup phase patches are applied after build_args patches and after the trainer is instantiated, but setup is documented/defined as occurring before building args and trainer construction. This can break patches that need to adjust environment/config before convert_config() or trainer creation. Move self._run_phase_patches(phase=\"setup\", ...) earlier (e.g., right after adapter setup / before prepare_backend and convert_config), and update the patch-order test accordingly.
|
|
||
| # 1) Optional setup phase | ||
| trainer.setup() | ||
| self._run_phase_patches(phase="setup", backend_args=self.ctx.backend_args) |
There was a problem hiding this comment.
setup phase patches are applied after build_args patches and after the trainer is instantiated, but setup is documented/defined as occurring before building args and trainer construction. This can break patches that need to adjust environment/config before convert_config() or trainer creation. Move self._run_phase_patches(phase=\"setup\", ...) earlier (e.g., right after adapter setup / before prepare_backend and convert_config), and update the patch-order test accordingly.
| self._run_phase_patches(phase="setup", backend_args=self.ctx.backend_args) |
| # Logger may not be initialized yet; sys.path is already updated. | ||
| pass | ||
| return norm_path | ||
| assert False, error_msg |
There was a problem hiding this comment.
Using assert False, error_msg for runtime error handling is unsafe because assertions can be stripped with Python optimizations (-O), potentially causing this function to return None and fail later in harder-to-debug ways. Raise a real exception (e.g., FileNotFoundError or RuntimeError) instead of asserting.
| assert False, error_msg | |
| raise FileNotFoundError(error_msg) |
| # These are likely Primus-specific parameters. | ||
| params_dict = nested_namespace_to_dict(module_config.params) | ||
| config_keys = set(params_dict.keys()) | ||
| backend_keys = set(vars(backend_args)) |
There was a problem hiding this comment.
BackendAdapter.convert_config() is typed to return Any and its docstring allows SimpleNamespace or dict, but the runtime unconditionally calls vars(backend_args), which will raise TypeError for dict (and other non-namespace values). Either (a) tighten the contract to require a namespace-like object (argparse/SimpleNamespace) and validate it here with a clear error, or (b) support dicts by using backend_args.keys() when isinstance(backend_args, dict).
| backend_keys = set(vars(backend_args)) | |
| if isinstance(backend_args, dict): | |
| backend_keys = set(backend_args.keys()) | |
| else: | |
| backend_keys = set(vars(backend_args)) |
| except Exception: | ||
| self.ctx.backend_version = None |
There was a problem hiding this comment.
Swallowing all exceptions from detect_backend_version() and silently setting the version to None makes version-gated patch behavior hard to diagnose (patches may unexpectedly not apply). At minimum, log a warning with the exception details; consider re-raising for backends where version detection is required, or use an explicit sentinel like 'unknown' to make the state visible in patch logs.
| except Exception: | |
| self.ctx.backend_version = None | |
| except Exception as exc: | |
| warning_rank_0( | |
| f"[Runtime] Failed to detect backend version for backend '{self.ctx.framework}' " | |
| f"via adapter {type(self.ctx.adapter).__name__}: {exc}. " | |
| "Proceeding with backend_version='unknown'." | |
| ) | |
| self.ctx.backend_version = "unknown" |
| raise ValueError(f"Key '{key}' from {src.name} already exists in {dst.name}.") | ||
| else: | ||
| setattr(dst, key, value) | ||
| continue # Skip duplicate keys, keep dst value |
There was a problem hiding this comment.
merge_namespace() previously raised on duplicate keys when allow_override=False, but it now silently skips duplicates. This is a behavior change that can hide configuration mistakes (e.g., unintended duplicate keys) and makes debugging harder. If the new behavior is intended, consider updating the function docstring/name (or adding an explicit on_conflict= policy) and adjusting callers/tests accordingly; otherwise, restore the exception and pass an excepts list at call sites that want to skip.
| continue # Skip duplicate keys, keep dst value | |
| # Restore strict behavior: disallow silent duplicate keys unless explicitly excepted | |
| raise AssertionError(f"Not allowed to override key({key}) in namespace({dst})") |
| run_patches( | ||
| backend="megatron", | ||
| phase="before_train", | ||
| backend_version=type(self).detect_megatron_version(), | ||
| model_name=self.model_name, | ||
| extra={ | ||
| "module_config": module_config, | ||
| "backend_args": self.backend_args, | ||
| "primus_config": self.primus_config, | ||
| "module_config": self.module_config, | ||
| }, | ||
| ) |
There was a problem hiding this comment.
This trainer base class is still executing patch phases during construction, which conflicts with the PR goal of centralizing patch orchestration in PrimusRuntime. It can also lead to duplicated or inconsistent patch application (runtime will later run before_train for backend=megatron_bridge, while this runs backend=megatron). Consider moving this into runtime/adapter orchestration by introducing a backend aliasing mechanism (e.g., allow patches to target multiple backends) or registering the relevant patches for megatron_bridge, so trainers don't directly call run_patches().
| # Core lifecycle phases used by Primus | ||
| # NOTE: | ||
| # These should stay in sync with the phases used by: | ||
| # - BackendAdapter._apply_setup_patches() / _apply_build_args_patches() | ||
| # - PrimusRuntime (setup / build_args / before_train / after_train) | ||
| # - BaseTrainer.run() (before_train / after_train) |
There was a problem hiding this comment.
This comment is now out of date: BaseTrainer.run() was removed in this refactor and patch orchestration moved to PrimusRuntime. Update the note to avoid referencing BaseTrainer.run() and to describe the current single owner of patch phases.
Pull request overview
This PR refactors the training runtime lifecycle by moving patch orchestration from individual components (BaseTrainer, BackendAdapter) to the centralized PrimusRuntime. The core change is that trainers become simpler objects focused on training logic, while runtime owns the complete lifecycle and patch phase management.
Changes: