Skip to content

Commit 0d567b8

Browse files
yuankaichen-amdwenxie-amdaraina-amd
authored
Add memory projection cli (#273)
Co-authored-by: wenxie-amd <[email protected]> Co-authored-by: Anshu Raina <[email protected]>
1 parent 5f20e73 commit 0d567b8

30 files changed

+1184
-130
lines changed

primus/cli/benchmark_cli.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,33 @@ def register_subcommand(subparsers):
5252

5353
# ---------- GEMM ----------
5454
gemm = suite_parsers.add_parser("gemm", help="GEMM microbench.")
55-
from primus.tools.benchmark import gemm_bench
55+
from primus.tools.benchmark.gemm_bench_args import add_gemm_parser
5656

57-
gemm_bench.add_gemm_parser(gemm)
57+
add_gemm_parser(gemm)
5858

5959
# ---------- DENSE-GEMM ----------
6060
dense_gemm = suite_parsers.add_parser("gemm-dense", help="GEMM-DENSE microbench.")
61-
from primus.tools.benchmark import dense_gemm_bench
61+
from primus.tools.benchmark.dense_gemm_bench_args import (
62+
add_gemm_parser as add_dense_gemm_parser,
63+
)
6264

63-
dense_gemm_bench.add_gemm_parser(dense_gemm)
65+
add_dense_gemm_parser(dense_gemm)
6466

6567
# ---------- DEEPSEEK-GEMM ----------
6668
deepseek_gemm = suite_parsers.add_parser("gemm-deepseek", help="DEEPSEEK-GEMM microbench.")
67-
from primus.tools.benchmark import deepseek_dense_gemm_bench
69+
from primus.tools.benchmark.deepseek_dense_gemm_bench_args import (
70+
add_gemm_parser as add_deepseek_gemm_parser,
71+
)
6872

69-
deepseek_dense_gemm_bench.add_gemm_parser(deepseek_gemm)
73+
add_deepseek_gemm_parser(deepseek_gemm)
7074

7175
# ---------- INTER-NODE-ALLGATHER-BY-LOCAL-RANK ----------
7276
strided_allgather_parser = suite_parsers.add_parser(
7377
"strided-allgather", help="Strided Allgather microbench."
7478
)
75-
from primus.tools.benchmark import strided_allgather_bench
79+
from primus.tools.benchmark.strided_allgather_bench_args import add_arguments
7680

77-
strided_allgather_bench.add_arguments(strided_allgather_parser)
81+
add_arguments(strided_allgather_parser)
7882

7983
parser.set_defaults(func=run)
8084

primus/cli/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ def main():
2323
parser = argparse.ArgumentParser(prog="primus", description="Primus Unified CLI for Training & Utilities")
2424
subparsers = parser.add_subparsers(dest="command", required=True)
2525

26-
from primus.cli import benchmark_cli, train_cli
26+
from primus.cli import benchmark_cli, projection_cli, train_cli
2727

2828
# Register train subcommand (only implemented one for now)
2929
train_cli.register_subcommand(subparsers)
3030
benchmark_cli.register_subcommand(subparsers)
31+
projection_cli.register_subcommand(subparsers)
3132

3233
args, unknown_args = parser.parse_known_args()
3334

primus/cli/projection_cli.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
8+
def run(args, overrides):
9+
"""
10+
Entry point for the 'projection' subcommand.
11+
"""
12+
if args.suite == "memory":
13+
from primus.core.projection.memory_projection import launch_projection_from_cli
14+
15+
launch_projection_from_cli(args, overrides)
16+
else:
17+
raise NotImplementedError(f"Unsupported projection suite: {args.suite}")
18+
19+
20+
def register_subcommand(subparsers):
21+
"""
22+
Register the 'projection' subcommand to the main CLI parser.
23+
24+
Example:
25+
primus projection memory --config exp.yaml
26+
Args:
27+
subparsers: argparse subparsers object from main.py
28+
29+
Returns:
30+
parser: The parser for this subcommand
31+
"""
32+
33+
parser = subparsers.add_parser(
34+
"projection",
35+
help="Pre-training performance projection tool",
36+
description="Primus performance projection entry point.",
37+
)
38+
suite_parsers = parser.add_subparsers(dest="suite", required=True)
39+
40+
# ---------- pretrain ----------
41+
pretrain = suite_parsers.add_parser("memory", help="Memory projection.")
42+
from primus.core.launcher.parser import add_pretrain_parser
43+
44+
add_pretrain_parser(pretrain)
45+
46+
parser.set_defaults(func=run)
47+
48+
return parser

primus/core/projection/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
from abc import ABC
8+
9+
10+
class BaseModuleProfiler(ABC):
11+
"""Abstract base class for transformer-like module profiler.
12+
Provides both estimated and measured statistics.
13+
"""
14+
15+
def __init__(self, config, sub_profilers=None):
16+
self.config = config
17+
self.sub_profilers = sub_profilers
18+
19+
# -------- Parameter related --------
20+
def estimated_num_params(self, rank: int | None = None) -> int:
21+
"""Return estimated parameter count (based on formula).
22+
If rank is provided, return the parameter count for the given rank,
23+
otherwise return the total parameter count for the entire model.
24+
"""
25+
raise NotImplementedError
26+
27+
def measured_num_params(self) -> int:
28+
"""Return measured parameter count (from real tensors)."""
29+
raise NotImplementedError
30+
31+
# -------- Memory related --------
32+
def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int:
33+
"""Return estimated memory usage in bytes (activations)."""
34+
raise NotImplementedError
35+
36+
def measured_activation_memory(self, batch_size: int, seq_len: int) -> int:
37+
"""Return measured memory usage in bytes (via profiler/runtime stats)."""
38+
raise NotImplementedError
39+
40+
# -------- Performance related --------
41+
def estimated_forward_time(self, batch_size: int, seq_len: int) -> int:
42+
"""Return estimated forward latency for forward pass in milliseconds."""
43+
raise NotImplementedError
44+
45+
def estimated_backward_time(self, batch_size: int, seq_len: int) -> int:
46+
"""Return estimated latency for backward pass in milliseconds."""
47+
raise NotImplementedError
48+
49+
def measured_forward_time(self, batch_size: int, seq_len: int) -> float:
50+
"""Return measured forward latency in milliseconds."""
51+
raise NotImplementedError
52+
53+
def measured_backward_time(self, batch_size: int, seq_len: int) -> float:
54+
"""Return measured backward latency in milliseconds."""
55+
raise NotImplementedError
56+
57+
# -------- Debugging / summary --------
58+
def __repr__(self):
59+
return f"{self.__class__.__name__}"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .projection import launch_projection_from_cli
2+
3+
__all__ = [
4+
launch_projection_from_cli,
5+
]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
from pathlib import Path
3+
4+
from primus.core.launcher.parser import PrimusParser
5+
from primus.core.projection.module_profilers.language_model import (
6+
build_profiler,
7+
get_language_model_profiler_spec,
8+
)
9+
from primus.core.projection.training_config import (
10+
convert_primus_config_to_projection_config,
11+
)
12+
13+
14+
def print_profiler_hierarchy(profiler, batch_size, seq_len, rank=None, name="root", depth=0, visited=None):
15+
"""
16+
Recursively print the profiler hierarchy with num_params and activation_memory for each component.
17+
18+
Args:
19+
profiler: The profiler instance to print
20+
batch_size: Batch size for activation memory calculation
21+
seq_len: Sequence length for activation memory calculation
22+
rank: Rank for parameter calculation (if None, calculates total parameters)
23+
name: Name of the current profiler component
24+
depth: Current depth in the hierarchy (for indentation)
25+
visited: Set of visited profiler IDs to avoid infinite recursion
26+
"""
27+
if visited is None:
28+
visited = set()
29+
30+
# Avoid infinite recursion if profilers reference each other
31+
profiler_id = id(profiler)
32+
if profiler_id in visited:
33+
return
34+
visited.add(profiler_id)
35+
36+
indent = " " * depth
37+
38+
# Calculate metrics for this profiler
39+
try:
40+
if depth == 0:
41+
# Only output the total number of parameters for the entire model for depth 0.
42+
num_params = profiler.estimated_num_params(rank=None)
43+
print(f"{indent} Total Number of Parameters: {num_params / 1e9:.6f} Billion ({num_params:,})")
44+
else:
45+
num_params = profiler.estimated_num_params(rank=rank)
46+
activation_mem = profiler.estimated_activation_memory(batch_size, seq_len)
47+
print(f"{indent}[{name}]")
48+
print(f"{indent} Params: {num_params / 1e9:.6f} Billion ({num_params:,})")
49+
print(f"{indent} Activation Memory: {activation_mem / 1024 / 1024 / 1024:.4f} GB")
50+
51+
# Recursively process sub_profilers if they exist
52+
if hasattr(profiler, "sub_profilers") and profiler.sub_profilers:
53+
for sub_name, sub_profiler in profiler.sub_profilers.items():
54+
if sub_profiler is not None:
55+
print() # Add spacing between components
56+
print_profiler_hierarchy(
57+
sub_profiler, batch_size, seq_len, rank, sub_name, depth + 1, visited
58+
)
59+
except Exception as e:
60+
print(f"{indent}[{name}] - Error calculating metrics: {e}")
61+
62+
63+
def launch_projection_from_cli(args, overrides):
64+
"""
65+
Entry point for the 'projection' subcommand.
66+
67+
"""
68+
cfg_path = Path(args.config)
69+
if not cfg_path.exists():
70+
raise FileNotFoundError(f"[Primus:Projection] Config file '{cfg_path}' not found.")
71+
72+
config_parser = PrimusParser()
73+
primus_config = config_parser.parse(args)
74+
training_config = convert_primus_config_to_projection_config(primus_config)
75+
print(training_config)
76+
77+
model_profiler_spec = get_language_model_profiler_spec(training_config)
78+
model_profiler = build_profiler(model_profiler_spec)
79+
80+
seq_len = training_config.runtime_config.sequence_length
81+
batch_size = training_config.runtime_config.micro_batch_size
82+
rank = int(os.getenv("RANK", "0"))
83+
84+
# Print recursive profiler hierarchy with detailed breakdown
85+
print("\n" + "=" * 100)
86+
print(f"[Primus:Projection] Component-wise Profiling Results (Rank {rank}):")
87+
print("=" * 100)
88+
print()
89+
90+
# Print the complete hierarchy recursively
91+
print_profiler_hierarchy(
92+
model_profiler, batch_size, seq_len, rank=rank, name="LanguageModelProfiler", depth=0
93+
)
94+
95+
# Get overall totals from the model profiler for this rank
96+
num_params = model_profiler.estimated_num_params(rank=rank)
97+
activation_memory = model_profiler.estimated_activation_memory(batch_size, seq_len)
98+
num_bytes_per_param = model_profiler.get_num_bytes_per_param()
99+
print()
100+
print("=" * 100)
101+
print(f"[Primus:Projection] Memory Projection Summary on Rank {rank}:")
102+
print(f" Params: {num_params / 1e9:.6f} Billion ({num_params:,})")
103+
print(f" Param+Optimizer Memory: {num_params * num_bytes_per_param / 1024 / 1024 / 1024:.4f} GB")
104+
print(
105+
f" Activation Memory (per batch size {batch_size}, seq len {seq_len}): "
106+
f"{activation_memory / 1024 / 1024 / 1024:.4f} GB"
107+
)
108+
print(
109+
f" Projected Total Memory: "
110+
f"{(num_params * num_bytes_per_param + activation_memory) / 1024 / 1024 / 1024:.4f} GB"
111+
)
112+
print("=" * 100)

primus/core/projection/module_profilers/__init__.py

Whitespace-only changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
8+
from primus.core.projection.base_module_profiler import BaseModuleProfiler
9+
10+
11+
class AttentionProfiler(BaseModuleProfiler):
12+
def estimated_num_params(self, rank: int | None = None) -> int:
13+
args = self.config.model_config
14+
# Group-query & multi-latent attention support.
15+
# If GQA not enabled, fall back to per-head queries.
16+
num_query_groups = (
17+
args.num_query_groups
18+
if args.group_query_attention and args.num_query_groups
19+
else args.num_attention_heads
20+
)
21+
22+
# Projection ratio: (kv_channels * n_heads) / hidden_size
23+
query_proj_to_hidden = (args.kv_channels * args.num_attention_heads) / args.hidden_size
24+
25+
if args.multi_latent_attention:
26+
# q_term: either dense or LoRA factored Q with RoPE/Q-norm
27+
if args.q_lora_rank is None:
28+
q_term = (
29+
args.hidden_size
30+
* args.num_attention_heads
31+
* (args.qk_head_dim + args.qk_pos_emb_head_dim)
32+
)
33+
else:
34+
q_term = args.q_lora_rank * (
35+
args.hidden_size
36+
+ args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)
37+
+ 1
38+
)
39+
attn = (
40+
q_term
41+
# kv lora + rope + kv norm
42+
+ args.kv_lora_rank
43+
* (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim) + 1)
44+
# pos emb
45+
+ args.hidden_size * args.qk_pos_emb_head_dim
46+
# out proj
47+
+ (args.num_attention_heads * args.v_head_dim) * args.hidden_size
48+
)
49+
return attn
50+
51+
# Standard attention path (Q,K,V,O projections)
52+
return (
53+
2
54+
* args.hidden_size
55+
* args.hidden_size
56+
* ((1 + (num_query_groups / args.num_attention_heads)) * query_proj_to_hidden)
57+
)
58+
59+
def estimated_activation_memory(self, batch_size: int, seq_len: int) -> int:
60+
multiplier = 4 # for Q, K, V, O
61+
return (
62+
batch_size
63+
* seq_len
64+
// self.config.model_parallel_config.tensor_model_parallel_size
65+
// self.config.model_parallel_config.context_model_parallel_size
66+
* self.config.model_config.hidden_size
67+
* multiplier
68+
* 2
69+
) # bf16

0 commit comments

Comments
 (0)