Skip to content

Latest commit

 

History

History
542 lines (404 loc) · 25.4 KB

File metadata and controls

542 lines (404 loc) · 25.4 KB

ATOM Model Operations Guide

ATOM (AiTer Optimized Model) wraps AITER kernels with model-level abstractions for LLM inference on AMD ROCm/HIP GPUs. This guide documents every operator class in atom/model_ops/, their AITER kernel mappings, quantization paths, and fused kernel chains.


Quick Reference

ATOM Class File AITER Kernel / Import Purpose
LinearBase linear.py tgemm.mm, gemm_a8w8, gemm_a8w8_bpreshuffle, gemm_a8w8_blockscale_bpreshuffle, gemm_a4w4 Quantized linear dispatch
ColumnParallelLinear linear.py (inherits LinearBase) Column-sharded TP linear
RowParallelLinear linear.py (inherits LinearBase) Row-sharded TP linear
QKVParallelLinear linear.py (inherits ColumnParallelLinear) Fused Q/K/V projection
MergedColumnParallelLinear linear.py (inherits LinearBase) Merged gate+up projection
Attention base_attention.py unified_attention_with_output_base (custom op) Unified attention entry
MHA Attention attention_mha.py flash_attn_varlen_func, pa_fwd_asm, pa_persistent_fwd, pa_decode_gluon Multi-head attention
MLAAttention attention_mla.py mla_decode_fwd, mla_prefill_fwd, concat_and_cache_mla, fused_qk_rope_concat_and_cache_mla Multi-head latent attention
FusedMoE moe.py aiter.fused_moe.fused_moe, asm_moe Mixture of experts
RMSNorm layernorm.py rmsnorm2d_fwd, rmsnorm2d_fwd_with_add, fused_add_rmsnorm_pad RMS normalization
LayerNorm layernorm.py layernorm2d_fwd, layernorm2d_fwd_with_add Layer normalization
SiluAndMul activation.py aiter.silu_and_mul SiLU gated activation
VocabParallelEmbedding embed_head.py F.embedding + TP all-reduce Vocab embedding
ParallelLMHead embed_head.py tgemm.mm + tensor_model_parallel_all_gather LM output head
RotaryEmbedding rotary_embedding.py aiter.rope_cached_positions_2c_fwd_inplace Rotary position embedding
Sampler sampler.py aiter.mixed_sample_outer_exponential, aiter.ops.triton.topk.topk, aiter.ops.triton.softmax.softmax Token sampling
RejectionSampler rejection_sampler.py Triton rejection_greedy_sample_kernel Speculative decoding

1. AITER Integration Overview

ATOM is a thin model-level inference engine. Every compute-heavy operation delegates to an AITER kernel. The general pattern is:

  1. An ATOM nn.Module owns model weights and configuration.
  2. Its forward() method selects the appropriate AITER function based on quantization type, parallelism settings, and phase (prefill vs. decode).
  3. Results are optionally reduced across tensor-parallel (TP) or data-parallel (DP) groups.

AITER Kernel Mapping Table

