Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 17 additions & 84 deletions megatron/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@

import torch

from megatron.core.inference.config import (
CudaGraphSizingDistribution,
InferenceConfig,
KVCacheManagementMode,
MambaInferenceStateConfig,
PrefixCachingCoordinatorPolicy,
PrefixCachingEvictionPolicy,
)
from megatron.core.inference.contexts import DynamicInferenceContext
from megatron.core.inference.engines import DynamicInferenceEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
Expand All @@ -28,7 +20,7 @@
from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer
from megatron.core.transformer.enums import InferenceCudaGraphScope
from megatron.core.transformer.module import MegatronModule
from megatron.core.utils import get_attr_wrapped_model, log_single_rank, unwrap_model
from megatron.core.utils import log_single_rank, unwrap_model
from megatron.training import get_args
from megatron.training import get_model as _get_model
from megatron.training import get_tokenizer, get_wandb_writer
Expand Down Expand Up @@ -327,47 +319,20 @@ def add_inference_args(parser: ArgumentParser) -> ArgumentParser:


def get_inference_config_from_model_and_args(model: MegatronModule, args):
"""Returns a `InferenceConfig` constructed from the model and command line arguments."""

# Max sequence length.
position_embedding_type = get_attr_wrapped_model(model, "position_embedding_type")
model_max_seq_len = get_attr_wrapped_model(model, "max_sequence_length")
inf_max_seq_len = args.inference_max_seq_length
max_batch_size = args.inference_dynamic_batching_max_requests

if position_embedding_type == "learned_absolute":
# When using absolute position embeddings, it is critical that the
# context's `max_sequence_length` is less than or equal to the model's
# `max_sequence_length`. Otherwise, the context's `position_ids` will
# contain ids greater than the dimension of the position embedding
# tensor, which will result in an index error.
if inf_max_seq_len:
max_sequence_length = min(model_max_seq_len, inf_max_seq_len)
else:
max_sequence_length = model_max_seq_len
assert max_batch_size is None or max_batch_size <= model_max_seq_len
else:
max_sequence_length = inf_max_seq_len
if args.inference_dynamic_batching_max_requests is not None:
max_sequence_length = max(max_sequence_length, max_batch_size)
"""Returns an `InferenceConfig` constructed from the model and command line arguments.

mamba_inference_state_config = MambaInferenceStateConfig.from_model(
model,
conv_states_dtype=args.mamba_inference_conv_states_dtype,
ssm_states_dtype=args.mamba_inference_ssm_states_dtype,
)
pg_collection = get_attr_wrapped_model(model, "pg_collection")

# Get inference logging configuration from args
log_inference_wandb = args.inference_wandb_logging
inference_logging_step_interval = args.inference_logging_step_interval
Delegates to ``InferenceSetupConfig.to_inference_config`` so the declarative
``InferenceSetupConfig`` (built from args) is the single source of truth for translating
inference args into the runtime engine ``InferenceConfig``.
"""
from megatron.training.argument_utils import inference_cfg_from_args

# Get metrics writer if logging is enabled and on the logging rank
# Use the same rank convention as training (last rank logs)
# Get metrics writer if logging is enabled and on the logging rank.
# Use the same rank convention as training (last rank logs).
metrics_writer = None
if (
inference_logging_step_interval > 0
and log_inference_wandb
args.inference_logging_step_interval > 0
and args.inference_wandb_logging
and args.rank == (args.world_size - 1)
):
metrics_writer = get_wandb_writer()
Expand All @@ -379,47 +344,15 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args):
"wandb module is available. Inference logging will be disabled.",
)

