Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions primus/cli/benchmark_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,33 @@ def register_subcommand(subparsers):

# ---------- GEMM ----------
gemm = suite_parsers.add_parser("gemm", help="GEMM microbench.")
from primus.tools.benchmark import gemm_bench
from primus.tools.benchmark.gemm_bench_args import add_gemm_parser

gemm_bench.add_gemm_parser(gemm)
add_gemm_parser(gemm)

# ---------- DENSE-GEMM ----------
dense_gemm = suite_parsers.add_parser("gemm-dense", help="GEMM-DENSE microbench.")
from primus.tools.benchmark import dense_gemm_bench
from primus.tools.benchmark.dense_gemm_bench_args import (
add_gemm_parser as add_dense_gemm_parser,
)

dense_gemm_bench.add_gemm_parser(dense_gemm)
add_dense_gemm_parser(dense_gemm)

# ---------- DEEPSEEK-GEMM ----------
deepseek_gemm = suite_parsers.add_parser("gemm-deepseek", help="DEEPSEEK-GEMM microbench.")
from primus.tools.benchmark import deepseek_dense_gemm_bench
from primus.tools.benchmark.deepseek_dense_gemm_bench_args import (
add_gemm_parser as add_deepseek_gemm_parser,
)

deepseek_dense_gemm_bench.add_gemm_parser(deepseek_gemm)
add_deepseek_gemm_parser(deepseek_gemm)

# ---------- INTER-NODE-ALLGATHER-BY-LOCAL-RANK ----------
strided_allgather_parser = suite_parsers.add_parser(
"strided-allgather", help="Strided Allgather microbench."
)
from primus.tools.benchmark import strided_allgather_bench
from primus.tools.benchmark.strided_allgather_bench_args import add_arguments

strided_allgather_bench.add_arguments(strided_allgather_parser)
add_arguments(strided_allgather_parser)

parser.set_defaults(func=run)

Expand Down
3 changes: 2 additions & 1 deletion primus/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ def main():
parser = argparse.ArgumentParser(prog="primus", description="Primus Unified CLI for Training & Utilities")
subparsers = parser.add_subparsers(dest="command", required=True)

from primus.cli import benchmark_cli, train_cli
from primus.cli import benchmark_cli, projection_cli, train_cli

# Register train subcommand (only implemented one for now)
train_cli.register_subcommand(subparsers)
benchmark_cli.register_subcommand(subparsers)
projection_cli.register_subcommand(subparsers)

args, unknown_args = parser.parse_known_args()

Expand Down
48 changes: 48 additions & 0 deletions primus/cli/projection_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
###############################################################################
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################


def run(args, overrides):
"""
Entry point for the 'projection' subcommand.
"""
if args.suite == "memory":
from primus.core.projection.memory_projection import launch_projection_from_cli

launch_projection_from_cli(args, overrides)
else:
raise NotImplementedError(f"Unsupported projection suite: {args.suite}")


def register_subcommand(subparsers):
"""
Register the 'projection' subcommand to the main CLI parser.

Example:
primus projection memory --config exp.yaml
Args:
subparsers: argparse subparsers object from main.py

Returns:
parser: The parser for this subcommand
"""

parser = subparsers.add_parser(
"projection",
help="Pre-training performance projection tool",
description="Primus performance projection entry point.",
)
suite_parsers = parser.add_subparsers(dest="suite", required=True)

# ---------- pretrain ----------
pretrain = suite_parsers.add_parser("memory", help="Memory projection.")
from primus.core.launcher.parser import add_pretrain_parser

add_pretrain_parser(pretrain)

parser.set_defaults(func=run)

return parser
Empty file.
59 changes: 59 additions & 0 deletions primus/core/projection/base_module_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
###############################################################################
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################

from abc import ABC


class BaseModuleProfiler(ABC):
"""Abstract base class for transformer-like module profiler.
Provides both estimated and measured statistics.
"""