ATOM Wrapper AITER Function / Import Path Backend Type
LinearBase.forward (No quant) aiter.tuned_gemm.tgemm.mm hipBLASLt
LinearBase.forward (per_Tensor FP8) aiter.tuned_gemm.tgemm.mm with scales hipBLASLt
LinearBase.forward (per_Token INT8) aiter.gemm_a8w8 CK
LinearBase.forward (per_Token FP8) aiter.gemm_a8w8_bpreshuffle CK
LinearBase.forward (per_1x128 FP8) aiter.gemm_a8w8_blockscale_bpreshuffle CK
LinearBase.forward (per_1x32 MXFP4) aiter.gemm_a4w4 CK
MHA prefill aiter.flash_attn_varlen_func ASM / CK
MHA decode (ASM) aiter.pa_fwd_asm ASM
MHA decode (persistent ASM) aiter.pa_persistent_fwd ASM
MHA decode (Triton) aiter.ops.triton.gluon.pa_decode_gluon Triton
MHA prefill (Triton unified) aiter.ops.triton.unified_attention.unified_attention Triton
MLA decode aiter.mla.mla_decode_fwd ASM
MLA prefill aiter.mla.mla_prefill_fwd ASM
MLA KV cache aiter.concat_and_cache_mla CK
RoPE aiter.rope_cached_positions_2c_fwd_inplace Triton
RMSNorm aiter.rmsnorm2d_fwd CK
SiLU+Mul aiter.silu_and_mul CK
TopK routing aiter.topk_softmax, aiter.grouped_topk, aiter.biased_grouped_topk CK
Sampling aiter.mixed_sample_outer_exponential CK
FusedMoE aiter.fused_moe.fused_moe CK
ASM MoE aiter.fused_moe_bf16_asm.asm_moe ASM
Quantization aiter.get_hip_quant(QuantType) CK / Triton

2. Linear Operations

All linear layers inherit from LinearBase in atom/model_ops/linear.py.

2.1 Class Hierarchy

LinearBase (nn.Module)
  +-- ReplicatedLinear          # No TP sharding
  |     +-- MergedReplicatedLinear
  +-- ColumnParallelLinear      # tp_dim=0, shard output
  |     +-- QKVParallelLinear   # Fused Q/K/V with per-head sharding
  +-- MergedColumnParallelLinear # tp_dim=0, merged gate+up
  +-- RowParallelLinear          # tp_dim=1, shard input, optional all-reduce

2.2 Quantization Dispatch

LinearBase.forward() dispatches to different GEMM kernels based on QuantType:

QuantType Weight dtype GEMM Kernel Scale Shape
No BF16/FP16 tgemm.mm (hipBLASLt) None
per_Tensor FP8 tgemm.mm with scale_a, scale_b [num_partitions, 1]
per_Token (INT8) INT8 gemm_a8w8 [output_size, 1]
per_Token (FP8) FP8 gemm_a8w8_bpreshuffle [output_size, 1]
per_1x128 FP8 gemm_a8w8_blockscale_bpreshuffle [ceil(N/128), ceil(K/128)]
per_1x32 MXFP4 (fp4x2) gemm_a4w4 [N, ceil(K/32)] (e8m0)

When x_scale is not provided, the input is dynamically quantized via get_hip_quant(quant_type).

2.3 Tensor Parallel Sharding

  • ColumnParallelLinear (tp_dim=0): Shards weight rows (output dimension) across GPUs. Each GPU owns output_size / tp_size rows.
  • RowParallelLinear (tp_dim=1): Shards weight columns (input dimension). If reduce_results=True, output is all-reduced across TP group.
  • QKVParallelLinear: Extends ColumnParallelLinear with per-head sharding. Q heads are evenly divided; KV heads are either divided or replicated when num_kv_heads < tp_size.
  • MergedColumnParallelLinear: Handles gate and up projections merged into a single weight with output_sizes as a list (e.g., [intermediate_size, intermediate_size]).

2.4 Weight Processing

After loading, process_weights_after_loading() handles:

  • e4m3fn to e4m3fnuz normalization (AMD FP8 format conversion).
  • Weight reshuffling via shuffle_weights() for pre-shuffled GEMM kernels.
  • Scale reshuffling via fp4_utils.e8m0_shuffle() for MXFP4 block scales.
  • Per-tensor requantization via requantize_with_max_scale() when multiple output partitions have separate scales.

3. Attention Operations

3.1 Base: Attention (base_attention.py)

The top-level Attention class in base_attention.py is a dispatcher. It:

  1. Selects the backend via get_attn_backend() from atom/utils/selector.py.
  2. Instantiates the backend's implementation class (impl_cls).
  3. Registers itself in compilation_config.static_forward_context under layer_name.
  4. On forward(), calls torch.ops.aiter.unified_attention_with_output_base, which is a custom op decorated with @mark_spliting_op -- this prevents torch.compile from tracing into attention internals, enabling full-graph capture.