return InferenceConfig(
verbose=True,
block_size_tokens=args.inference_dynamic_batching_block_size,
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
paused_buffer_size_gb=args.inference_dynamic_batching_paused_buffer_size_gb,
mamba_memory_ratio=args.inference_dynamic_batching_mamba_memory_ratio,
num_cuda_graphs=(
args.inference_dynamic_batching_num_cuda_graphs
if args.inference_cuda_graph_scope != InferenceCudaGraphScope.none
else None
),
max_requests=args.inference_dynamic_batching_max_requests,
max_tokens=args.inference_dynamic_batching_max_tokens,
unified_memory_level=args.inference_dynamic_batching_unified_memory_level,
kv_cache_management_mode=KVCacheManagementMode(args.rl_kv_cache_management_mode),
cuda_graph_mixed_prefill_count=args.inference_dynamic_batching_cuda_graph_mixed_prefill_count, # pylint: disable=line-too-long
cuda_graph_sizing_distribution=CudaGraphSizingDistribution(
args.inference_dynamic_batching_cuda_graph_sizing_distribution
),
use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs,
cuda_graph_all_prefills=args.inference_cuda_graph_all_prefills,
setup_cfg = inference_cfg_from_args(args)
return setup_cfg.to_inference_config(
model,
return_log_probs=args.return_log_probs,
kv_cache_management_mode=args.rl_kv_cache_management_mode,
static_kv_memory_pointers=args.rl_persist_cuda_graphs,
max_sequence_length=max_sequence_length,
mamba_inference_state_config=mamba_inference_state_config,
pg_collection=pg_collection,
use_flashinfer_fused_rope=args.use_flashinfer_fused_rope,
materialize_only_last_token_logits=(not args.return_log_probs),
track_generated_token_events=args.inference_dynamic_batching_track_generated_token_events,
track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events,
enable_chunked_prefill=args.enable_chunked_prefill,
enable_prefix_caching=args.inference_dynamic_batching_enable_prefix_caching,
prefix_caching_eviction_policy=PrefixCachingEvictionPolicy(args.inference_dynamic_batching_prefix_caching_eviction_policy),
prefix_caching_coordinator_policy=PrefixCachingCoordinatorPolicy(args.inference_dynamic_batching_prefix_caching_coordinator_policy),
prefix_caching_routing_alpha=getattr(args, 'inference_dynamic_batching_prefix_caching_routing_alpha', 0.5),
prefix_caching_mamba_gb=getattr(args, 'inference_dynamic_batching_prefix_caching_mamba_gb', None),
enable_cuda_graphs=(args.inference_cuda_graph_scope != InferenceCudaGraphScope.none),
metrics_writer=metrics_writer,
logging_step_interval=args.inference_logging_step_interval,
num_speculative_tokens=args.num_speculative_tokens,
use_synchronous_zmq_collectives=args.inference_use_synchronous_zmq_collectives,
disable_ep_consensus=args.inference_disable_ep_consensus,
sampling_backend=args.inference_dynamic_batching_sampling_backend,
)


Expand Down
61 changes: 61 additions & 0 deletions megatron/training/argument_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from megatron.training.config import (
DistributedInitConfig,
InferenceSetupConfig,
InferenceConfigContainer,
PretrainConfigContainer,
SchedulerConfig,
TokenizerConfig,
Expand Down Expand Up @@ -522,3 +524,62 @@ def pretrain_cfg_container_from_args(args: Namespace, model_cfg=None) -> Pretrai
)

return cfg


def inference_cfg_from_args(args: Namespace) -> InferenceSetupConfig:
"""Build an InferenceSetupConfig from the argparse arguments.

InferenceSetupConfig field names map one-to-one onto the argparse ``dest`` names produced
by ``_add_inference_args``, so this is a direct copy of the relevant values from ``args``.

This builds the declarative/serializable inference config. To obtain the runtime engine
config (``megatron.core.inference.config.InferenceConfig``), call
``inference_cfg_from_args(args).to_inference_config(model, ...)``.
"""
return _default_config_from_args(InferenceSetupConfig, args)


def inference_cfg_container_from_args(
args: Namespace, model_cfg=None
) -> InferenceConfigContainer:
"""Build an InferenceConfigContainer from the argparse arguments.

This mirrors ``pretrain_cfg_container_from_args`` but assembles only the configs that
inference needs (no optimizer, scheduler, training, validation, DDP, rerun, or straggler
configs). It is intended to be passed to ``initialize_megatron`` from inference entry points.

Args:
args: Parsed and validated argparse namespace (e.g. from ``parse_and_validate_args``).
model_cfg: Optional pre-built model config. If None, a model config is constructed from
``args`` (a HybridModelConfig when ``--hybrid-layer-pattern`` is set, otherwise a
GPTModelConfig).
"""
if model_cfg is None:
if getattr(args, "hybrid_layer_pattern", None) is not None:
model_cfg = hybrid_config_from_args(args)
else:
model_cfg = gpt_config_from_args(args)

ckpt_kwargs = _default_config_from_args(CheckpointConfig, args, return_instance=False)
ckpt_kwargs["save_optim"] = not args.no_save_optim
ckpt_kwargs["save_rng"] = not args.no_save_rng
ckpt_kwargs["load_optim"] = not args.no_load_optim
ckpt_kwargs["load_rng"] = not args.no_load_rng
ckpt_kwargs["fully_parallel_save"] = args.ckpt_fully_parallel_save
ckpt_kwargs["fully_parallel_load"] = args.ckpt_fully_parallel_load

prof_kwargs = _default_config_from_args(ProfilingConfig, args, return_instance=False)
prof_kwargs["use_nsys_profiler"] = args.profile

cfg = InferenceConfigContainer(
model=model_cfg,
checkpoint=CheckpointConfig(**ckpt_kwargs),
inference=inference_cfg_from_args(args),
dist=_default_config_from_args(DistributedInitConfig, args),
rng=_default_config_from_args(RNGConfig, args),
tokenizer=_default_config_from_args(TokenizerConfig, args),
logger=_default_config_from_args(LoggerConfig, args),
profiling=ProfilingConfig(**prof_kwargs),
)

return cfg
3 changes: 2 additions & 1 deletion megatron/training/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RerunStateMachineConfig,
StragglerDetectionConfig,
)
from megatron.training.config.inference_config import InferenceSetupConfig

