Skip to content
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,48 @@ python -m atom.benchmarks.benchmark_serving \
--result-dir=./ --result-filename=$RESULT_FILENAME.json
```

### Profile Analyze

ATOM supports automatic trace collection and analysis, which breaks down GPU kernel durations per module for both **prefill** and **decode** phases and exports the results to Excel (`.xlsx`) files.

#### Step 1: Collect a Trace

Launch the server with `--torch-profiler-dir` to enable the PyTorch profiler and `--mark-trace` to insert per-module annotations into the trace. Set `TORCHINDUCTOR_COMPILE_THREADS=1` to ensure deterministic compilation order.

```bash
TORCHINDUCTOR_COMPILE_THREADS=1 python -m atom.entrypoints.openai_server \
--model deepseek-ai/DeepSeek-R1 \
--kv_cache_dtype fp8 -tp 8 \
--torch-profiler-dir ./trace \
--mark-trace
```

After the server processes requests and shuts down, two `*.json.gz` trace files will be generated in the `--torch-profiler-dir` directory.

#### Step 2: Analyze the Trace

Run `parse_trace.py` on the collected trace file(use it on trace file start with the model name):

```bash
python ATOM/tools/parse_trace.py ./trace/model_name_ts_*.json.gz
```

This produces two Excel files in the current directory:

| Output File | Description |
|---|---|
| `prefill_breakdown.xlsx` | Per-kernel duration breakdown for one prefill layer |
| `decode_breakdown.xlsx` | Per-kernel duration breakdown for one decode layer |

Each file contains columns: `cpu_module`, `gpu_kernel`, `duration_us`, `sum per module`, plus averaged values across layers.

**Options:**

| Flag | Default | Description |
|---|---|---|
| `--layer N` | `3` | Target transformer layer index to analyze (0-indexed) |


## 📊 Performance

### Online Serving Throughput
Expand Down
7 changes: 6 additions & 1 deletion atom/model_engine/engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,18 @@ def __init__(self, config: Config, input_address: str, output_address: str):
if not config.enforce_eager:
# Start profiler before cudagraph capture only if mark-trace is enabled.
if self.profile_enbaled and self.mark_trace:
self.runner_mgr.call_func("start_profiler", wait_out=True)
self.runner_mgr.call_func(
"start_profiler", "capture_graph", wait_out=True
)
cap_cost, bs = self.runner_mgr.call_func(
"capture_cudagraph", wait_out=True
)
logger.info(
f"{self.label}: cudagraph capture{bs} cost: {cap_cost:.2f} seconds"
)
if self.profile_enbaled and self.mark_trace:
# Persist a dedicated capture-graph trace immediately.
self.runner_mgr.call_func("stop_profiler", wait_out=True)
good = True
finally:
logger.info(
Expand Down
50 changes: 35 additions & 15 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import os
import time
import gzip
from contextlib import nullcontext
from typing import Any, Optional, Union

Expand Down Expand Up @@ -690,7 +691,7 @@ def exit(self):
torch.cuda.empty_cache()
return True

def start_profiler(self):
def start_profiler(self, trace_name: Optional[str] = None):
"""
Start profiling for this rank.

Expand All @@ -700,6 +701,35 @@ def start_profiler(self):
"""
if self.profiler_dir is not None and self.profiler is None:
enable_detailed_profiling = envs.ATOM_PROFILER_MORE
model_name = os.path.basename(self.config.model.rstrip("/"))
safe_model_name = "".join(
c if c.isalnum() or c in ("_", "-", ".") else "_" for c in model_name
)
worker_name = safe_model_name or "trace"
if isinstance(trace_name, str) and trace_name:
worker_name = "".join(
c if c.isalnum() or c in ("_", "-", ".") else "_"
for c in trace_name
)
if worker_name == "capture_graph":
if safe_model_name:
worker_name = f"{worker_name}_{safe_model_name}"
output_prefix = os.path.join(self.profiler_dir, worker_name)

def _on_trace_ready(prof):
# Use a short human-readable timestamp in file name.
ts = time.strftime("%Y%m%d_%H%M%S", time.localtime())
ms = int((time.time() % 1) * 1000)
output_path = f"{output_prefix}_ts_{ts}_{ms:03d}.pt.trace.json.gz"
tmp_json_path = output_path[:-3]
prof.export_chrome_trace(tmp_json_path)
with (
open(tmp_json_path, "rb") as src,
gzip.open(output_path, "wb") as dst,
):
dst.write(src.read())
os.remove(tmp_json_path)

self.profiler = torch_profiler.profile(
activities=[
torch_profiler.ProfilerActivity.CPU,
Expand All @@ -708,9 +738,7 @@ def start_profiler(self):
record_shapes=enable_detailed_profiling,
with_stack=enable_detailed_profiling,
profile_memory=enable_detailed_profiling,
on_trace_ready=torch_profiler.tensorboard_trace_handler(
self.profiler_dir, use_gzip=True
),
on_trace_ready=_on_trace_ready,
)
self.profiler.__enter__()
return True
Expand Down Expand Up @@ -1504,21 +1532,13 @@ def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor
is_prefill = context.is_prefill
positions = context.positions
if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]:
with (
record_function(
f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}"
)
if self.mark_trace
else nullcontext()
with record_function(
f"prefill_bs_{bs}_ctxlens_{forward_context.attn_metadata.context_lens}"
):
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
else:
with (
record_function(f"decode_step_bs_{bs}")
if self.mark_trace
else nullcontext()
):
with record_function(f"decode_step_bs_{bs}"):
graph_bs = context.graph_bs
max_q_len = forward_context.attn_metadata.max_seqlen_q
graph_key = (graph_bs, max_q_len)
Expand Down
6 changes: 6 additions & 0 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from atom.plugin.prepare import is_plugin_mode, is_vllm
from atom.plugin.attention_mha import PagedAttentionImplDecoratorForPluginMode
from atom.utils.decorators import mark_trace


@PagedAttentionImplDecoratorForPluginMode
Expand Down Expand Up @@ -110,6 +111,7 @@ def forward_impl_server_mode(

return o

@mark_trace(prefix="rope_cache", torch_compile=False)
def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
attn_metadata = fwd_ctx.attn_metadata
kv_cache_data = fwd_ctx.kv_cache_data
Expand Down Expand Up @@ -214,6 +216,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):

return q, k, v, k_cache, v_cache, k_scale, v_scale

@mark_trace(prefix="paged_attention_triton", torch_compile=False)
def paged_attention_triton(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
):
Expand Down Expand Up @@ -290,6 +293,7 @@ def paged_attention_triton(

return o

@mark_trace(prefix="paged_attention_asm", torch_compile=False)
def paged_attention_asm(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
):
Expand All @@ -312,6 +316,7 @@ def paged_attention_asm(

return o

@mark_trace(prefix="paged_attention_persistent_asm", torch_compile=False)
def paged_attention_persistent_asm(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
):
Expand Down Expand Up @@ -341,6 +346,7 @@ def paged_attention_persistent_asm(

return output

@mark_trace(prefix="prefill_attention", torch_compile=False)
def prefill_attention(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
):
Expand Down
13 changes: 13 additions & 0 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from atom.model_ops.linear import use_triton_gemm
from atom.model_ops.utils import get_and_maybe_dequant_weights
from atom.utils import envs
from atom.utils.decorators import mark_trace

from atom.utils.forward_context import (
AttentionMetaData,
ForwardContext,
Expand All @@ -39,6 +41,15 @@

from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode

concat_and_cache_mla = mark_trace(
concat_and_cache_mla, prefix="kv_cache", torch_compile=False
)
fused_qk_rope_concat_and_cache_mla = mark_trace(
fused_qk_rope_concat_and_cache_mla, prefix="rope_and_kv_cache", torch_compile=False
)
mla_prefill_fwd = mark_trace(mla_prefill_fwd, prefix="mla_prefill", torch_compile=False)
mla_decode_fwd = mark_trace(mla_decode_fwd, prefix="mla_decode", torch_compile=False)

# torch.set_printoptions(threshold=10_000)

logger = logging.getLogger("atom")
Expand Down Expand Up @@ -194,6 +205,7 @@ def process_weights_after_loading(self):
W_V, dtype=dtypes.fp8
)

@mark_trace(prefix="v_up_proj_and_o_proj", torch_compile=False)
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
Expand Down Expand Up @@ -228,6 +240,7 @@ def _v_up_proj_and_o_proj(self, x):
x = x.reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)

@mark_trace(prefix="q_proj_and_k_up_proj", torch_compile=False)
def _q_proj_and_k_up_proj(self, x, x_scale=None):
q_nope, q_pe = (
self.q_proj(x, x_scale)
Expand Down
2 changes: 2 additions & 0 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
handle_torch_function,
)
from atom.config import QuantizationConfig, LayerQuantConfig
from atom.utils.decorators import mark_trace
from torch import nn
from aiter import (
rmsnorm2d_fwd,
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
self.quant_type = quant_type
self.params_dtype = params_dtype

@mark_trace(prefix="rmsnorm", torch_compile=True)
def forward(
self,
x: torch.Tensor,
Expand Down
6 changes: 3 additions & 3 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
shuffle_weights,
)
from atom.utils import envs
from atom.utils.decorators import mark_trace, record_function
from atom.utils.decorators import mark_trace
from torch import nn

logger = logging.getLogger("atom")
Expand Down Expand Up @@ -174,7 +174,7 @@ def gemm_a8w8_blockscale_preshuffle_fake(
return torch.empty((*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device)


@record_function
@mark_trace(torch_compile=False)
@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[])
def gemm_a8w8_blockscale_preshuffle_impl(
x: torch.Tensor,
Expand All @@ -194,7 +194,6 @@ def gemm_a8w8_blockscale_preshuffle_impl(
return y


@mark_trace
class LinearBase(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -386,6 +385,7 @@ def process_weights_after_loading(self):
if self.quant_type == QuantType.per_1x32:
self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data)

@mark_trace
def forward(
self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16
) -> torch.Tensor:
Expand Down
5 changes: 5 additions & 0 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from atom.utils import envs
from atom.utils.custom_register import direct_register_custom_op
from atom.utils.forward_context import get_forward_context
from atom.utils.decorators import mark_trace
from torch import nn
from transformers import PretrainedConfig
from atom.plugin.moe import FusedMoEDecoratorForPluginMode
Expand Down Expand Up @@ -449,6 +450,7 @@ def get_fused_moe_quant_config(
) -> FusedMoEQuantConfig | None:
return FUSED_MOE_UNQUANTIZED_CONFIG

@mark_trace(prefix="unquantized_moe", torch_compile=False)
def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -859,6 +861,7 @@ def get_fused_moe_quant_config(
w2_scale=layer.w2_weight_scale,
)

@mark_trace(prefix="mxfp4_moe", torch_compile=False)
def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -1273,6 +1276,7 @@ def get_fused_moe_quant_config(
block_shape=block_shape,
)

@mark_trace(prefix="compressed_fp8_moe", torch_compile=False)
def apply(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -1637,6 +1641,7 @@ def get_fused_moe_quant_config(
block_shape=None,
)

@mark_trace(prefix="fp8_moe", torch_compile=False)
def apply(
self,
layer: torch.nn.Module,
Expand Down
3 changes: 3 additions & 0 deletions atom/model_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from aiter import dtypes
from typing import Union, Optional

from atom.utils.decorators import mark_trace


def apply_rotary_emb(
x: torch.Tensor,
Expand Down Expand Up @@ -77,6 +79,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
sin = freqs.sin().unsqueeze(-2).unsqueeze(-2)
return cos, sin

@mark_trace(prefix="rope_cached")
def forward(
self,
positions: torch.Tensor,
Expand Down
Loading
Loading