Backend selection logic (in selector.py):

Condition Backend Class Implementation
use_mla=True AiterMLABackend MLAAttention from attention_mla.py
use_mla=False AiterBackend Attention from attention_mha.py

3.2 Multi-Head Attention (attention_mha.py)

The MHA Attention class handles standard models (Llama, Qwen3, Mixtral, etc.).

Forward flow:

  1. Reshape Q, K, V to [num_tokens, num_heads, head_dim].
  2. Apply RoPE + KV cache write via rope_cache().
  3. Dispatch to the appropriate backend via dispatch_backend().

RoPE + KV cache paths:

Condition Kernel Chain
q_norm + k_norm + rotary_emb present fused_qk_norm_rope_cache_quant_shuffle (single fused kernel for QK norm, RoPE, cache write, optional FP8 quant)
Triton path (sliding_window != -1 or head_dim != 128) + rotary_emb fused_qk_rope_reshape_and_cache (Triton fused RoPE + reshape + cache)
ASM path + rotary_emb rotary_emb(position, q, k) then reshape_and_cache or reshape_and_cache_with_pertoken_quant

Attention dispatch:

Phase Condition Method AITER Kernel
Prefill Always prefill_attention aiter.flash_attn_varlen_func
Decode use_triton_attn=True paged_attention_triton torch.ops.aiter.pa_decode_gluon
Decode block_size == 1024 paged_attention_persistent_asm aiter.pa_persistent_fwd
Decode Default paged_attention_asm aiter.pa_fwd_asm

The use_triton_attn flag is set when sliding_window != -1 or head_dim != 128.

3.3 Multi-head Latent Attention (attention_mla.py)

MLAAttention implements DeepSeek's MLA with a compressed KV representation. Key data structures:

@dataclass
class MLAModules:
    q_lora_rank: Optional[int]
    kv_lora_rank: int
    qk_nope_head_dim: int
    qk_rope_head_dim: int
    qk_head_dim: int
    v_head_dim: int
    rotary_emb: torch.nn.Module
    q_proj: Optional[torch.nn.Module]
    kv_b_proj: torch.nn.Module
    o_proj: torch.nn.Module
    indexer: Optional[torch.nn.Module]

Forward flow:

  1. If prefill and not sparse: Standard MHA-style prefill with flash_attn_varlen_func, preceded by kv_b_proj GEMM to produce K_nope and V from compressed kv_c_normed.
  2. Otherwise: Fused Q projection + K up-projection via batched FP8/FP4 BMM (_q_proj_and_k_up_proj), then:
    • fused_qk_rope_concat_and_cache_mla writes to KV cache.
    • Decode: mla_decode_fwd (ASM persistent MLA kernel).
    • Prefill (sparse): mla_prefill_fwd.
  3. V up-projection + O projection via batched BMM (_v_up_proj_and_o_proj).

Batched GEMM backends for MLA projections:

Condition Kernel
ATOM_USE_TRITON_MXFP4_BMM=True batched_gemm_a16wfp4 (Triton FP4 BMM)
Default batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant (Triton FP8 BMM)

Prefill GEMM optimizations (for kv_b_proj):

Condition Kernel
ATOM_USE_TRITON_GEMM=True + FP4 weights fused_gemm_afp4wfp4_preshuffle_split_cat (GEMM + split K/V + cat rope in one kernel)
ATOM_USE_TRITON_GEMM=True + FP8 weights fused_gemm_a8w8_blockscale_preshuffle_split_cat
Default kv_b_proj(kv_c_normed) then manual split + cat

3.4 Backend Abstraction (attentions/backends.py)

The AttentionBackend abstract class defines three required methods:

  • get_name() -- Returns backend identifier string.
  • get_builder_cls() -- Returns the AttentionMetadataBuilder subclass.
  • get_impl_cls() -- Returns the attention implementation class.

