ATOM (AiTer Optimized Model) wraps AITER kernels with model-level abstractions for LLM inference on AMD ROCm/HIP GPUs. This guide documents every operator class in atom/model_ops/, their AITER kernel mappings, quantization paths, and fused kernel chains.
| ATOM Class | File | AITER Kernel / Import | Purpose |
|---|---|---|---|
LinearBase |
linear.py |
tgemm.mm, gemm_a8w8, gemm_a8w8_bpreshuffle, gemm_a8w8_blockscale_bpreshuffle, gemm_a4w4 |
Quantized linear dispatch |
ColumnParallelLinear |
linear.py |
(inherits LinearBase) |
Column-sharded TP linear |
RowParallelLinear |
linear.py |
(inherits LinearBase) |
Row-sharded TP linear |
QKVParallelLinear |
linear.py |
(inherits ColumnParallelLinear) |
Fused Q/K/V projection |
MergedColumnParallelLinear |
linear.py |
(inherits LinearBase) |
Merged gate+up projection |
Attention |
base_attention.py |
unified_attention_with_output_base (custom op) |
Unified attention entry |
MHA Attention |
attention_mha.py |
flash_attn_varlen_func, pa_fwd_asm, pa_persistent_fwd, pa_decode_gluon |
Multi-head attention |
MLAAttention |
attention_mla.py |
mla_decode_fwd, mla_prefill_fwd, concat_and_cache_mla, fused_qk_rope_concat_and_cache_mla |
Multi-head latent attention |
FusedMoE |
moe.py |
aiter.fused_moe.fused_moe, asm_moe |
Mixture of experts |
RMSNorm |
layernorm.py |
rmsnorm2d_fwd, rmsnorm2d_fwd_with_add, fused_add_rmsnorm_pad |
RMS normalization |
LayerNorm |
layernorm.py |
layernorm2d_fwd, layernorm2d_fwd_with_add |
Layer normalization |
SiluAndMul |
activation.py |
aiter.silu_and_mul |
SiLU gated activation |
VocabParallelEmbedding |
embed_head.py |
F.embedding + TP all-reduce |
Vocab embedding |
ParallelLMHead |
embed_head.py |
tgemm.mm + tensor_model_parallel_all_gather |
LM output head |
RotaryEmbedding |
rotary_embedding.py |
aiter.rope_cached_positions_2c_fwd_inplace |
Rotary position embedding |
Sampler |
sampler.py |
aiter.mixed_sample_outer_exponential, aiter.ops.triton.topk.topk, aiter.ops.triton.softmax.softmax |
Token sampling |
RejectionSampler |
rejection_sampler.py |
Triton rejection_greedy_sample_kernel |
Speculative decoding |
ATOM is a thin model-level inference engine. Every compute-heavy operation delegates to an AITER kernel. The general pattern is:
- An ATOM
nn.Moduleowns model weights and configuration. - Its
forward()method selects the appropriate AITER function based on quantization type, parallelism settings, and phase (prefill vs. decode). - Results are optionally reduced across tensor-parallel (TP) or data-parallel (DP) groups.
| ATOM Wrapper | AITER Function / Import Path | Backend Type |
|---|---|---|
LinearBase.forward (No quant) |
aiter.tuned_gemm.tgemm.mm |
hipBLASLt |
LinearBase.forward (per_Tensor FP8) |
aiter.tuned_gemm.tgemm.mm with scales |
hipBLASLt |
LinearBase.forward (per_Token INT8) |
aiter.gemm_a8w8 |
CK |
LinearBase.forward (per_Token FP8) |
aiter.gemm_a8w8_bpreshuffle |
CK |
LinearBase.forward (per_1x128 FP8) |
aiter.gemm_a8w8_blockscale_bpreshuffle |
CK |
LinearBase.forward (per_1x32 MXFP4) |
aiter.gemm_a4w4 |
CK |
| MHA prefill | aiter.flash_attn_varlen_func |
ASM / CK |
| MHA decode (ASM) | aiter.pa_fwd_asm |
ASM |
| MHA decode (persistent ASM) | aiter.pa_persistent_fwd |
ASM |
| MHA decode (Triton) | aiter.ops.triton.gluon.pa_decode_gluon |
Triton |
| MHA prefill (Triton unified) | aiter.ops.triton.unified_attention.unified_attention |
Triton |
| MLA decode | aiter.mla.mla_decode_fwd |
ASM |
| MLA prefill | aiter.mla.mla_prefill_fwd |
ASM |
| MLA KV cache | aiter.concat_and_cache_mla |
CK |
| RoPE | aiter.rope_cached_positions_2c_fwd_inplace |
Triton |
| RMSNorm | aiter.rmsnorm2d_fwd |
CK |
| SiLU+Mul | aiter.silu_and_mul |
CK |
| TopK routing | aiter.topk_softmax, aiter.grouped_topk, aiter.biased_grouped_topk |
CK |
| Sampling | aiter.mixed_sample_outer_exponential |
CK |
| FusedMoE | aiter.fused_moe.fused_moe |
CK |
| ASM MoE | aiter.fused_moe_bf16_asm.asm_moe |
ASM |
| Quantization | aiter.get_hip_quant(QuantType) |
CK / Triton |
All linear layers inherit from LinearBase in atom/model_ops/linear.py.
LinearBase (nn.Module)
+-- ReplicatedLinear # No TP sharding
| +-- MergedReplicatedLinear
+-- ColumnParallelLinear # tp_dim=0, shard output
| +-- QKVParallelLinear # Fused Q/K/V with per-head sharding
+-- MergedColumnParallelLinear # tp_dim=0, merged gate+up
+-- RowParallelLinear # tp_dim=1, shard input, optional all-reduce
LinearBase.forward() dispatches to different GEMM kernels based on QuantType:
QuantType |
Weight dtype | GEMM Kernel | Scale Shape |
|---|---|---|---|
No |
BF16/FP16 | tgemm.mm (hipBLASLt) |
None |
per_Tensor |
FP8 | tgemm.mm with scale_a, scale_b |
[num_partitions, 1] |
per_Token (INT8) |
INT8 | gemm_a8w8 |
[output_size, 1] |
per_Token (FP8) |
FP8 | gemm_a8w8_bpreshuffle |
[output_size, 1] |
per_1x128 |
FP8 | gemm_a8w8_blockscale_bpreshuffle |
[ceil(N/128), ceil(K/128)] |
per_1x32 |
MXFP4 (fp4x2) |
gemm_a4w4 |
[N, ceil(K/32)] (e8m0) |
When x_scale is not provided, the input is dynamically quantized via get_hip_quant(quant_type).
- ColumnParallelLinear (
tp_dim=0): Shards weight rows (output dimension) across GPUs. Each GPU ownsoutput_size / tp_sizerows. - RowParallelLinear (
tp_dim=1): Shards weight columns (input dimension). Ifreduce_results=True, output is all-reduced across TP group. - QKVParallelLinear: Extends
ColumnParallelLinearwith per-head sharding. Q heads are evenly divided; KV heads are either divided or replicated whennum_kv_heads < tp_size. - MergedColumnParallelLinear: Handles gate and up projections merged into a single weight with
output_sizesas a list (e.g.,[intermediate_size, intermediate_size]).
After loading, process_weights_after_loading() handles:
- e4m3fn to e4m3fnuz normalization (AMD FP8 format conversion).
- Weight reshuffling via
shuffle_weights()for pre-shuffled GEMM kernels. - Scale reshuffling via
fp4_utils.e8m0_shuffle()for MXFP4 block scales. - Per-tensor requantization via
requantize_with_max_scale()when multiple output partitions have separate scales.
The top-level Attention class in base_attention.py is a dispatcher. It:
- Selects the backend via
get_attn_backend()fromatom/utils/selector.py. - Instantiates the backend's implementation class (
impl_cls). - Registers itself in
compilation_config.static_forward_contextunderlayer_name. - On
forward(), callstorch.ops.aiter.unified_attention_with_output_base, which is a custom op decorated with@mark_spliting_op-- this preventstorch.compilefrom tracing into attention internals, enabling full-graph capture.
Backend selection logic (in selector.py):
| Condition | Backend Class | Implementation |
|---|---|---|
use_mla=True |
AiterMLABackend |
MLAAttention from attention_mla.py |
use_mla=False |
AiterBackend |
Attention from attention_mha.py |
The MHA Attention class handles standard models (Llama, Qwen3, Mixtral, etc.).
Forward flow:
- Reshape Q, K, V to
[num_tokens, num_heads, head_dim]. - Apply RoPE + KV cache write via
rope_cache(). - Dispatch to the appropriate backend via
dispatch_backend().
RoPE + KV cache paths:
| Condition | Kernel Chain |
|---|---|
q_norm + k_norm + rotary_emb present |
fused_qk_norm_rope_cache_quant_shuffle (single fused kernel for QK norm, RoPE, cache write, optional FP8 quant) |
Triton path (sliding_window != -1 or head_dim != 128) + rotary_emb |
fused_qk_rope_reshape_and_cache (Triton fused RoPE + reshape + cache) |
ASM path + rotary_emb |
rotary_emb(position, q, k) then reshape_and_cache or reshape_and_cache_with_pertoken_quant |
Attention dispatch:
| Phase | Condition | Method | AITER Kernel |
|---|---|---|---|
| Prefill | Always | prefill_attention |
aiter.flash_attn_varlen_func |
| Decode | use_triton_attn=True |
paged_attention_triton |
torch.ops.aiter.pa_decode_gluon |
| Decode | block_size == 1024 |
paged_attention_persistent_asm |
aiter.pa_persistent_fwd |
| Decode | Default | paged_attention_asm |
aiter.pa_fwd_asm |
The use_triton_attn flag is set when sliding_window != -1 or head_dim != 128.
MLAAttention implements DeepSeek's MLA with a compressed KV representation. Key data structures:
@dataclass
class MLAModules:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
qk_head_dim: int
v_head_dim: int
rotary_emb: torch.nn.Module
q_proj: Optional[torch.nn.Module]
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
indexer: Optional[torch.nn.Module]Forward flow:
- If prefill and not sparse: Standard MHA-style prefill with
flash_attn_varlen_func, preceded bykv_b_projGEMM to produce K_nope and V from compressedkv_c_normed. - Otherwise: Fused Q projection + K up-projection via batched FP8/FP4 BMM (
_q_proj_and_k_up_proj), then:fused_qk_rope_concat_and_cache_mlawrites to KV cache.- Decode:
mla_decode_fwd(ASM persistent MLA kernel). - Prefill (sparse):
mla_prefill_fwd.
- V up-projection + O projection via batched BMM (
_v_up_proj_and_o_proj).
Batched GEMM backends for MLA projections:
| Condition | Kernel |
|---|---|
ATOM_USE_TRITON_MXFP4_BMM=True |
batched_gemm_a16wfp4 (Triton FP4 BMM) |
| Default | batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant (Triton FP8 BMM) |
Prefill GEMM optimizations (for kv_b_proj):
| Condition | Kernel |
|---|---|
ATOM_USE_TRITON_GEMM=True + FP4 weights |
fused_gemm_afp4wfp4_preshuffle_split_cat (GEMM + split K/V + cat rope in one kernel) |
ATOM_USE_TRITON_GEMM=True + FP8 weights |
fused_gemm_a8w8_blockscale_preshuffle_split_cat |
| Default | kv_b_proj(kv_c_normed) then manual split + cat |
The AttentionBackend abstract class defines three required methods:
get_name()-- Returns backend identifier string.get_builder_cls()-- Returns theAttentionMetadataBuildersubclass.get_impl_cls()-- Returns the attention implementation class.
CommonAttentionBuilder provides shared metadata preparation (slot mapping, block tables, cumulative sequence lengths) used by both AiterBackend and AiterMLABackend.
| Operation | AITER Kernel | Used By |
|---|---|---|
| Standard KV cache write | aiter.reshape_and_cache |
MHA (BF16 KV) |
| FP8 KV cache write | aiter.reshape_and_cache_with_pertoken_quant |
MHA (FP8 KV) |
| MLA KV cache write | aiter.concat_and_cache_mla |
MLA prefill |
| Fused QK RoPE + MLA cache | aiter.fused_qk_rope_concat_and_cache_mla |
MLA decode |
FusedMoE is the top-level MoE module. It handles:
- Expert routing via
select_experts(). - Weight creation and quantization dispatch via
quant_method. - Tensor/Expert/Data parallelism via
FusedMoEParallelConfig. - Optional shared expert fusion and MORI communication.
Constructor parameters:
FusedMoE(
num_experts: int, # Global number of experts
top_k: int, # Experts per token
hidden_size: int, # Input hidden dimension
intermediate_size: int, # Expert intermediate dimension
reduce_results: bool, # Whether to all-reduce output
renormalize: bool, # Renormalize routing weights
use_grouped_topk: bool, # Use grouped top-k (DeepSeek)
activation: ActivationType, # Silu, Gelu, Swiglu, etc.
...
)FusedMoE selects a quant_method at construction time:
| Quant Config | Method Class | GEMM Kernel |
|---|---|---|
QuantType.No |
UnquantizedFusedMoEMethod |
aiter.fused_moe.fused_moe |
FP8 (dtypes.fp8) |
Fp8MoEMethod |
aiter.fused_moe.fused_moe with quant_type |
| FP8 compressed-tensors | CompressedTensorsFp8MoEMethod |
aiter.fused_moe.fused_moe or asm_moe |
MXFP4 (dtypes.fp4x2) |
Mxfp4MoEMethod |
aiter.fused_moe.fused_moe or Triton triton_kernel_moe_forward |
The ASM MoE path (asm_moe from aiter.fused_moe_bf16_asm) is used by FP8 methods and supports a16 mode where activations remain in BF16/FP16 while weights are FP8/INT8.
| Routing Function | AITER Kernel | Used For |
|---|---|---|
rocm_aiter_topk_softmax |
aiter.topk_softmax |
Standard top-k (Mixtral) |
rocm_aiter_grouped_topk |
aiter.grouped_topk |
Grouped top-k (DeepSeek) |
rocm_aiter_biased_grouped_topk |
aiter.biased_grouped_topk |
Biased grouped top-k (DeepSeek V3) |
Shared expert fusion: When is_rocm_aiter_fusion_shared_expert_enabled() returns True, the top-k buffers are extended with shared expert IDs appended after routed expert IDs. This allows shared expert computation to be fused into the same MoE kernel call. The metadata is initialized via init_aiter_topK_meta_data().
@dataclass
class FusedMoEParallelConfig:
tp_size: int # Tensor parallel size
dp_size: int # Data parallel size
ep_size: int # Expert parallel size
tp_rank: int
dp_rank: int
ep_rank: int
use_ep: bool # Whether expert parallelism is active
local_ep_size: int # Local EP size (GPUs per node * TP)Key properties:
use_all2all_kernels:Truewhendp_size > 1, EP is enabled, and MORI is available.use_mori_kernels: AlwaysTrue(currently).
MORI (MoE Router Infrastructure) provides all-to-all communication kernels for expert parallelism. MoriPrepareAndFinalize implements:
prepare(): Dispatches tokens to remote experts viamori_op.dispatch(). Optionally quantizes activations to FP8 before dispatch.finalize(): Combines expert outputs viamori_op.combine()and copies results back.
The FusedMoEModularKernel orchestrates the prepare-compute-finalize pipeline.
FusedMoEQuantConfig describes activation and weight quantization for MoE layers:
@dataclass
class FusedMoEQuantConfig:
_a1: FusedMoEQuantDesc # First activation (input to gate_up)
_a2: FusedMoEQuantDesc # Second activation (input to down_proj)
_w1: FusedMoEQuantDesc # gate_up_proj weights
_w2: FusedMoEQuantDesc # down_proj weightsFactory functions:
fp8_w8a8_moe_quant_config()-- FP8 weights and activations.mxfp4_w4a16_moe_quant_config()-- MXFP4 weights, unquantized activations.FUSED_MOE_UNQUANTIZED_CONFIG-- No quantization.
triton_kernel_moe_forward() provides a Triton-based MoE path using the triton_kernels library. It uses routing() for expert assignment and matmul_ogs() for the expert GEMM. This path is currently used for MXFP4 MoE on GFX94x hardware.
RMSNorm supports multiple forward paths depending on configuration flags:
| Condition | Kernel / Path | Returns |
|---|---|---|
x_pad_to_multiple > 0, no residual |
fused_rmsnorm_pad_ (Triton fused_add_rmsnorm_pad) |
Padded output |
x_pad_to_multiple > 0, with residual |
fused_add_rmsnorm_pad_ |
(output, residual) |
fused_allreduce=True and tp_size > 1 |
tensor_model_parallel_fused_allreduce_rmsnorm |
(output, residual) |
fused_quant=True and x_scale provided |
fused_rms_fp8_per_tensor_static_quant |
(FP8 output, scale) |
fused_quant=True and per_1x32 |
fused_rms_mxfp4_quant |
(MXFP4 output, scale) |
| Default, no residual | rmsnorm2d_fwd |
Output |
| Default, with residual | rmsnorm2d_fwd_with_add |
(output, residual) |
Constructor parameters:
RMSNorm(
dim: int,
eps: float = 1e-6,
x_pad_to_multiple: int = 0,
fused_allreduce: bool = False,
fused_quant: bool = False,
quant_config: Optional[QuantizationConfig] = None,
)LayerNorm wraps layernorm2d_fwd and layernorm2d_fwd_with_add (with bias support):
LayerNorm(dim: int, eps: float = 1e-6)- Without residual:
layernorm2d_fwd(x, weight, bias, eps) - With residual:
layernorm2d_fwd_with_add(out, x, residual, residual_out, weight, bias, eps)
SiluAndMul computes SiLU(x_first_half) * x_second_half. It splits the last dimension in half.
| Condition | Kernel | Output |
|---|---|---|
fused_quant=True + x_scale provided (FP8) |
fused_silu_mul_fp8_per_tensor_static_quant |
(FP8 output, scale) |
fused_quant=True + per_1x32 (MXFP4) |
fused_reduce_act_mul_and_mxfp4_quant (via mxfp4_act_mul_quant_fuse) |
(MXFP4 output, scale) |
| Default | aiter.silu_and_mul(out, x) |
BF16 output |
Constructor:
SiluAndMul(
fused_quant: bool = False,
quant_config: Optional[QuantizationConfig] = None,
)Partitions the vocabulary across TP ranks. Each rank holds num_embeddings / tp_size rows.
Forward:
- Mask input token IDs to this rank's partition range
[vocab_start_idx, vocab_end_idx). F.embedding()on local partition.- Zero out out-of-range positions.
all_reduce()across TP group.
Extends VocabParallelEmbedding for the output projection. Key differences:
- Forward extracts only the last token per sequence during prefill (via
cu_seqlens_q[1:] - 1). - Uses
tgemm.mm(x, self.weight, self.bias)for the logit computation (notF.linear). - Calls
tensor_model_parallel_all_gather()to gather logits across TP ranks.
Precomputes cos/sin caches at initialization and applies RoPE in-place.
Constructor:
RotaryEmbedding(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool = True,
dtype: Optional[torch.dtype] = None,
)Forward: Calls aiter.rope_cached_positions_2c_fwd_inplace(query_, key_, cos, sin, positions, rotate_style, ...) which applies RoPE to Q and K tensors in-place using precomputed caches indexed by position IDs.
get_rope(head_size, rotary_dim, max_position, base, rope_scaling=None)Returns a cached RotaryEmbedding instance. Currently rope_scaling must be None.
- MHA (
attention_mha.py): RoPE is applied during therope_cache()phase, either via the fusedfused_qk_norm_rope_cache_quant_shufflekernel, viafused_qk_rope_reshape_and_cache, or via standalonerotary_emb(position, q, k). - MLA (
attention_mla.py): RoPE is applied toq_peandk_ropetensors. During decode, this is fused intofused_qk_rope_concat_and_cache_mla. During prefill, it is applied viaself.rotary_emb(positions, prefill_q_pe, k_rope).
Unified sampling supporting both greedy (temperature=0) and random (temperature>0) sampling in a single kernel call.
Forward:
def forward(self, logits, temperatures) -> sampled_tokens:
mixed_sample_outer_exponential(sampled_tokens, logits, exponential, temperatures, eps)aiter.mixed_sample_outer_exponential performs temperature-scaled exponential sampling: it divides logits by temperature, then uses the Gumbel-max trick with pre-generated exponential random variates.
Fallback methods (currently unreachable due to early return):
greedy_sample():aiter.ops.triton.topk.topk(logits, 1)random_sample():aiter.ops.triton.softmax.softmax(logits)followed by exponential sampling andtopk.
Implements rejection sampling for speculative decoding (MTP). Given draft token IDs and target model logits:
- Computes
target_argmax = target_logits.argmax(dim=-1). - Runs a Triton kernel
rejection_greedy_sample_kernelthat sequentially compares draft tokens against target argmax, accepting until first mismatch. - On full acceptance, appends the bonus token.
- Returns
(output_token_ids, num_bonus_tokens).
ATOM uses fused kernels to reduce memory traffic by combining multiple operations into a single kernel launch.
| Fused Operation | Components | Controlled By | AITER Kernel |
|---|---|---|---|
| RMSNorm + FP8 quant | RMSNorm, per-tensor FP8 static quant | RMSNorm(fused_quant=True) + x_scale |
fused_rms_fp8_per_tensor_static_quant |
| RMSNorm + MXFP4 quant | RMSNorm, per-1x32 MXFP4 quant | RMSNorm(fused_quant=True) + QuantType.per_1x32 |
fused_rms_mxfp4_quant |
| RMSNorm + add + pad | Residual add, RMSNorm, output padding | RMSNorm(x_pad_to_multiple>0) |
fused_add_rmsnorm_pad |
| AllReduce + RMSNorm | TP all-reduce, RMSNorm | RMSNorm(fused_allreduce=True) |
tensor_model_parallel_fused_allreduce_rmsnorm |
| SiLU + mul + FP8 quant | SiLU activation, multiply, FP8 quant | SiluAndMul(fused_quant=True) + x_scale |
fused_silu_mul_fp8_per_tensor_static_quant |
| SiLU + mul + MXFP4 quant | SiLU activation, multiply, MXFP4 quant | SiluAndMul(fused_quant=True) + QuantType.per_1x32 |
fused_reduce_act_mul_and_mxfp4_quant |
| QK norm + RoPE + cache + quant | Q/K norm, RoPE, KV cache write, optional FP8 quant, weight shuffle | q_norm + k_norm + rotary_emb all present |
fused_qk_norm_rope_cache_quant_shuffle |
| RoPE + reshape + cache | RoPE, K reshape, KV cache write | Triton attention path | fused_qk_rope_reshape_and_cache |
| QK RoPE + MLA cache | Q RoPE, KV concat, MLA cache write, FP8 quant | MLA decode path | fused_qk_rope_concat_and_cache_mla |
| GEMM + split + cat (FP4) | KV_b_proj GEMM, split K_nope/V, cat K_rope | ATOM_USE_TRITON_GEMM=True + FP4 weights |
fused_gemm_afp4wfp4_preshuffle_split_cat |
| GEMM + split + cat (FP8) | KV_b_proj GEMM, split K_nope/V, cat K_rope | ATOM_USE_TRITON_GEMM=True + FP8 weights |
fused_gemm_a8w8_blockscale_preshuffle_split_cat |
| FP8 BMM + RoPE + cache (MLA) | Batched FP8 BMM, RoPE, MLA KV cache write | MLA decode with FP8 | fused_fp8_bmm_rope_cat_and_cache_mla |
| FP4 BMM + RoPE + cache (MLA) | Batched FP4 BMM, RoPE, MLA KV cache write | MLA decode with MXFP4 | fused_fp4_bmm_rope_cat_and_cache_mla |
| File | Description |
|---|---|
linear.py |
LinearBase, ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, MergedReplicatedLinear |
activation.py |
SiluAndMul with fused FP8/MXFP4 quantization |
layernorm.py |
RMSNorm, LayerNorm with fused allreduce/quant/pad variants |
base_attention.py |
Top-level Attention dispatcher with custom op registration |
attention_mha.py |
MHA implementation: prefill (flash), decode (ASM/Triton paged attention) |
attention_mla.py |
MLAAttention, MLAModules -- DeepSeek MLA with compressed KV |
moe.py |
FusedMoE, FusedMoEParallelConfig, UnquantizedFusedMoEMethod, Fp8MoEMethod, Mxfp4MoEMethod, CompressedTensorsFp8MoEMethod |
fused_moe_triton.py |
triton_kernel_moe_forward -- Triton MoE via triton_kernels library |
embed_head.py |
VocabParallelEmbedding, ParallelLMHead |
rotary_embedding.py |
RotaryEmbedding, get_rope |
topK.py |
rocm_aiter_topk_softmax, rocm_aiter_grouped_topk, init_aiter_topK_meta_data |
sampler.py |
Sampler -- unified greedy/random sampling |
rejection_sampler.py |
RejectionSampler -- speculative decoding rejection sampling |
base_config.py |
QuantizeMethodBase abstract class |
utils.py |
Helper utilities: shuffle_weights, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, etc. |
| File | Description |
|---|---|
backends.py |
AttentionBackend, AttentionMetadataBuilder, CommonAttentionBuilder, AttentionImpl abstract classes |
aiter_attention.py |
AiterBackend, AiterAttentionMetadataBuilder -- MHA backend with persistent ASM paged attention support |
aiter_mla.py |
AiterMLABackend, AiterMLAMetadataBuilder -- MLA backend with sparse attention support |
| File | Description |
|---|---|
config.py |
FusedMoEConfig, FusedMoEQuantConfig, FusedMoEQuantDesc, GroupShape, factory functions (fp8_w8a8_moe_quant_config, mxfp4_w4a16_moe_quant_config) |
modular_kernel.py |
FusedMoEModularKernel, FusedMoEPrepareAndFinalize, ExpertTokensMetadata -- modular MoE kernel pipeline |
mori_prepare_finalize.py |
MoriPrepareAndFinalize -- MORI all-to-all dispatch/combine for expert parallelism |
utils.py |
MoE utility functions |
| File | Description |
|---|---|
selector.py |
get_attn_backend() -- selects AiterBackend or AiterMLABackend based on use_mla flag |