def __init__(self, config, sub_profilers=None):
self.config = config
self.sub_profilers = sub_profilers

# -------- Parameter related --------
def estimated_num_params(self, rank: int | None = None) -> int:
"""Return estimated parameter count (based on formula).
If rank is provided, return the parameter count for the given rank,
otherwise return the total parameter count for the entire model.
"""
raise NotImplementedError

def measured_num_params(self) -> int:
"""Return measured parameter count (from real tensors)."""
raise NotImplementedError

# -------- Memory related --------
def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int:
"""Return estimated memory usage in bytes (activations)."""
raise NotImplementedError

def measured_activation_memory(self, batch_size: int, seq_len: int) -> int:
"""Return measured memory usage in bytes (via profiler/runtime stats)."""
raise NotImplementedError

# -------- Performance related --------
def estimated_forward_time(self, batch_size: int, seq_len: int) -> int:
"""Return estimated forward latency for forward pass in milliseconds."""
raise NotImplementedError

def estimated_backward_time(self, batch_size: int, seq_len: int) -> int:
"""Return estimated latency for backward pass in milliseconds."""
raise NotImplementedError

def measured_forward_time(self, batch_size: int, seq_len: int) -> float:
"""Return measured forward latency in milliseconds."""
raise NotImplementedError

def measured_backward_time(self, batch_size: int, seq_len: int) -> float:
"""Return measured backward latency in milliseconds."""
raise NotImplementedError

# -------- Debugging / summary --------
def __repr__(self):
return f"{self.__class__.__name__}"
5 changes: 5 additions & 0 deletions primus/core/projection/memory_projection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .projection import launch_projection_from_cli

__all__ = [
launch_projection_from_cli,
]
112 changes: 112 additions & 0 deletions primus/core/projection/memory_projection/projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
from pathlib import Path

from primus.core.launcher.parser import PrimusParser
from primus.core.projection.module_profilers.language_model import (
build_profiler,
get_language_model_profiler_spec,
)
from primus.core.projection.training_config import (
convert_primus_config_to_projection_config,
)


def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="root", depth=0, visited=None):
"""
Recursively print the profiler hierarchy with num_params and activation_memory for each component.

Args:
profiler: The profiler instance to print
batch_size: Batch size for activation memory calculation
seq_len: Sequence length for activation memory calculation
rank: Rank for parameter calculation (if None, calculates total parameters)
name: Name of the current profiler component
depth: Current depth in the hierarchy (for indentation)
visited: Set of visited profiler IDs to avoid infinite recursion
"""
if visited is None:
visited = set()

# Avoid infinite recursion if profilers reference each other
profiler_id = id(profiler)
if profiler_id in visited:
return
visited.add(profiler_id)

indent = " " * depth

# Calculate metrics for this profiler
try:
if depth == 0:
# Only output the total number of parameters for the entire model for depth 0.
num_params = profiler.estimated_num_params(rank=None)
print(f"{indent} Total Number of Parameters: {num_params / 1e9:.6f} Billion ({num_params:,})")
else:
num_params = profiler.estimated_num_params(rank=rank)
activation_mem = profiler.estimated_activation_memory(batch_size, seq_len)
print(f"{indent}[{name}]")
print(f"{indent} Params: {num_params / 1e9:.6f} Billion ({num_params:,})")
print(f"{indent} Activation Memory: {activation_mem / 1024 / 1024 / 1024:.4f} GB")

# Recursively process sub_profilers if they exist
if hasattr(profiler, "sub_profilers") and profiler.sub_profilers:
for sub_name, sub_profiler in profiler.sub_profilers.items():
if sub_profiler is not None:
print() # Add spacing between components
print_profiler_hierarchy(
sub_profiler, batch_size, seq_len, rank, sub_name, depth + 1, visited
)
except Exception as e:
print(f"{indent}[{name}] - Error calculating metrics: {e}")


