Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,15 @@ class MegatronEngineConfig:
# FP8 Training Configuration
fp8_config: FP8EngineConfig | None = None

# Bridge backend used for HF<->Megatron conversion/model creation.
bridge_type: str = field(
default="mbridge",
metadata={
"help": "Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'.",
"choices": ["mbridge", "megatron-bridge"],
},
)


class SchedulingStrategyType(str, Enum):
separation = "separation"
Expand Down
131 changes: 86 additions & 45 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import mbridge
import torch
import torch.distributed as dist
from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
Expand Down Expand Up @@ -167,6 +168,7 @@ def __init__(self, config: TrainEngineConfig):
self.fp8_config.direct_convert if self.enable_fp8 else False
)
self.quantization_config: dict[str, int | str | list[str]] | None = None
self.bridge_cls: str = getattr(self.mcore_config, "bridge_type", "mbridge")

def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
if parallel_strategy is None:
Expand Down Expand Up @@ -238,47 +240,37 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):

self.tokenizer = load_hf_tokenizer(self.config.path)

with patch_bridge_for_tree_training(self.enable_tree_training):
self.bridge = mbridge.AutoBridge.from_pretrained(self.config.path)
self.bridge.dtype = self.dtype
# Set gradient checkpointing options
if self.config.gradient_checkpointing:
self.bridge.set_extra_args(
recompute_granularity=self.mcore_config.recompute_granularity,
recompute_method=self.mcore_config.recompute_method,
recompute_num_layers=self.mcore_config.recompute_num_layers,
distribute_saved_activations=self.mcore_config.distribute_saved_activations,
recompute_modules=self.mcore_config.recompute_modules,
)

self.logger.info(
"Using mbridge to create models and hf model save/load in MegatronEngine."
)
with patch_bridge_for_tree_training(
self.enable_tree_training and self.bridge_cls == "mbridge"
):
self.bridge = self._build_hf_mcore_bridge()

self.hf_config, self.tf_config = make_hf_and_mcore_config(
self.config.path, dtype=self.dtype, bridge=self.bridge
self.config.path,
dtype=self.dtype,
bridge=self.bridge,
bridge_type=self.bridge_cls,
)
self.tf_config = configure_pipeline_layer_splits(
self.parallel_strategy, self.hf_config, self.tf_config
)

# Get quantization_config from hf_config if available (for FP8 weight updates)
self.quantization_config = getattr(
self.hf_config, "quantization_config", None
)

self._check_and_apply_fp8_config()
self._validate_fp8_consistency()

# initialize mcore (DDP Wrapped) GPTModel
with self.device:
models = make_mcore_model(
hf_config=self.hf_config,
tf_config=self.tf_config,
mcore_config=self.mcore_config,
bridge=self.bridge,
is_critic=self.config.is_critic,
)
with self.device:
models = make_mcore_model(
hf_config=self.hf_config,
tf_config=self.tf_config,
mcore_config=self.mcore_config,
bridge=self.bridge,
bridge_type=self.bridge_cls,
is_critic=self.config.is_critic,
)

self.model = _MegatronModelList(models)

Expand Down Expand Up @@ -324,6 +316,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):

primary_model = self.model[0]
model_config = get_model_config(primary_model)

# NOTE: It is recommended to set this option to True for RL training on MoE models for stability.
if self.mcore_config.use_deterministic_algorithms:
set_deterministic_algorithms(model_config)
Expand Down Expand Up @@ -358,6 +351,43 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
self._create_optimizer(ft_spec)
self._initialized = True

def _build_hf_mcore_bridge(self):
if self.bridge_cls == "mbridge":
self.bridge = mbridge.AutoBridge.from_pretrained(self.config.path)
self.bridge.dtype = self.dtype
if self.config.gradient_checkpointing:
self.bridge.set_extra_args(
recompute_granularity=self.mcore_config.recompute_granularity,
recompute_method=self.mcore_config.recompute_method,
recompute_num_layers=self.mcore_config.recompute_num_layers,
distribute_saved_activations=self.mcore_config.distribute_saved_activations,
recompute_modules=self.mcore_config.recompute_modules,
)
self.logger.info(
"Using mbridge to create models and hf model save/load in MegatronEngine."
)

elif self.bridge_cls == "megatron-bridge":
if self.enable_tree_training:
raise NotImplementedError(
"Tree training is not supported with bridge_type='megatron-bridge'."
)
self.bridge = MegatronBridgeAutoBridge.from_hf_pretrained(
self.config.path,
trust_remote_code=True,
dtype=self.config.dtype,
)
self.logger.info(
"Using megatron-bridge to create models and hf model save/load in MegatronEngine."
)