CommonAttentionBuilder provides shared metadata preparation (slot mapping, block tables, cumulative sequence lengths) used by both AiterBackend and AiterMLABackend.

3.5 KV Cache Operations

Operation AITER Kernel Used By
Standard KV cache write aiter.reshape_and_cache MHA (BF16 KV)
FP8 KV cache write aiter.reshape_and_cache_with_pertoken_quant MHA (FP8 KV)
MLA KV cache write aiter.concat_and_cache_mla MLA prefill
Fused QK RoPE + MLA cache aiter.fused_qk_rope_concat_and_cache_mla MLA decode

4. Mixture of Experts (MoE)

4.1 FusedMoE Class (moe.py)

FusedMoE is the top-level MoE module. It handles:

  • Expert routing via select_experts().
  • Weight creation and quantization dispatch via quant_method.
  • Tensor/Expert/Data parallelism via FusedMoEParallelConfig.
  • Optional shared expert fusion and MORI communication.

Constructor parameters:

FusedMoE(
    num_experts: int,        # Global number of experts
    top_k: int,              # Experts per token
    hidden_size: int,        # Input hidden dimension
    intermediate_size: int,  # Expert intermediate dimension
    reduce_results: bool,    # Whether to all-reduce output
    renormalize: bool,       # Renormalize routing weights
    use_grouped_topk: bool,  # Use grouped top-k (DeepSeek)
    activation: ActivationType,  # Silu, Gelu, Swiglu, etc.
    ...
)

4.2 Quantization Methods

FusedMoE selects a quant_method at construction time:

Quant Config Method Class GEMM Kernel
QuantType.No UnquantizedFusedMoEMethod aiter.fused_moe.fused_moe
FP8 (dtypes.fp8) Fp8MoEMethod aiter.fused_moe.fused_moe with quant_type
FP8 compressed-tensors CompressedTensorsFp8MoEMethod aiter.fused_moe.fused_moe or asm_moe
MXFP4 (dtypes.fp4x2) Mxfp4MoEMethod aiter.fused_moe.fused_moe or Triton triton_kernel_moe_forward

The ASM MoE path (asm_moe from aiter.fused_moe_bf16_asm) is used by FP8 methods and supports a16 mode where activations remain in BF16/FP16 while weights are FP8/INT8.

4.3 TopK Routing (topK.py)

Routing Function AITER Kernel Used For
rocm_aiter_topk_softmax aiter.topk_softmax Standard top-k (Mixtral)
rocm_aiter_grouped_topk aiter.grouped_topk Grouped top-k (DeepSeek)
rocm_aiter_biased_grouped_topk aiter.biased_grouped_topk Biased grouped top-k (DeepSeek V3)

Shared expert fusion: When is_rocm_aiter_fusion_shared_expert_enabled() returns True, the top-k buffers are extended with shared expert IDs appended after routed expert IDs. This allows shared expert computation to be fused into the same MoE kernel call. The metadata is initialized via init_aiter_topK_meta_data().

4.4 FusedMoEParallelConfig

@dataclass
class FusedMoEParallelConfig:
    tp_size: int       # Tensor parallel size
    dp_size: int       # Data parallel size
    ep_size: int       # Expert parallel size
    tp_rank: int
    dp_rank: int
    ep_rank: int
    use_ep: bool       # Whether expert parallelism is active
    local_ep_size: int # Local EP size (GPUs per node * TP)

Key properties:

  • use_all2all_kernels: True when dp_size > 1, EP is enabled, and MORI is available.
  • use_mori_kernels: Always True (currently).

4.5 MORI Integration (fused_moe/mori_prepare_finalize.py)

MORI (MoE Router Infrastructure) provides all-to-all communication kernels for expert parallelism. MoriPrepareAndFinalize implements:

  • prepare(): Dispatches tokens to remote experts via mori_op.dispatch(). Optionally quantizes activations to FP8 before dispatch.
  • finalize(): Combines expert outputs via mori_op.combine() and copies results back.

