Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#pragma once

#include "helper.h" // For getEnvDeterministicMode, getEnvDeterministicDebug
#include "helper.h" // For getBoolEnv
#include "multiquery_attention_c16_kernel.h"

template <typename T,
Expand All @@ -33,7 +33,7 @@ __global__ void multi_query_append_attention_kernel(
const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) *
// head_dim]
const T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
// head_dim]
const T *__restrict__ cache_v,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
Expand All @@ -54,9 +54,9 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0,
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def _validate_split_kv_size(value: int) -> int:
"FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")),
# File path for file storage backend
"FILE_BACKEND_STORAGE_DIR": lambda: str(os.getenv("FILE_BACKEND_STORAGE_DIR", "/tmp/fastdeploy")),
# Custom all-reduce max buffer size in MB (default 8MB).
# Custom all-reduce max buffer size in MB (default 64MB).
# Increase this to avoid NCCL fallback for large tensors in deterministic mode.
# E.g. FD_CUSTOM_AR_MAX_SIZE_MB=128 for 128MB.
"FD_CUSTOM_AR_MAX_SIZE_MB": lambda: int(os.getenv("FD_CUSTOM_AR_MAX_SIZE_MB", "8")),
"FD_CUSTOM_AR_MAX_SIZE_MB": lambda: int(os.getenv("FD_CUSTOM_AR_MAX_SIZE_MB", "64")),
# Enable deterministic inference mode for chunked prefill alignment
"FD_DETERMINISTIC_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_MODE", "0"))),
# Split KV block size for deterministic alignment (must be power of 2 and > 0, default 16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
log_softmax,
matmul_persistent,
mean_dim,
rms_norm_batch_invariant,
set_batch_invariant_mode,
)

Expand All @@ -22,6 +23,7 @@
"matmul_persistent",
"log_softmax",
"mean_dim",
"rms_norm_batch_invariant",
"get_batch_invariant_attention_block_size",
"AttentionBlockSize",
]
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,69 @@ def mean_batch_invariant(
return result


# ---------------------------------------------------------------------------
# Batch-invariant RMSNorm (Triton): one program per row, fixed reduction order
# ---------------------------------------------------------------------------


@triton.jit
def _rms_norm_kernel( # pragma: no cover
input_ptr,
weight_ptr,
output_ptr,
input_row_stride: tl.constexpr,
output_row_stride: tl.constexpr,
n_cols: tl.constexpr,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""Per-row RMSNorm: y = x * rsqrt(mean(x^2) + eps) * weight.
Each program handles exactly one row → M-invariant."""
row_idx = tl.program_id(0).to(tl.int64)
row_start = input_ptr + row_idx * input_row_stride
out_start = output_ptr + row_idx * output_row_stride

# Pass 1: sum of squares in float32
sum_sq = tl.zeros([1], dtype=tl.float32)
for off in range(0, n_cols, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=0.0).to(tl.float32)
sum_sq += tl.sum(tl.where(mask, x * x, 0.0))

inv_rms = 1.0 / tl.sqrt(sum_sq / n_cols + eps)

# Pass 2: normalize and scale
for off in range(0, n_cols, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
x = tl.load(row_start + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(weight_ptr + cols, mask=mask, other=1.0).to(tl.float32)
y = x * inv_rms * w
tl.store(out_start + cols, y.to(out_start.dtype.element_ty), mask=mask)


def rms_norm_batch_invariant(x: paddle.Tensor, weight: paddle.Tensor, eps: float = 1e-6) -> paddle.Tensor:
"""M-invariant RMSNorm: each row computed independently via Triton."""
orig_shape = x.shape
x_2d = x.reshape([-1, x.shape[-1]]).contiguous()
weight = weight.contiguous()
n_rows, n_cols = x_2d.shape
out = paddle.empty_like(x_2d)
BLOCK_SIZE = 1024
_rms_norm_kernel[(n_rows,)](
x_2d,
weight,
out,
x_2d.stride(0),
out.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return out.reshape(orig_shape)


_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None, "bmm": None}

_batch_invariant_MODE = False
Expand Down
19 changes: 18 additions & 1 deletion fastdeploy/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import nn
from paddle.distributed import fleet

import fastdeploy.envs as envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.utils import h2d_copy, set_weight_attrs, slice_fn
Expand Down Expand Up @@ -315,6 +316,22 @@ def forward(self, ids_remove_padding: paddle.Tensor = None, forward_meta: Forwar
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.embeddings(ids_remove_padding)
if envs.FD_DETERMINISTIC_MODE and self.world_size > 1: # pragma: no cover
# Bypass Paddle's _mp_allreduce (NCCL) with Custom AR for determinism.
from paddle.distributed.fleet.layers.mpu import mp_ops

from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce,
)

output_parallel = mp_ops._c_lookup_table(
self.embeddings.weight,
ids_remove_padding,
start_index=self.embeddings.vocab_start_index,
vocab_size=self.embeddings.num_embeddings,
)
input_embedings = tensor_model_parallel_all_reduce(output_parallel, self.tp_group)
else:
input_embedings = self.embeddings(ids_remove_padding)

return input_embedings
36 changes: 23 additions & 13 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused

from .batch_invariant_ops import (
is_batch_invariant_mode_enabled,
rms_norm_batch_invariant,
)
from .utils import get_tensor, modules_to_convert


Expand Down Expand Up @@ -237,19 +241,25 @@ def forward(
return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
else:
norm_out = self.norm_func(
x,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.bias,
residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if is_batch_invariant_mode_enabled():
# M-invariant path: per-row Triton kernel, no cross-row reduction
if residual_input is not None:
x = x + residual_input
norm_out = rms_norm_batch_invariant(x, self.weight, self.eps), x
else:
norm_out = self.norm_func(
x,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.bias,
residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
else:
if residual_input is not None:
x = x + residual_input
Expand Down
Loading
Loading