diff --git a/README.md b/README.md index b958d601b..b3fec433f 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,48 @@ python -m atom.benchmarks.benchmark_serving \ --result-dir=./ --result-filename=$RESULT_FILENAME.json ``` +### Profile Analyze + +ATOM supports automatic trace collection and analysis, which breaks down GPU kernel durations per module for both **prefill** and **decode** phases and exports the results to Excel (`.xlsx`) files. + +#### Step 1: Collect a Trace + +Launch the server with `--torch-profiler-dir` to enable the PyTorch profiler and `--mark-trace` to insert per-module annotations into the trace. Set `TORCHINDUCTOR_COMPILE_THREADS=1` to ensure deterministic compilation order. + +```bash +TORCHINDUCTOR_COMPILE_THREADS=1 python -m atom.entrypoints.openai_server \ + --model deepseek-ai/DeepSeek-R1 \ + --kv_cache_dtype fp8 -tp 8 \ + --torch-profiler-dir ./trace \ + --mark-trace +``` + +After the server processes requests and shuts down, two `*.json.gz` trace files will be generated in the `--torch-profiler-dir` directory. + +#### Step 2: Analyze the Trace + +Run `parse_trace.py` on the collected trace file(use it on trace file start with the model name): + +```bash +python ATOM/tools/parse_trace.py ./trace/model_name_ts_*.json.gz +``` + +This produces two Excel files in the current directory: + +| Output File | Description | +|---|---| +| `prefill_breakdown.xlsx` | Per-kernel duration breakdown for one prefill layer | +| `decode_breakdown.xlsx` | Per-kernel duration breakdown for one decode layer | + +Each file contains columns: `cpu_module`, `gpu_kernel`, `duration_us`, `sum per module`, plus averaged values across layers. + +**Options:** + +| Flag | Default | Description | +|---|---|---| +| `--layer N` | `3` | Target transformer layer index to analyze (0-indexed) | + + ## 📊 Performance ### Online Serving Throughput diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index 260d136d0..dbc5db9b7 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -86,13 +86,18 @@ def __init__(self, config: Config, input_address: str, output_address: str): if not config.enforce_eager: # Start profiler before cudagraph capture only if mark-trace is enabled. if self.profile_enbaled and self.mark_trace: - self.runner_mgr.call_func("start_profiler", wait_out=True) + self.runner_mgr.call_func( + "start_profiler", "capture_graph", wait_out=True + ) cap_cost, bs = self.runner_mgr.call_func( "capture_cudagraph", wait_out=True ) logger.info( f"{self.label}: cudagraph capture{bs} cost: {cap_cost:.2f} seconds" ) + if self.profile_enbaled and self.mark_trace: + # Persist a dedicated capture-graph trace immediately. + self.runner_mgr.call_func("stop_profiler", wait_out=True) good = True finally: logger.info( diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 4c6b667fb..0a7ed2a87 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -5,6 +5,7 @@ import math import os import time +import gzip from contextlib import nullcontext from typing import Any, Optional, Union @@ -690,7 +691,7 @@ def exit(self): torch.cuda.empty_cache() return True - def start_profiler(self): + def start_profiler(self, trace_name: Optional[str] = None): """ Start profiling for this rank. @@ -700,6 +701,35 @@ def start_profiler(self): """ if self.profiler_dir is not None and self.profiler is None: enable_detailed_profiling = envs.ATOM_PROFILER_MORE + model_name = os.path.basename(self.config.model.rstrip("/")) + safe_model_name = "".join( + c if c.isalnum() or c in ("_", "-", ".") else "_" for c in model_name + ) + worker_name = safe_model_name or "trace" + if isinstance(trace_name, str) and trace_name: + worker_name = "".join( + c if c.isalnum() or c in ("_", "-", ".") else "_" + for c in trace_name + ) + if worker_name == "capture_graph": + if safe_model_name: + worker_name = f"{worker_name}_{safe_model_name}" + output_prefix = os.path.join(self.profiler_dir, worker_name) + + def _on_trace_ready(prof): + # Use a short human-readable timestamp in file name. + ts = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + ms = int((time.time() % 1) * 1000) + output_path = f"{output_prefix}_ts_{ts}_{ms:03d}.pt.trace.json.gz" + tmp_json_path = output_path[:-3] + prof.export_chrome_trace(tmp_json_path) + with ( + open(tmp_json_path, "rb") as src, + gzip.open(output_path, "wb") as dst, + ): + dst.write(src.read()) + os.remove(tmp_json_path) + self.profiler = torch_profiler.profile( activities=[ torch_profiler.ProfilerActivity.CPU, @@ -708,9 +738,7 @@ def start_profiler(self): record_shapes=enable_detailed_profiling, with_stack=enable_detailed_profiling, profile_memory=enable_detailed_profiling, - on_trace_ready=torch_profiler.tensorboard_trace_handler( - self.profiler_dir, use_gzip=True - ), + on_trace_ready=_on_trace_ready, ) self.profiler.__enter__() return True @@ -1504,21 +1532,13 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor is_prefill = context.is_prefill positions = context.positions if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: - with ( - record_function( - f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}" - ) - if self.mark_trace - else nullcontext() + with record_function( + f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}" ): hidden_states = self.model(input_ids, positions) logits = self.model.compute_logits(hidden_states) else: - with ( - record_function(f"decode_step_bs_{bs}") - if self.mark_trace - else nullcontext() - ): + with record_function(f"decode_step_bs_{bs}"): graph_bs = context.graph_bs max_q_len = forward_context.attn_metadata.max_seqlen_q graph_key = (graph_bs, max_q_len) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b38ca0ee3..8e0e1681e 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -17,6 +17,7 @@ from atom.plugin.prepare import is_plugin_mode, is_vllm from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode +from atom.utils.decorators import mark_trace @PagedAttentionImplDecoratorForPluginMode @@ -110,6 +111,7 @@ def forward_impl_server_mode( return o + @mark_trace(prefix="rope_cache", torch_compile=False) def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): attn_metadata = fwd_ctx.attn_metadata kv_cache_data = fwd_ctx.kv_cache_data @@ -214,6 +216,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): return q, k, v, k_cache, v_cache, k_scale, v_scale + @mark_trace(prefix="paged_attention_triton", torch_compile=False) def paged_attention_triton( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext ): @@ -290,6 +293,7 @@ def paged_attention_triton( return o + @mark_trace(prefix="paged_attention_asm", torch_compile=False) def paged_attention_asm( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext ): @@ -312,6 +316,7 @@ def paged_attention_asm( return o + @mark_trace(prefix="paged_attention_persistent_asm", torch_compile=False) def paged_attention_persistent_asm( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext ): @@ -341,6 +346,7 @@ def paged_attention_persistent_asm( return output + @mark_trace(prefix="prefill_attention", torch_compile=False) def prefill_attention( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext ): diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 9d6fe94d1..68bd056a5 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -23,6 +23,8 @@ from atom.model_ops.linear import use_triton_gemm from atom.model_ops.utils import get_and_maybe_dequant_weights from atom.utils import envs +from atom.utils.decorators import mark_trace + from atom.utils.forward_context import ( AttentionMetaData, ForwardContext, @@ -39,6 +41,15 @@ from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode +concat_and_cache_mla = mark_trace( + concat_and_cache_mla, prefix="kv_cache", torch_compile=False +) +fused_qk_rope_concat_and_cache_mla = mark_trace( + fused_qk_rope_concat_and_cache_mla, prefix="rope_and_kv_cache", torch_compile=False +) +mla_prefill_fwd = mark_trace(mla_prefill_fwd, prefix="mla_prefill", torch_compile=False) +mla_decode_fwd = mark_trace(mla_decode_fwd, prefix="mla_decode", torch_compile=False) + # torch.set_printoptions(threshold=10_000) logger = logging.getLogger("atom") @@ -194,6 +205,7 @@ def process_weights_after_loading(self): W_V, dtype=dtypes.fp8 ) + @mark_trace(prefix="v_up_proj_and_o_proj", torch_compile=False) def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -228,6 +240,7 @@ def _v_up_proj_and_o_proj(self, x): x = x.reshape(-1, self.num_heads * self.v_head_dim) return self.o_proj(x) + @mark_trace(prefix="q_proj_and_k_up_proj", torch_compile=False) def _q_proj_and_k_up_proj(self, x, x_scale=None): q_nope, q_pe = ( self.q_proj(x, x_scale) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 9f16752a6..7f556fcc5 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -9,6 +9,7 @@ handle_torch_function, ) from atom.config import QuantizationConfig, LayerQuantConfig +from atom.utils.decorators import mark_trace from torch import nn from aiter import ( rmsnorm2d_fwd, @@ -197,6 +198,7 @@ def __init__( self.quant_type = quant_type self.params_dtype = params_dtype + @mark_trace(prefix="rmsnorm", torch_compile=True) def forward( self, x: torch.Tensor, diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 1b9eb615f..3cfbbb393 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -28,7 +28,7 @@ shuffle_weights, ) from atom.utils import envs -from atom.utils.decorators import mark_trace, record_function +from atom.utils.decorators import mark_trace from torch import nn logger = logging.getLogger("atom") @@ -174,7 +174,7 @@ def gemm_a8w8_blockscale_preshuffle_fake( return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device) -@record_function +@mark_trace(torch_compile=False) @torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[]) def gemm_a8w8_blockscale_preshuffle_impl( x: torch.Tensor, @@ -194,7 +194,6 @@ def gemm_a8w8_blockscale_preshuffle_impl( return y -@mark_trace class LinearBase(nn.Module): def __init__( self, @@ -386,6 +385,7 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 ) -> torch.Tensor: diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index f6f9e5e76..a4c1db0bd 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -50,6 +50,7 @@ from atom.utils import envs from atom.utils.custom_register import direct_register_custom_op from atom.utils.forward_context import get_forward_context +from atom.utils.decorators import mark_trace from torch import nn from transformers import PretrainedConfig from atom.plugin.moe import FusedMoEDecoratorForPluginMode @@ -449,6 +450,7 @@ def get_fused_moe_quant_config( ) -> FusedMoEQuantConfig | None: return FUSED_MOE_UNQUANTIZED_CONFIG + @mark_trace(prefix="unquantized_moe", torch_compile=False) def apply( self, layer: torch.nn.Module, @@ -859,6 +861,7 @@ def get_fused_moe_quant_config( w2_scale=layer.w2_weight_scale, ) + @mark_trace(prefix="mxfp4_moe", torch_compile=False) def apply( self, layer: torch.nn.Module, @@ -1273,6 +1276,7 @@ def get_fused_moe_quant_config( block_shape=block_shape, ) + @mark_trace(prefix="compressed_fp8_moe", torch_compile=False) def apply( self, layer: torch.nn.Module, @@ -1637,6 +1641,7 @@ def get_fused_moe_quant_config( block_shape=None, ) + @mark_trace(prefix="fp8_moe", torch_compile=False) def apply( self, layer: torch.nn.Module, diff --git a/atom/model_ops/rotary_embedding.py b/atom/model_ops/rotary_embedding.py index d15dfdcf1..5186f321c 100644 --- a/atom/model_ops/rotary_embedding.py +++ b/atom/model_ops/rotary_embedding.py @@ -8,6 +8,8 @@ from aiter import dtypes from typing import Union, Optional +from atom.utils.decorators import mark_trace + def apply_rotary_emb( x: torch.Tensor, @@ -77,6 +79,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: sin = freqs.sin().unsqueeze(-2).unsqueeze(-2) return cos, sin + @mark_trace(prefix="rope_cached") def forward( self, positions: torch.Tensor, diff --git a/atom/utils/decorators.py b/atom/utils/decorators.py index 8fa916218..eb16cd246 100644 --- a/atom/utils/decorators.py +++ b/atom/utils/decorators.py @@ -28,62 +28,46 @@ torch_compile_start_time: float = 0.0 -def record_function(prefix: Union[str, Callable, None] = None): - """ - Decorator that wraps a function with torch.profiler.record_function. +def _resolve_record_span_name( + func: Callable, + args, + kwargs, + explicit_prefix: Optional[str] = None, +): + if explicit_prefix is not None: + return str(explicit_prefix) + + span_name = func.__name__ + runtime_prefix = kwargs.get("prefix") + if isinstance(runtime_prefix, str) and runtime_prefix: + return runtime_prefix + + try: + base_sig = inspect.signature(inspect.unwrap(func)) + bound = base_sig.bind_partial(*args, **kwargs) + runtime_prefix = bound.arguments.get("prefix") + except Exception: + runtime_prefix = None + + if isinstance(runtime_prefix, str) and runtime_prefix: + return runtime_prefix + return span_name + + +def _decorate_record_function(func: Callable, prefix: Optional[str] = None): + @wraps(func) + def _wrapped(*args, **kwargs): + # Keep this decorator no-op unless mark-trace is enabled. + from atom.utils.graph_marker import is_graph_marker_enabled - Usage: - - @record_function - - @record_function("my_prefix") - """ + if not is_graph_marker_enabled(): + return func(*args, **kwargs) - def _decorate(func: Callable): - # Try to recover the original callable signature even when func is wrapped - # by other decorators. - base_func = inspect.unwrap(func) - try: - base_sig = inspect.signature(base_func) - except (TypeError, ValueError): - base_sig = None - - @wraps(func) - def _wrapped(*args, **kwargs): - # Keep this decorator no-op unless mark-trace is enabled. - from atom.utils.graph_marker import is_graph_marker_enabled - - if not is_graph_marker_enabled(): - return func(*args, **kwargs) - - # Priority: - # 1) explicit decorator prefix: @record_function("xxx") - # 2) runtime function argument named "prefix" when non-empty - # 3) function name fallback - if prefix is not None: - span_name = str(prefix) - else: - span_name = func.__name__ - runtime_prefix = kwargs.get("prefix") - if not (isinstance(runtime_prefix, str) and runtime_prefix): - if base_sig is not None: - try: - bound = base_sig.bind_partial(*args, **kwargs) - runtime_prefix = bound.arguments.get("prefix") - except Exception: - runtime_prefix = None - if isinstance(runtime_prefix, str) and runtime_prefix: - span_name = runtime_prefix - - with torch.profiler.record_function(f"{span_name}"): - return func(*args, **kwargs) - - return _wrapped - - # Support @record_function without parentheses. - if callable(prefix): - func = prefix - prefix = None - return _decorate(func) - return _decorate + span_name = _resolve_record_span_name(func, args, kwargs, prefix) + with torch.profiler.record_function(f"{span_name}"): + return func(*args, **kwargs) + + return _wrapped def _graph_marker_first_tensor(obj, name: str): @@ -126,43 +110,85 @@ def _graph_marker_first_tensor(obj, name: str): return obj, False -def mark_trace(cls): - forward = getattr(cls, "forward", None) - if forward is None: - return cls - if getattr(forward, "__mark_trace_wrapped__", False): - return cls +def _decorate_mark_trace_torch_compile(func: Callable, prefix: Optional[str] = None): + if getattr(func, "__mark_trace_wrapped__", False): + return func from atom.utils.graph_marker import is_graph_marker_enabled - def wrapped_forward(self, *args, **kwargs): - # When mark-trace is disabled, bypass all wrapping logic entirely + owner_name = func.__qualname__.split(".")[0] + try: + unwrapped = inspect.unwrap(func) + params = list(inspect.signature(unwrapped).parameters.values()) + skip_first_arg = bool(params) and params[0].name in {"self", "cls"} + except (TypeError, ValueError): + skip_first_arg = False + + @wraps(func) + def wrapped(*args, **kwargs): + # When mark-trace is disabled, bypass all wrapping logic entirely. if not is_graph_marker_enabled(): - return forward(self, *args, **kwargs) + return func(*args, **kwargs) + + marker_prefix = str(prefix) if prefix is not None else owner_name + start_idx = 0 + if skip_first_arg and args: + marker_prefix = ( + str(prefix) + if prefix is not None + else getattr(args[0], "prefix", owner_name) + ) + start_idx = 1 + if prefix is None: + runtime_prefix = kwargs.get("prefix") + if isinstance(runtime_prefix, str) and runtime_prefix: + marker_prefix = runtime_prefix - prefix = getattr(self, "prefix", cls.__name__) # Mark only the first tensor across args/kwargs, keeping names stable. args_l = list(args) marked = False - for i, a in enumerate(args_l): + for i in range(start_idx, len(args_l)): if marked: break - aa, marked = _graph_marker_first_tensor(a, f"{prefix}_start") + aa, marked = _graph_marker_first_tensor(args_l[i], f"{marker_prefix}_start") args_l[i] = aa if not marked: for k, v in list(kwargs.items()): if marked: break - vv, marked = _graph_marker_first_tensor(v, f"{prefix}_start") + vv, marked = _graph_marker_first_tensor(v, f"{marker_prefix}_start") kwargs[k] = vv - args = tuple(args_l) - y = forward(self, *args, **kwargs) - yy, _ = _graph_marker_first_tensor(y, f"{prefix}_end") + + y = func(*tuple(args_l), **kwargs) + yy, _ = _graph_marker_first_tensor(y, f"{marker_prefix}_end") return yy - wrapped_forward.__mark_trace_wrapped__ = True - cls.forward = wrapped_forward - return cls + wrapped.__mark_trace_wrapped__ = True + return wrapped + + +def mark_trace( + func: Optional[Callable] = None, + *, + torch_compile: bool = True, + prefix: Optional[str] = None, +): + """ + Unified trace decorator. + + - torch_compile=True: original graph_marker-based mark_trace behavior. + - torch_compile=False: record_function behavior. + """ + + def _decorate(target: Callable): + if torch_compile: + return _decorate_mark_trace_torch_compile(target, prefix) + return _decorate_record_function(target, prefix) + + # Support both @mark_trace and @mark_trace(...) + if func is not None: + return _decorate(func) + return _decorate # We remove it from utils/__init__.py to avoid circular import diff --git a/atom/utils/graph_marker_instrumentation.py b/atom/utils/graph_marker_instrumentation.py index ba97f8f40..30d529820 100644 --- a/atom/utils/graph_marker_instrumentation.py +++ b/atom/utils/graph_marker_instrumentation.py @@ -269,11 +269,7 @@ def _wrap_region_with_record_function( if end_marker_idx <= start_marker_idx + 1: return - # Only add layer prefix if prefix doesn't already contain layer info - if layer_id is not None and layer_id >= 0 and "model.layers" not in prefix: - tag = f"layer_{layer_id}_{prefix}" - else: - tag = prefix + tag = prefix with_line = f'{indent}with record_function("{tag}"):\n' insert_at = start_marker_idx + 1 diff --git a/tools/parse_trace.py b/tools/parse_trace.py index ba0ff04ba..0e55a9125 100644 --- a/tools/parse_trace.py +++ b/tools/parse_trace.py @@ -12,6 +12,8 @@ import bisect import argparse import re +import os +from glob import glob from typing import List, Dict, Any, Tuple, Optional from openpyxl import Workbook @@ -19,7 +21,7 @@ FILTER_OUT = ["fill_"] # Sampling-related modules and low-level ops to filter out in prefill -FILTER_OUT_PREFILL = ["aten::", "aiter::gemm_a16w16", "aiter::mixed_sample"] +FILTER_OUT_PREFILL = ["aiter::mixed_sample"] # ============================================================================= @@ -34,13 +36,6 @@ def load_trace(filepath: str) -> Dict[str, Any]: return json.load(f) -def is_within( - child_ts: float, child_dur: float, parent_ts: float, parent_dur: float -) -> bool: - """Check if child event is within parent's time range.""" - return child_ts >= parent_ts and (child_ts + child_dur) <= (parent_ts + parent_dur) - - def is_kernel_launch(name: str) -> bool: """Check if name is a kernel launch (contains 'launch' and 'kernel').""" n = name.lower() @@ -57,6 +52,50 @@ def should_filter_prefill(name: str) -> bool: return any(f in name for f in FILTER_OUT_PREFILL) +def is_strict_norm_name(name: str) -> bool: + """Match norm module names strictly, not by substring.""" + if not isinstance(name, str): + return False + n = name.strip().lower() + return n == "layernorm" or n == "rmsnorm" + + +def extract_model_name_from_trace_filename(filepath: str) -> Optional[str]: + """ + Extract model name from trace filename prefix before `_ts_`. + Examples: + - Meta-Llama-3.1-8B-Instruct_ts_... -> Meta-Llama-3.1-8B-Instruct + - capture_graph_Meta-Llama-3.1-8B-Instruct_ts_... -> Meta-Llama-3.1-8B-Instruct + """ + base = os.path.basename(filepath) + if "_ts_" not in base: + return None + prefix = base.split("_ts_", 1)[0] + if prefix.startswith("capture_graph_"): + prefix = prefix[len("capture_graph_") :] + return prefix or None + + +def find_capture_graph_trace_path(run_trace_path: str) -> Optional[str]: + """ + Find capture graph trace file in the same directory as the run trace. + Pattern: capture_graph__ts_*.pt.trace.json[.gz] + """ + model_name = extract_model_name_from_trace_filename(run_trace_path) + if not model_name: + return None + trace_dir = os.path.dirname(run_trace_path) or "." + pattern = os.path.join(trace_dir, f"capture_graph_{model_name}_ts_*.pt.trace.json*") + candidates = sorted(glob(pattern), key=os.path.getmtime, reverse=True) + if not candidates: + return None + run_abs = os.path.abspath(run_trace_path) + for fp in candidates: + if os.path.abspath(fp) != run_abs: + return fp + return None + + def write_breakdown_xlsx( output_xlsx: str, rows: List[List[Any]], @@ -103,7 +142,7 @@ def build_groups(block_rows: List[List[Any]]) -> List[Tuple[int, int, str, float renamed_group_mods = [g[2] for g in main_groups] seen_rmsnorm = 0 for gi, mod in enumerate(renamed_group_mods): - if isinstance(mod, str) and "rmsnorm" in mod.lower(): + if is_strict_norm_name(mod): if seen_rmsnorm == 0: renamed_group_mods[gi] = "input_layernorm" elif seen_rmsnorm == 1: @@ -151,39 +190,66 @@ def _normalize_module_for_avg(name: str) -> str: def build_avg_rows_from_layers( - layer_rows_list: List[List[List[Any]]], - layer_start_idx: int, + layer_rows_list: List[Tuple[int, List[List[Any]]]], section_name: str, ) -> Optional[List[List[Any]]]: """ - Build AVG rows across layers using layer-3 rows as template. - Returns None if any layer cannot be aligned by (module, kernel) sequence. + Build AVG rows across layers with fallback: + 1) try contiguous layers from avg_start_layer. + 2) if mismatch, retry every other layer: start, start+2, start+4, ... + Returns None if still not alignable. """ if not layer_rows_list: return [] - base = layer_rows_list[0] - base_sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in base] - - for rel_idx, rows in enumerate(layer_rows_list[1:], start=1): - sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in rows] - if sig != base_sig: - bad_layer = layer_start_idx + rel_idx - print( - f"{section_name} avg skipped: layer {bad_layer} does not match layer {layer_start_idx} layout." + def _try_build( + selected_layers: List[Tuple[int, List[List[Any]]]], + ) -> Tuple[Optional[List[List[Any]]], Optional[int]]: + base_layer, base_rows = selected_layers[0] + base_sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in base_rows] + + for layer_idx, rows in selected_layers[1:]: + sig = [(_normalize_module_for_avg(r[0]), r[1]) for r in rows] + if sig != base_sig: + return None, layer_idx + + n = len(selected_layers) + avg_rows: List[List[Any]] = [] + for i, (_, kernel) in enumerate(base_sig): + # Keep display style from base layer rows. + display_mod = base_rows[i][0] + avg_dur = ( + sum(float(selected_layers[layer_i][1][i][2]) for layer_i in range(n)) + / n ) - return None - - n = len(layer_rows_list) - avg_rows: List[List[Any]] = [] - for i, (mod, kernel) in enumerate(base_sig): - # Keep original module display style from layer_start_idx rows. - display_mod = base[i][0] - avg_dur = ( - sum(float(layer_rows_list[layer_idx][i][2]) for layer_idx in range(n)) / n - ) - avg_rows.append([display_mod, kernel, avg_dur]) - return avg_rows + avg_rows.append([display_mod, kernel, avg_dur]) + return avg_rows, None + + start_layer = layer_rows_list[0][0] + avg_rows, bad_layer = _try_build(layer_rows_list) + if avg_rows is not None: + return avg_rows + + print( + f"{section_name} avg mismatch: layer {bad_layer} does not match layer {start_layer} layout." + ) + fallback_layers = [ + item for item in layer_rows_list if (item[0] - start_layer) % 2 == 0 + ] + fallback_indices = [str(layer_idx) for layer_idx, _ in fallback_layers] + print( + f"{section_name} avg retry with every other layer: {'.'.join(fallback_indices)}" + ) + if len(fallback_layers) < 2: + print(f"{section_name} avg skipped: fallback has fewer than 2 layers.") + return None + + avg_rows, bad_layer = _try_build(fallback_layers) + if avg_rows is not None: + return avg_rows + + print(f"{section_name} avg skipped: still mismatch at layer {bad_layer}.") + return None # ============================================================================= @@ -275,73 +341,6 @@ def has_kernel_launch(self, event: Dict) -> bool: return self.count_kernel_launches(event) > 0 -# ============================================================================= -# Legacy functions (for prefill compatibility) -# ============================================================================= - - -def find_events(events: List[Dict], name: str, prefix: bool = False) -> List[Dict]: - """Find all duration events (ph='X') with given name, sorted by time.""" - if prefix: - result = [ - e - for e in events - if e.get("name", "").startswith(name) and e.get("ph") == "X" - ] - else: - result = [e for e in events if e.get("name") == name and e.get("ph") == "X"] - return sorted(result, key=lambda x: x["ts"]) - - -def get_gpu_kernels(events: List[Dict], start_ts: float) -> List[Dict]: - """Get GPU kernels (cat='kernel') starting from given timestamp.""" - result = [e for e in events if e.get("cat") == "kernel" and e["ts"] >= start_ts] - return sorted(result, key=lambda x: x["ts"]) - - -def get_direct_children(parent: Dict, events: List[Dict]) -> List[Dict]: - """Get direct children of parent event (excluding nested children).""" - p_ts, p_dur = parent["ts"], parent.get("dur", 0) - - candidates = [ - e - for e in events - if e.get("ph") == "X" - and e is not parent - and is_within(e.get("ts", 0), e.get("dur", 0), p_ts, p_dur) - ] - - direct = [] - for c in candidates: - c_ts, c_dur = c["ts"], c.get("dur", 0) - is_direct = not any( - is_within(c_ts, c_dur, o["ts"], o.get("dur", 0)) - for o in candidates - if o is not c - ) - if is_direct: - direct.append(c) - - return sorted(direct, key=lambda x: x["ts"]) - - -def count_kernel_launches(event: Dict, events: List[Dict]) -> int: - """Count kernel launches within event's subtree.""" - e_ts, e_dur = event["ts"], event.get("dur", 0) - return sum( - 1 - for e in events - if e.get("ph") == "X" - and is_kernel_launch(e.get("name", "")) - and is_within(e.get("ts", 0), e.get("dur", 0), e_ts, e_dur) - ) - - -def has_kernel_launch(event: Dict, events: List[Dict]) -> bool: - """Check if event's subtree contains any kernel launch.""" - return count_kernel_launches(event, events) > 0 - - # ============================================================================= # Parse Functions # ============================================================================= @@ -349,14 +348,7 @@ def has_kernel_launch(event: Dict, events: List[Dict]) -> bool: def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: """ - Parse prefill phase: find the actual prefill event on CPU trace (user_annotation). - - Warmup rule: - - If only one prefill exists, it is the actual prefill (no warmup). - - If >=2 prefills exist: - - If there is a decode_step_bs* event between prefill[0] and prefill[1], prefill[0] - is treated as warmup and prefill[1] is the actual prefill. - - Otherwise, prefill[0] is the actual prefill. + Parse prefill phase from a run trace (no warmup mixed in this trace). """ # CPU side prefill/decode annotations. # Accept both legacy "prefill" and traced variants like @@ -376,33 +368,9 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - return actual_prefill_idx = 0 - warmup_detected = False - - # Only evaluate warmup when there are at least two prefills. - if len(prefills) >= 2: - first = prefills[0] - second = prefills[1] - gap_start = first["ts"] + first.get("dur", 0) - gap_end = second["ts"] - - # If decode_step_bs appears in [gap_start, gap_end], first prefill is warmup. - has_decode_between = any( - e.get("ph") == "X" - and e.get("cat") == "user_annotation" - and e.get("name", "").startswith("decode_step_bs") - and gap_start <= e.get("ts", 0) <= gap_end - for e in events - ) - if has_decode_between: - actual_prefill_idx = 1 - warmup_detected = True - actual_prefill = prefills[actual_prefill_idx] print(f"Found {len(prefills)} prefill events (user_annotation)") - if warmup_detected: - print("Warmup detected: decode_step_bs found between prefill[0] and prefill[1]") - else: - print("No warmup prefill detected by rule, using prefill[0]") + print("Using first prefill event (run trace has no warmup phase).") print( f"Using prefill[{actual_prefill_idx}] " f"(ts={actual_prefill.get('ts', 0):.0f}, dur={actual_prefill.get('dur', 0):.0f})" @@ -447,42 +415,11 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - # Layer extraction by rmsnorm positions: # each layer has 2 rmsnorm modules, layer N starts at rmsnorm index 2*N (0-based). - TARGET_LAYER = target_layer all_norm_indices = [ i for i, item in enumerate(launch_level2_items) - if "rmsnorm" in item["level2_event"].get("name", "").lower() + if is_strict_norm_name(item["level2_event"].get("name", "")) ] - # Last rmsnorm is final layernorm, not part of transformer layers. - norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] - print( - f"Found {len(all_norm_indices)} rmsnorm modules in level2-with-launch rows " - f"({len(norm_indices)} used for layer split, excluding final layernorm)" - ) - - mod_start = 0 - mod_end = 0 - norm_start_idx = TARGET_LAYER * 2 - norm_end_idx = (TARGET_LAYER + 1) * 2 - final_norm_idx = ( - all_norm_indices[-1] if len(all_norm_indices) > 0 else len(launch_level2_items) - ) - if norm_start_idx >= len(norm_indices): - print( - f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty XLSX" - ) - else: - mod_start = norm_indices[norm_start_idx] - mod_end = ( - norm_indices[norm_end_idx] - if norm_end_idx < len(norm_indices) - else final_norm_idx - ) - print( - f"Layer {TARGET_LAYER} range by rmsnorm: " - f"rows [{mod_start}:{mod_end}) from rmsnorm #{norm_start_idx+1} to #{norm_end_idx+1}" - ) - print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") # Build launch->kernel mapping by correlation id. # Build launch candidates from current prefill thread/range once. @@ -533,55 +470,38 @@ def parse_prefill(events: List[Dict], output_xlsx: str, target_layer: int = 3) - kernels.append({"name": k.get("name", "N/A"), "dur": k.get("dur", 0)}) item_kernels.append(kernels) + def _resolve_moe_child_name_prefill(event: Dict[str, Any]) -> str: + mod_name = event.get("name", "") + if "moe" not in mod_name.lower(): + return mod_name + children = prefill_idx.get_direct_children(event) + children_with_launch = [c for c in children if prefill_idx.has_kernel_launch(c)] + if children_with_launch: + return children_with_launch[0].get("name", mod_name) + return mod_name + def build_rows_from_item_range(start: int, end: int) -> List[List[Any]]: rows = [] for i in range(start, end): item = launch_level2_items[i] - mod_name = item["level2_event"].get("name", "") + mod_name = _resolve_moe_child_name_prefill(item["level2_event"]) if should_filter_prefill(mod_name): continue kernels = [k for k in item_kernels[i] if k["name"] not in ("", "N/A")] if not kernels: continue - if "moe_forward" in mod_name.lower(): - rows.extend(process_moe_module(mod_name, len(kernels), 0, kernels)) - else: - for k in kernels: - rows.append( - [clean_module_name(mod_name, k["name"]), k["name"], k["dur"]] - ) + rows.extend(process_module(mod_name, len(kernels), 0, kernels)) return rows - # Target layer rows. - csv_rows = ( - build_rows_from_item_range(mod_start, mod_end) - if norm_start_idx < len(norm_indices) - else [] + _extract_layer_and_write( + all_norm_indices, + len(launch_level2_items), + target_layer, + "Prefill", + "prefill", + build_rows_from_item_range, + output_xlsx, ) - print(f"Layer {TARGET_LAYER} launch->kernel mapping rows: {len(csv_rows)}") - - print(f"Prefill decode-style CSV rows (after filters): {len(csv_rows)}") - - # AVG rows from layer 3 to last layer. - avg_rows = None - avg_layer_rows: List[List[List[Any]]] = [] - avg_start_layer = 3 - layer = avg_start_layer - while 2 * layer < len(norm_indices): - s = norm_indices[2 * layer] - e_idx = 2 * (layer + 1) - e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx - avg_layer_rows.append(build_rows_from_item_range(s, e)) - layer += 1 - if avg_layer_rows: - avg_rows = build_avg_rows_from_layers( - avg_layer_rows, avg_start_layer, "Prefill" - ) - if avg_rows is not None: - print(f"Prefill avg rows: {len(avg_rows)}") - - # Write XLSX for prefill. - write_breakdown_xlsx(output_xlsx, csv_rows, sheet_name="prefill", avg_rows=avg_rows) def clean_module_name(name: str, mapped_kernel_name: str = "") -> str: @@ -597,60 +517,105 @@ def clean_module_name(name: str, mapped_kernel_name: str = "") -> str: if name.startswith("aiter::"): name = name[7:] # len('aiter::') == 7 - # Rename based on keywords (rope takes priority) - name_lower = name.lower() - if "rope" in name_lower and "cache" in name_lower: - return "rope & kv_cache" - if "rope" in name_lower: - return "rope" - if "cache" in name_lower and "gemm" not in name_lower: - return "kv_cache" - return name -def process_moe_module( - mod_name: str, kernel_count: int, start_gpu_idx: int, gpu_kernels: List[Dict] -) -> List[List]: +def _extract_layer_and_write( + all_norm_indices: List[int], + total_module_count: int, + target_layer: int, + section_name: str, + sheet_name: str, + build_rows_fn, + output_xlsx: str, +) -> None: + """ + Shared layer-extraction, AVG computation, and XLSX write logic + used by both parse_prefill and parse_decode. + + Args: + all_norm_indices: indices of all norm modules (including final layernorm) + total_module_count: total number of modules (used as fallback final_norm_idx) + target_layer: which layer to extract + section_name: "Prefill" or "Decode" for print messages + sheet_name: XLSX sheet name + build_rows_fn: callable(start, end) -> List[List[Any]] + output_xlsx: output file path """ - Process moe_forward module: categorize kernels by name. + norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] + print( + f"Found {len(all_norm_indices)} norm modules " + f"({len(norm_indices)} used for layer split, excluding final layernorm)" + ) - - 'moesort' in kernel name -> moe_sort - - 'topk' in kernel name -> moe_topk - - others -> keep original mod_name + TARGET_LAYER = target_layer + norm_start_idx = TARGET_LAYER * 2 + norm_end_idx = (TARGET_LAYER + 1) * 2 + final_norm_idx = ( + all_norm_indices[-1] if len(all_norm_indices) > 0 else total_module_count + ) - Returns list of [display_name, gpu_kernel_name, gpu_dur] rows. - """ - rows = [] - for i in range(kernel_count): - gpu_idx = start_gpu_idx + i - gpu_kernel_name = "N/A" - gpu_dur = 0 - if gpu_idx < len(gpu_kernels): - gpu = gpu_kernels[gpu_idx] - gpu_kernel_name = gpu.get("name", "N/A") - gpu_dur = gpu.get("dur", 0) + mod_start = 0 + mod_end = 0 + if norm_start_idx >= len(norm_indices): + print(f"Not enough norms for layer {TARGET_LAYER}") + if sheet_name == "prefill": + print( + f"Not enough rmsnorm modules for layer {TARGET_LAYER}, writing empty XLSX" + ) + write_breakdown_xlsx(output_xlsx, [], sheet_name=sheet_name) + return + else: + mod_start = norm_indices[norm_start_idx] + mod_end = ( + norm_indices[norm_end_idx] + if norm_end_idx < len(norm_indices) + else final_norm_idx + ) + print( + f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] " + f"(norms at indices {norm_start_idx}, {norm_start_idx+1})" + ) - # Determine category based on kernel name - kernel_lower = gpu_kernel_name.lower() - if "moesort" in kernel_lower: - category = "moe_sort" - elif "topk" in kernel_lower: - category = "moe_topk" - else: - category = clean_module_name(mod_name, gpu_kernel_name) - - # Always show category/module name on each row. - display_name = category - rows.append([display_name, gpu_kernel_name, gpu_dur]) + avg_start_layer = TARGET_LAYER + avg_end_layer = (len(norm_indices) - 1) // 2 + if avg_start_layer <= avg_end_layer: + print( + f"Target layer: {TARGET_LAYER}; AVG layers: [{avg_start_layer}..{avg_end_layer}]" + ) + else: + print( + f"Target layer: {TARGET_LAYER}; AVG layers: disabled (no eligible layers)" + ) - return rows + # Target layer rows. + rows = build_rows_fn(mod_start, mod_end) + print(f"Layer {TARGET_LAYER} rows: {len(rows)}") + # AVG rows. + avg_rows = None + avg_layer_rows: List[Tuple[int, List[List[Any]]]] = [] + layer = avg_start_layer + while 2 * layer < len(norm_indices): + s = norm_indices[2 * layer] + e_idx = 2 * (layer + 1) + e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx + avg_layer_rows.append((layer, build_rows_fn(s, e))) + layer += 1 + if avg_layer_rows: + avg_rows = build_avg_rows_from_layers(avg_layer_rows, section_name) + if avg_rows is not None: + print(f"{section_name} avg rows: {len(avg_rows)}") + + write_breakdown_xlsx(output_xlsx, rows, sheet_name=sheet_name, avg_rows=avg_rows) + print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") + print(f"XLSX written to: {output_xlsx} ({len(rows)} rows)") -def process_regular_module( + +def process_module( mod_name: str, kernel_count: int, start_gpu_idx: int, gpu_kernels: List[Dict] ) -> List[List]: - """Process regular module and show module name on every row.""" + """Process a module and return [display_name, gpu_kernel_name, gpu_dur] rows.""" rows = [] for i in range(kernel_count): gpu_idx = start_gpu_idx + i @@ -665,9 +630,16 @@ def process_regular_module( return rows -def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> None: +def parse_decode( + run_events: List[Dict], + capture_events: List[Dict], + output_xlsx: str, + target_layer: int = 3, +) -> None: """ - Parse decode phase: map capture_graph modules to GPU kernels. + Parse decode phase: + - use run trace for first decode_step and real kernel timings + - use capture trace for capture_graph module hierarchy Output CSV columns: cpu_module, gpu_kernel, duration_us """ @@ -676,7 +648,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> # Find GPU-annotated decode_step events (cat='gpu_user_annotation') decode_steps = [ e - for e in events + for e in run_events if e.get("name", "").startswith("decode_step") and e.get("ph") == "X" and e.get("cat") == "gpu_user_annotation" @@ -687,25 +659,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> print("No decode_step (gpu_user_annotation) events found.") return - # Skip warmup: find first gap > 100ms (warmup/run boundary) - # Normal decode gaps are < 5ms, so 100ms is safe threshold - WARMUP_GAP_THRESHOLD = 100000 # 100ms in microseconds - actual_run_idx = 0 - found_warmup_boundary = False - for i in range(1, len(decode_steps)): - gap = decode_steps[i]["ts"] - ( - decode_steps[i - 1]["ts"] + decode_steps[i - 1].get("dur", 0) - ) - if gap > WARMUP_GAP_THRESHOLD: - actual_run_idx = i - found_warmup_boundary = True - print(f"Warmup/run boundary at [{i-1}]->[{i}], gap={gap/1000:.1f}ms") - break - - if not found_warmup_boundary: - print("No warmup detected (no gap > 100ms), using first decode_step") - - first_ds = decode_steps[actual_run_idx] + first_ds = decode_steps[0] first_ds_name = first_ds.get("name", "") target_bs: Optional[int] = None if "_bs_" in first_ds_name: @@ -723,12 +677,14 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> # Find matching capture_graph capture_graphs = [ - e for e in events if e.get("name") == target_cg_name and e.get("ph") == "X" + e + for e in capture_events + if e.get("name") == target_cg_name and e.get("ph") == "X" ] if not capture_graphs and target_bs is not None: # Prefer the largest capture_graph_bs_K where K < target_bs. lower_bs_candidates: List[Tuple[int, Dict[str, Any]]] = [] - for e in events: + for e in capture_events: if e.get("ph") != "X": continue n = e.get("name", "") @@ -749,7 +705,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> # Fallback: find any capture_graph capture_graphs = [ e - for e in events + for e in capture_events if e.get("name", "").startswith("capture_graph") and e.get("ph") == "X" ] capture_graphs = sorted(capture_graphs, key=lambda x: x["ts"]) @@ -767,7 +723,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> cg_end = cg_start + cg.get("dur", 0) cg_events = [ e - for e in events + for e in capture_events if e.get("ph") == "X" and e.get("ts", 0) >= cg_start and e.get("ts", 0) + e.get("dur", 0) <= cg_end @@ -781,7 +737,7 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> gpu_kernels = [ e - for e in events + for e in run_events if e.get("cat") == "kernel" and ds1_start <= e["ts"] <= ds1_end ] gpu_kernels = sorted(gpu_kernels, key=lambda x: x["ts"]) @@ -798,8 +754,19 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> # Collect all modules with their kernel info all_modules = [] # list of (mod_name, kernel_count, start_gpu_idx) + all_module_events = [] gpu_idx = 0 + def _resolve_moe_child_name_decode(event: Dict[str, Any]) -> str: + mod_name = event.get("name", "") + if "moe" not in mod_name.lower(): + return mod_name + children = idx.get_direct_children(event) + children_with_launch = [c for c in children if idx.has_kernel_launch(c)] + if children_with_launch: + return children_with_launch[0].get("name", mod_name) + return mod_name + for child in kernel_children: child_name = child.get("name", "") if should_filter(child_name): @@ -813,87 +780,103 @@ def parse_decode(events: List[Dict], output_xlsx: str, target_layer: int = 3) -> modules = sub_kernel_children if sub_kernel_children else [child] for mod in modules: - mod_name = mod.get("name", "") + mod_name = _resolve_moe_child_name_decode(mod) kernel_count = idx.count_kernel_launches(mod) all_modules.append((mod_name, kernel_count, gpu_idx)) + all_module_events.append(mod) gpu_idx += kernel_count - # Find norm positions (rmsnorm in name) - all_norm_indices = [ - i for i, (name, _, _) in enumerate(all_modules) if "rmsnorm" in name.lower() - ] - # Last rmsnorm is final layernorm, not part of transformer layers. - norm_indices = all_norm_indices[:-1] if len(all_norm_indices) > 0 else [] - print( - f"Found {len(all_norm_indices)} norm modules " - f"({len(norm_indices)} used for layer split, excluding final layernorm)" - ) - - # Extract layer 3 (4th layer, 0-indexed) - # Each layer has 2 norms, so layer N starts at norm index 2*N - TARGET_LAYER = target_layer - norm_start_idx = TARGET_LAYER * 2 # 6 (7th norm, 0-indexed) - norm_end_idx = (TARGET_LAYER + 1) * 2 # 8 (9th norm, 0-indexed) - - final_norm_idx = ( - all_norm_indices[-1] if len(all_norm_indices) > 0 else len(all_modules) + # Decode module sequence should start from the first norm module to avoid + # wrapper/initialization nodes before the model block. + first_norm_module_idx = next( + (i for i, (name, _, _) in enumerate(all_modules) if is_strict_norm_name(name)), + None, ) - if norm_start_idx >= len(norm_indices): - print(f"Not enough norms for layer {TARGET_LAYER}") + if first_norm_module_idx is None: + print("No norm module found in capture_graph modules.") return - - # Module range for layer 3: from norm_indices[6] to norm_indices[8] (exclusive) - mod_start = norm_indices[norm_start_idx] - mod_end = ( - norm_indices[norm_end_idx] - if norm_end_idx < len(norm_indices) - else final_norm_idx + if first_norm_module_idx > 0: + print(f"Dropped {first_norm_module_idx} leading modules before first norm.") + all_modules = all_modules[first_norm_module_idx:] + all_module_events = all_module_events[first_norm_module_idx:] + + # Anchor module->kernel alignment by first norm's correlated launch kernel: + # 1) find first launch inside first norm and read its correlation/kernel + # 2) find same kernel's first occurrence in run decode kernels + # 3) rebuild module start indices from that anchor. + capture_runtime_launches = [ + e + for e in cg_events + if e.get("cat") == "cuda_runtime" and is_kernel_launch(e.get("name", "")) + ] + capture_runtime_launches.sort(key=lambda x: x.get("ts", 0)) + capture_launch_ts = [e.get("ts", 0) for e in capture_runtime_launches] + + def _first_kernel_name_in_capture(mod_event: Dict[str, Any]) -> Optional[str]: + m_start = mod_event.get("ts", 0) + m_end = m_start + mod_event.get("dur", 0) + left = bisect.bisect_left(capture_launch_ts, m_start) + right = bisect.bisect_right(capture_launch_ts, m_end) + for launch in capture_runtime_launches[left:right]: + corr = (launch.get("args") or {}).get("correlation") + if corr is None: + continue + kernel_name = (launch.get("args") or {}).get("kernel", "") + if kernel_name: + return str(kernel_name) + return None + + anchor_kernel_name = _first_kernel_name_in_capture(all_module_events[0]) + if not anchor_kernel_name: + raise RuntimeError( + "Cannot resolve anchor kernel from first rmsnorm correlation in capture trace." + ) + found = next( + ( + i + for i, k in enumerate(gpu_kernels) + if k.get("name", "") == anchor_kernel_name + ), + None, ) - + if found is None: + raise RuntimeError( + f"Anchor kernel '{anchor_kernel_name}' not found in run decode kernels." + ) + anchor_gpu_idx = found print( - f"Layer {TARGET_LAYER}: modules [{mod_start}:{mod_end}] (norms at indices {norm_start_idx}, {norm_start_idx+1})" + f"Aligned from first norm kernel: {anchor_kernel_name} at gpu_idx={anchor_gpu_idx}" ) + rebuilt_modules = [] + running_gpu_idx = anchor_gpu_idx + for name, count, _ in all_modules: + rebuilt_modules.append((name, count, running_gpu_idx)) + running_gpu_idx += count + all_modules = rebuilt_modules + + # Find norm positions (rmsnorm in name) + all_norm_indices = [ + i for i, (name, _, _) in enumerate(all_modules) if is_strict_norm_name(name) + ] + def build_rows_for_module_range(start: int, end: int) -> List[List[Any]]: rows = [] for mod_name, kernel_count, start_gpu_idx in all_modules[start:end]: - if "moe_forward" in mod_name.lower(): - rows.extend( - process_moe_module( - mod_name, kernel_count, start_gpu_idx, gpu_kernels - ) - ) - else: - rows.extend( - process_regular_module( - mod_name, kernel_count, start_gpu_idx, gpu_kernels - ) - ) + rows.extend( + process_module(mod_name, kernel_count, start_gpu_idx, gpu_kernels) + ) return rows - # Target layer rows. - rows = build_rows_for_module_range(mod_start, mod_end) - - # AVG rows from layer 3 to last layer. - avg_rows = None - avg_layer_rows: List[List[List[Any]]] = [] - layer = 3 - while 2 * layer < len(norm_indices): - s = norm_indices[2 * layer] - e_idx = 2 * (layer + 1) - e = norm_indices[e_idx] if e_idx < len(norm_indices) else final_norm_idx - avg_layer_rows.append(build_rows_for_module_range(s, e)) - layer += 1 - if avg_layer_rows: - avg_rows = build_avg_rows_from_layers(avg_layer_rows, 3, "Decode") - if avg_rows is not None: - print(f"Decode avg rows: {len(avg_rows)}") - - # Write XLSX - write_breakdown_xlsx(output_xlsx, rows, sheet_name="decode", avg_rows=avg_rows) - - print(f"Layer {TARGET_LAYER} modules: {mod_end - mod_start}") - print(f"XLSX written to: {output_xlsx} ({len(rows)} rows)") + _extract_layer_and_write( + all_norm_indices, + len(all_modules), + target_layer, + "Decode", + "decode", + build_rows_for_module_range, + output_xlsx, + ) # ============================================================================= @@ -918,10 +901,24 @@ def main(): filepath = args.filepath target_layer = args.layer - print(f"Loading: {filepath}") + print(f"Loading run trace: {filepath}") trace = load_trace(filepath) events = trace.get("traceEvents", []) - print(f"Loaded {len(events)} events\n") + print(f"Loaded run events: {len(events)}") + + capture_trace_path = find_capture_graph_trace_path(filepath) + if capture_trace_path is None: + print( + "Warning: matching capture trace not found; decode analysis will fallback " + "to current trace for capture_graph hierarchy." + ) + capture_events = events + else: + print(f"Loading capture trace: {capture_trace_path}") + capture_trace = load_trace(capture_trace_path) + capture_events = capture_trace.get("traceEvents", []) + print(f"Loaded capture events: {len(capture_events)}") + print("") print("=" * 60) print("PREFILL ANALYSIS") @@ -931,7 +928,12 @@ def main(): print("\n" + "=" * 60) print("DECODE ANALYSIS") print("=" * 60) - parse_decode(events, "decode_breakdown.xlsx", target_layer=target_layer) + parse_decode( + events, + capture_events, + "decode_breakdown.xlsx", + target_layer=target_layer, + ) if __name__ == "__main__":