The FusedMoEModularKernel orchestrates the prepare-compute-finalize pipeline.

4.6 MoE Quantization Config (fused_moe/config.py)

FusedMoEQuantConfig describes activation and weight quantization for MoE layers:

@dataclass
class FusedMoEQuantConfig:
    _a1: FusedMoEQuantDesc   # First activation (input to gate_up)
    _a2: FusedMoEQuantDesc   # Second activation (input to down_proj)
    _w1: FusedMoEQuantDesc   # gate_up_proj weights
    _w2: FusedMoEQuantDesc   # down_proj weights

Factory functions:

  • fp8_w8a8_moe_quant_config() -- FP8 weights and activations.
  • mxfp4_w4a16_moe_quant_config() -- MXFP4 weights, unquantized activations.
  • FUSED_MOE_UNQUANTIZED_CONFIG -- No quantization.

4.7 Triton MoE Fallback (fused_moe_triton.py)

triton_kernel_moe_forward() provides a Triton-based MoE path using the triton_kernels library. It uses routing() for expert assignment and matmul_ogs() for the expert GEMM. This path is currently used for MXFP4 MoE on GFX94x hardware.


5. Normalization

5.1 RMSNorm (layernorm.py)

RMSNorm supports multiple forward paths depending on configuration flags:

Condition Kernel / Path Returns
x_pad_to_multiple > 0, no residual fused_rmsnorm_pad_ (Triton fused_add_rmsnorm_pad) Padded output
x_pad_to_multiple > 0, with residual fused_add_rmsnorm_pad_ (output, residual)
fused_allreduce=True and tp_size > 1 tensor_model_parallel_fused_allreduce_rmsnorm (output, residual)
fused_quant=True and x_scale provided fused_rms_fp8_per_tensor_static_quant (FP8 output, scale)
fused_quant=True and per_1x32 fused_rms_mxfp4_quant (MXFP4 output, scale)
Default, no residual rmsnorm2d_fwd Output
Default, with residual rmsnorm2d_fwd_with_add (output, residual)

Constructor parameters:

RMSNorm(
    dim: int,
    eps: float = 1e-6,
    x_pad_to_multiple: int = 0,
    fused_allreduce: bool = False,
    fused_quant: bool = False,
    quant_config: Optional[QuantizationConfig] = None,
)

5.2 LayerNorm (layernorm.py)

LayerNorm wraps layernorm2d_fwd and layernorm2d_fwd_with_add (with bias support):

LayerNorm(dim: int, eps: float = 1e-6)
  • Without residual: layernorm2d_fwd(x, weight, bias, eps)
  • With residual: layernorm2d_fwd_with_add(out, x, residual, residual_out, weight, bias, eps)

6. Activation Functions

6.1 SiluAndMul (activation.py)

SiluAndMul computes SiLU(x_first_half) * x_second_half. It splits the last dimension in half.

Condition Kernel Output
fused_quant=True + x_scale provided (FP8) fused_silu_mul_fp8_per_tensor_static_quant (FP8 output, scale)
fused_quant=True + per_1x32 (MXFP4) fused_reduce_act_mul_and_mxfp4_quant (via mxfp4_act_mul_quant_fuse) (MXFP4 output, scale)
Default aiter.silu_and_mul(out, x) BF16 output

Constructor:

SiluAndMul(
    fused_quant: bool = False,
    quant_config: Optional[QuantizationConfig] = None,
)

7. Embedding & Output Head

7.1 VocabParallelEmbedding (embed_head.py)

Partitions the vocabulary across TP ranks. Each rank holds num_embeddings / tp_size rows.

Forward:

  1. Mask input token IDs to this rank's partition range [vocab_start_idx, vocab_end_idx).
  2. F.embedding() on local partition.
  3. Zero out out-of-range positions.
  4. all_reduce() across TP group.

7.2 ParallelLMHead (embed_head.py)