from megatron.training.config.container import PretrainConfigContainer
from megatron.training.config.container import InferenceConfigContainer, PretrainConfigContainer
from megatron.training.config.instantiate_utils import TargetAllowlist, target_allowlist
33 changes: 33 additions & 0 deletions megatron/training/config/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from megatron.core.msc_utils import MultiStorageClientFeature
from megatron.core.optimizer import OptimizerConfig
from megatron.training.config.common_config import DistributedInitConfig, ProfilingConfig, RNGConfig
from megatron.training.config.inference_config import InferenceSetupConfig
from megatron.training.config.instantiate_utils import InstantiationMode, instantiate
from megatron.training.config.resilience_config import (
RerunStateMachineConfig,
Expand Down Expand Up @@ -247,3 +248,35 @@ class PretrainConfigContainer(ConfigContainerBase):

rerun_state_machine: RerunStateMachineConfig = field(default_factory=RerunStateMachineConfig)
straggler: StragglerDetectionConfig | None = None


@dataclass(kw_only=True)
class InferenceConfigContainer(ConfigContainerBase):
"""Top-level container for inference entry points.

This is the inference counterpart to :class:`PretrainConfigContainer`. It holds only the
configs that inference actually needs and is intentionally shaped differently from the
training container: there is no optimizer, LR schedule, train/validation loop, DDP, rerun
state machine, or straggler detection.

Explicitly NOT included (relative to ``PretrainConfigContainer``): ``TrainingConfig``,
``OptimizerConfig``, ``SchedulerConfig``, ``ValidationConfig``,
``DistributedDataParallelConfig``, ``RerunStateMachineConfig``, ``StragglerDetectionConfig``.
"""

model: HybridModelConfig | GPTModelConfig
"""Which model to load for inference."""

checkpoint: CheckpointConfig
"""Checkpoint configuration used to load model weights."""

inference: InferenceSetupConfig
"""Declarative inference settings (the serializable, args-shaped layer). Use
``InferenceSetupConfig.to_inference_config(model, ...)`` to build the runtime
``megatron.core.inference.config.InferenceConfig`` consumed by the engine."""

dist: DistributedInitConfig = field(default_factory=DistributedInitConfig)
rng: RNGConfig = field(default_factory=RNGConfig)
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
logger: LoggerConfig = field(default_factory=LoggerConfig)
profiling: ProfilingConfig = field(default_factory=ProfilingConfig)
Loading