else:
self.logger.info(
"Not using bridge to create models and hf model save/load in MegatronEngine."
)
self.bridge = None
return self.bridge

@property
def initialized(self) -> bool:
return self._initialized
Expand Down Expand Up @@ -1347,16 +1377,23 @@ def _save_model_to_hf(
assert self.model is not None, "Model is not initialized."
os.makedirs(path, exist_ok=True)

save_weights_to_hf_with_mbridge_fast(
bridge=self.bridge,
models=self.model,
weights_path=path,
base_model_path=base_model_path,
max_shard_size_byte=int(3e9),
max_workers=None,
is_critic=self.config.is_critic,
fp8_direct_convert=self.fp8_direct_convert,
)
if self.bridge_cls == "megatron-bridge":
self.bridge.save_hf_pretrained(
self.model,
path,
source_path=base_model_path,
)
else:
save_weights_to_hf_with_mbridge_fast(
bridge=self.bridge,
models=self.model,
weights_path=path,
base_model_path=base_model_path,
max_shard_size_byte=int(3e9),
max_workers=None,
is_critic=self.config.is_critic,
fp8_direct_convert=self.fp8_direct_convert,
)

if dist.get_rank() == 0:
if tokenizer is not None:
Expand All @@ -1369,14 +1406,18 @@ def _save_model_to_hf(

def _load_model_from_hf(self, path: str) -> None:
assert self.model is not None, "Model is not initialized."
load_weights_from_hf_with_mbridge_fast(
bridge=self.bridge,
models=self.model,
weights_path=path,
max_workers=None,
is_critic=self.config.is_critic,
fp8_direct_convert=self.fp8_direct_convert,
)

if self.bridge_cls == "megatron-bridge":
self.bridge.load_hf_weights(self.model, hf_path=path)
else:
load_weights_from_hf_with_mbridge_fast(
bridge=self.bridge,
models=self.model,
weights_path=path,
max_workers=None,
is_critic=self.config.is_critic,
fp8_direct_convert=self.fp8_direct_convert,
)

def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
Expand Down
89 changes: 85 additions & 4 deletions areal/models/mcore/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
from typing import Any

import torch
from mbridge.core.bridge import Bridge
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import DistributedDataParallelConfig as MCoreDDPConfig
Expand Down Expand Up @@ -96,12 +98,20 @@ def unwrap_to_gpt_model(model: torch.nn.Module) -> GPTModel:

# Model registry for different architectures
def make_hf_and_mcore_config(
hf_path: str, dtype: torch.dtype, bridge=None
hf_path: str,
dtype: torch.dtype,
bridge=None,
bridge_type: str = "mbridge",
) -> tuple[PretrainedConfig, TransformerConfig]:
if bridge is not None:
if bridge is not None and bridge_type == "mbridge":
hf_config = bridge.hf_config
hf_config._name_or_path = hf_path
return hf_config, bridge.config
elif bridge is not None and bridge_type == "megatron-bridge":
hf_config = getattr(bridge.hf_pretrained, "config", bridge.hf_pretrained)
if hasattr(hf_config, "_name_or_path"):
hf_config._name_or_path = hf_path
return hf_config, bridge.transformer_config
else:
hf_config: PretrainedConfig = AutoConfig.from_pretrained(
pretrained_model_name_or_path=hf_path,
Expand Down Expand Up @@ -132,10 +142,11 @@ def make_mcore_model(
hf_config: PretrainedConfig,
tf_config: TransformerConfig,
mcore_config: MegatronEngineConfig | None = None,
bridge: Bridge | None = None,
bridge: Bridge | Any | None = None,
bridge_type: str = "mbridge",
is_critic: bool = False,
) -> list[GPTModel | DDP]:
if bridge is not None:
if bridge is not None and bridge_type == "mbridge":
models = bridge.get_model(
# TODO: Add DDP options when supporting training
wrap_with_ddp=mcore_config.wrap_with_ddp,
Expand All @@ -156,6 +167,76 @@ def make_mcore_model(
_replace_output_layer_with_value_head(_model, tf_config)

return models

if bridge is not None and bridge_type == "megatron-bridge":
provider = bridge.to_megatron_provider(load_weights=False)
vpp_size = mcore_config.virtual_pipeline_parallel_size or 0

provider.tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
provider.pipeline_model_parallel_size = (
mpu.get_pipeline_model_parallel_world_size()
)
provider.virtual_pipeline_model_parallel_size = (
vpp_size if vpp_size > 1 else None
)
provider.context_parallel_size = mpu.get_context_parallel_world_size()
provider.expert_model_parallel_size = mpu.get_expert_model_parallel_world_size()
provider.expert_tensor_parallel_size = (
mpu.get_expert_tensor_parallel_world_size()
)
provider.sequence_parallel = mpu.get_tensor_model_parallel_world_size() > 1
provider.pipeline_dtype = tf_config.params_dtype

provider.recompute_granularity = mcore_config.recompute_granularity
provider.recompute_method = mcore_config.recompute_method
provider.recompute_num_layers = mcore_config.recompute_num_layers
provider.distribute_saved_activations = (
mcore_config.distribute_saved_activations
)
provider.recompute_modules = mcore_config.recompute_modules

provider.account_for_embedding_in_pipeline_split = False
provider.account_for_loss_in_pipeline_split = False

# Keep these four flags aligned with mbridge base defaults.
provider.variable_seq_lengths = True
logger.warning(
"Ignoring mcore_config.moe_token_dispatcher_type=%s for bridge_type='megatron-bridge'; "
"using 'alltoall' and variable_seq_lengths=True.",
mcore_config.moe_token_dispatcher_type,
)
provider.moe_token_dispatcher_type = "alltoall"
provider.batch_p2p_comm = False
provider.overlap_p2p_comm = (
vpp_size > 1 and provider.pipeline_model_parallel_size > 1
)

# Aligning tf config settings with provider for consistency.
tf_config.variable_seq_lengths = provider.variable_seq_lengths
tf_config.moe_token_dispatcher_type = provider.moe_token_dispatcher_type
tf_config.batch_p2p_comm = provider.batch_p2p_comm
tf_config.overlap_p2p_comm = provider.overlap_p2p_comm

provider.finalize()

models = provider.provide_distributed_model(
ddp_config=MCoreDDPConfig(**dataclasses.asdict(mcore_config.ddp)),
fp16=tf_config.fp16,
bf16=tf_config.bf16,
use_megatron_fsdp=mcore_config.use_custom_fsdp,
use_torch_fsdp2=mcore_config.use_torch_fsdp2,
wrap_with_ddp=mcore_config.wrap_with_ddp,
overlap_param_gather_with_optimizer_step=mcore_config.overlap_param_gather_with_optimizer_step,
)
models = list(models)

if is_critic:
for model in models:
_model = unwrap_to_gpt_model(model)
_replace_output_layer_with_value_head(_model, tf_config)

return models

else:
if (
mcore_config is not None
Expand Down
1 change: 1 addition & 0 deletions areal/tools/validate_docker_installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DockerInstallationValidator(BaseInstallationValidator):
"flash_attn_3",
"megatron-core",
"mbridge",
"megatron-bridge",
"causal_conv1d",
}

Expand Down
1 change: 1 addition & 0 deletions areal/tools/validation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class BaseInstallationValidator:
# and should be validated dynamically by subclasses
"megatron-core",
"mbridge",
"megatron-bridge",
"ray",
"datasets",
"hydra-core",
Expand Down
1 change: 1 addition & 0 deletions docs/en/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ parts:
- file: reference/checkpointing
- file: reference/metrics_tracking
- file: reference/alloc_mode
- file: reference/bridge_backend
- file: reference/tree_training
- file: reference/rollout_workflow
- file: reference/agent_workflow
Expand Down
37 changes: 37 additions & 0 deletions docs/en/reference/bridge_backend.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Megatron-HF Bridge Backend

AReaL currently supports two bridge backends for `MegatronEngine`:

- `mbridge` (default)
- `megatron-bridge`

Set the backend with:

```yaml
actor:
megatron:
bridge_type: mbridge
```

- Use `bridge_type=megatron-bridge` to enable the new path.
- `mbridge` is the default choice if this argument is not present

## Why this feature exists

- `mbridge` is being deprecated and does not provide PEFT/LoRA support.
- `megatron-bridge` supports more/ newer model architectures.
- `megatron-bridge` provides built-in PEFT/LoRA implementations.

## Recommendation

- For new GPU training workflows, prefer `megatron-bridge`.
- Keep `mbridge` for backward compatibility and environments that still depend on it.
- Prefer `mbridge` when using disk-based weight broadcast as it has optimized HF
load/save path.
- If you use XCCL for weight broadcast, load/save time is less important.

## Current limitation

- Tree-attention training in `MegatronEngine` currently supports only `mbridge`.
- The `megatron-bridge` backend is not supported in the tree-attention path yet.
- `megatron-bridge` does support faster/optimized HF model load/save implementations.
Loading
Loading