Extends VocabParallelEmbedding for the output projection. Key differences:

  • Forward extracts only the last token per sequence during prefill (via cu_seqlens_q[1:] - 1).
  • Uses tgemm.mm(x, self.weight, self.bias) for the logit computation (not F.linear).
  • Calls tensor_model_parallel_all_gather() to gather logits across TP ranks.

8. Rotary Position Embedding (RoPE)

8.1 RotaryEmbedding (rotary_embedding.py)

Precomputes cos/sin caches at initialization and applies RoPE in-place.

Constructor:

RotaryEmbedding(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool = True,
    dtype: Optional[torch.dtype] = None,
)

Forward: Calls aiter.rope_cached_positions_2c_fwd_inplace(query_, key_, cos, sin, positions, rotate_style, ...) which applies RoPE to Q and K tensors in-place using precomputed caches indexed by position IDs.

8.2 get_rope() Factory

get_rope(head_size, rotary_dim, max_position, base, rope_scaling=None)

Returns a cached RotaryEmbedding instance. Currently rope_scaling must be None.

8.3 Integration in Attention

  • MHA (attention_mha.py): RoPE is applied during the rope_cache() phase, either via the fused fused_qk_norm_rope_cache_quant_shuffle kernel, via fused_qk_rope_reshape_and_cache, or via standalone rotary_emb(position, q, k).
  • MLA (attention_mla.py): RoPE is applied to q_pe and k_rope tensors. During decode, this is fused into fused_qk_rope_concat_and_cache_mla. During prefill, it is applied via self.rotary_emb(positions, prefill_q_pe, k_rope).

9. Sampling

9.1 Sampler (sampler.py)

Unified sampling supporting both greedy (temperature=0) and random (temperature>0) sampling in a single kernel call.

Forward:

def forward(self, logits, temperatures) -> sampled_tokens:
    mixed_sample_outer_exponential(sampled_tokens, logits, exponential, temperatures, eps)

aiter.mixed_sample_outer_exponential performs temperature-scaled exponential sampling: it divides logits by temperature, then uses the Gumbel-max trick with pre-generated exponential random variates.

Fallback methods (currently unreachable due to early return):

  • greedy_sample(): aiter.ops.triton.topk.topk(logits, 1)
  • random_sample(): aiter.ops.triton.softmax.softmax(logits) followed by exponential sampling and topk.

9.2 RejectionSampler (rejection_sampler.py)

Implements rejection sampling for speculative decoding (MTP). Given draft token IDs and target model logits:

  1. Computes target_argmax = target_logits.argmax(dim=-1).
  2. Runs a Triton kernel rejection_greedy_sample_kernel that sequentially compares draft tokens against target argmax, accepting until first mismatch.
  3. On full acceptance, appends the bonus token.
  4. Returns (output_token_ids, num_bonus_tokens).

10. Fused Kernel Chains

ATOM uses fused kernels to reduce memory traffic by combining multiple operations into a single kernel launch.