def launch_projection_from_cli(args, overrides):
"""
Entry point for the 'projection' subcommand.

"""
cfg_path = Path(args.config)
if not cfg_path.exists():
raise FileNotFoundError(f"[Primus:Projection] Config file '{cfg_path}' not found.")

config_parser = PrimusParser()
primus_config = config_parser.parse(args)
training_config = convert_primus_config_to_projection_config(primus_config)
print(training_config)

model_profiler_spec = get_language_model_profiler_spec(training_config)
model_profiler = build_profiler(model_profiler_spec)

seq_len = training_config.runtime_config.sequence_length
batch_size = training_config.runtime_config.micro_batch_size
rank = int(os.getenv("RANK", "0"))

# Print recursive profiler hierarchy with detailed breakdown
print("\n" + "=" * 100)
print(f"[Primus:Projection] Component-wise Profiling Results (Rank {rank}):")
print("=" * 100)
print()

# Print the complete hierarchy recursively
print_profiler_hierarchy(
model_profiler, batch_size, seq_len, rank=rank, name="LanguageModelProfiler", depth=0
)

# Get overall totals from the model profiler for this rank
num_params = model_profiler.estimated_num_params(rank=rank)
activation_memory = model_profiler.estimated_activation_memory(batch_size, seq_len)
num_bytes_per_param = model_profiler.get_num_bytes_per_param()
print()
print("=" * 100)
print(f"[Primus:Projection] Memory Projection Summary on Rank {rank}:")
print(f" Params: {num_params / 1e9:.6f} Billion ({num_params:,})")
print(f" Param+Optimizer Memory: {num_params * num_bytes_per_param / 1024 / 1024 / 1024:.4f} GB")
print(
f" Activation Memory (per batch size {batch_size}, seq len {seq_len}): "
f"{activation_memory / 1024 / 1024 / 1024:.4f} GB"
)
print(
f" Projected Total Memory: "
f"{(num_params * num_bytes_per_param + activation_memory) / 1024 / 1024 / 1024:.4f} GB"
)
print("=" * 100)
Empty file.
69 changes: 69 additions & 0 deletions primus/core/projection/module_profilers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
###############################################################################
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################


from primus.core.projection.base_module_profiler import BaseModuleProfiler


class AttentionProfiler(BaseModuleProfiler):
def estimated_num_params(self, rank: int | None = None) -> int:
args = self.config.model_config
# Group-query & multi-latent attention support.
# If GQA not enabled, fall back to per-head queries.
num_query_groups = (
args.num_query_groups
if args.group_query_attention and args.num_query_groups
else args.num_attention_heads
)

# Projection ratio: (kv_channels * n_heads) / hidden_size
query_proj_to_hidden = (args.kv_channels * args.num_attention_heads) / args.hidden_size

if args.multi_latent_attention:
# q_term: either dense or LoRA factored Q with RoPE/Q-norm
if args.q_lora_rank is None:
q_term = (
args.hidden_size
* args.num_attention_heads
* (args.qk_head_dim + args.qk_pos_emb_head_dim)
)
else:
q_term = args.q_lora_rank * (
args.hidden_size
+ args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)
+ 1
)
attn = (
q_term
# kv lora + rope + kv norm
+ args.kv_lora_rank
* (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) + 1)
# pos emb
+ args.hidden_size * args.qk_pos_emb_head_dim
# out proj
+ (args.num_attention_heads * args.v_head_dim) * args.hidden_size
)
return attn

# Standard attention path (Q,K,V,O projections)
return (
2
* args.hidden_size
* args.hidden_size
* ((1 + (num_query_groups / args.num_attention_heads)) * query_proj_to_hidden)
)

def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int:
multiplier = 4 # for Q, K, V, O
return (
batch_size
* seq_len
// self.config.model_parallel_config.tensor_model_parallel_size
// self.config.model_parallel_config.context_model_parallel_size
* self.config.model_config.hidden_size
* multiplier
* 2
) # bf16
Loading