Fused Operation Components Controlled By AITER Kernel
RMSNorm + FP8 quant RMSNorm, per-tensor FP8 static quant RMSNorm(fused_quant=True) + x_scale fused_rms_fp8_per_tensor_static_quant
RMSNorm + MXFP4 quant RMSNorm, per-1x32 MXFP4 quant RMSNorm(fused_quant=True) + QuantType.per_1x32 fused_rms_mxfp4_quant
RMSNorm + add + pad Residual add, RMSNorm, output padding RMSNorm(x_pad_to_multiple>0) fused_add_rmsnorm_pad
AllReduce + RMSNorm TP all-reduce, RMSNorm RMSNorm(fused_allreduce=True) tensor_model_parallel_fused_allreduce_rmsnorm
SiLU + mul + FP8 quant SiLU activation, multiply, FP8 quant SiluAndMul(fused_quant=True) + x_scale fused_silu_mul_fp8_per_tensor_static_quant
SiLU + mul + MXFP4 quant SiLU activation, multiply, MXFP4 quant SiluAndMul(fused_quant=True) + QuantType.per_1x32 fused_reduce_act_mul_and_mxfp4_quant
QK norm + RoPE + cache + quant Q/K norm, RoPE, KV cache write, optional FP8 quant, weight shuffle q_norm + k_norm + rotary_emb all present fused_qk_norm_rope_cache_quant_shuffle
RoPE + reshape + cache RoPE, K reshape, KV cache write Triton attention path fused_qk_rope_reshape_and_cache
QK RoPE + MLA cache Q RoPE, KV concat, MLA cache write, FP8 quant MLA decode path fused_qk_rope_concat_and_cache_mla
GEMM + split + cat (FP4) KV_b_proj GEMM, split K_nope/V, cat K_rope ATOM_USE_TRITON_GEMM=True + FP4 weights fused_gemm_afp4wfp4_preshuffle_split_cat
GEMM + split + cat (FP8) KV_b_proj GEMM, split K_nope/V, cat K_rope ATOM_USE_TRITON_GEMM=True + FP8 weights fused_gemm_a8w8_blockscale_preshuffle_split_cat
FP8 BMM + RoPE + cache (MLA) Batched FP8 BMM, RoPE, MLA KV cache write MLA decode with FP8 fused_fp8_bmm_rope_cat_and_cache_mla
FP4 BMM + RoPE + cache (MLA) Batched FP4 BMM, RoPE, MLA KV cache write MLA decode with MXFP4 fused_fp4_bmm_rope_cat_and_cache_mla

Source Files

atom/model_ops/

File Description
linear.py LinearBase, ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, MergedReplicatedLinear
activation.py SiluAndMul with fused FP8/MXFP4 quantization
layernorm.py RMSNorm, LayerNorm with fused allreduce/quant/pad variants
base_attention.py Top-level Attention dispatcher with custom op registration
attention_mha.py MHA implementation: prefill (flash), decode (ASM/Triton paged attention)
attention_mla.py MLAAttention, MLAModules -- DeepSeek MLA with compressed KV
moe.py FusedMoE, FusedMoEParallelConfig, UnquantizedFusedMoEMethod, Fp8MoEMethod, Mxfp4MoEMethod, CompressedTensorsFp8MoEMethod
fused_moe_triton.py triton_kernel_moe_forward -- Triton MoE via triton_kernels library
embed_head.py VocabParallelEmbedding, ParallelLMHead
rotary_embedding.py RotaryEmbedding, get_rope
topK.py rocm_aiter_topk_softmax, rocm_aiter_grouped_topk, init_aiter_topK_meta_data
sampler.py Sampler -- unified greedy/random sampling
rejection_sampler.py RejectionSampler -- speculative decoding rejection sampling
base_config.py QuantizeMethodBase abstract class
utils.py Helper utilities: shuffle_weights, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, etc.

atom/model_ops/attentions/

File Description
backends.py AttentionBackend, AttentionMetadataBuilder, CommonAttentionBuilder, AttentionImpl abstract classes
aiter_attention.py AiterBackend, AiterAttentionMetadataBuilder -- MHA backend with persistent ASM paged attention support
aiter_mla.py AiterMLABackend, AiterMLAMetadataBuilder -- MLA backend with sparse attention support

atom/model_ops/fused_moe/

File Description
config.py FusedMoEConfig, FusedMoEQuantConfig, FusedMoEQuantDesc, GroupShape, factory functions (fp8_w8a8_moe_quant_config, mxfp4_w4a16_moe_quant_config)
modular_kernel.py FusedMoEModularKernel, FusedMoEPrepareAndFinalize, ExpertTokensMetadata -- modular MoE kernel pipeline
mori_prepare_finalize.py MoriPrepareAndFinalize -- MORI all-to-all dispatch/combine for expert parallelism
utils.py MoE utility functions

atom/utils/

File Description
selector.py get_attn_backend() -- selects AiterBackend or AiterMLABackend based on use_mla flag