diff --git a/flashmask/flash_mask/cp_balance/__init__.py b/flashmask/flash_mask/cp_balance/__init__.py new file mode 100644 index 00000000000..6804de4266f --- /dev/null +++ b/flashmask/flash_mask/cp_balance/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cp_balance import balance_flashmask_input +from .cp_balance_cuda_kernels import indices_rerank_cuda, indices_to_chunks_cuda diff --git a/flashmask/flash_mask/cp_balance/cp_balance.py b/flashmask/flash_mask/cp_balance/cp_balance.py new file mode 100644 index 00000000000..504f0fc7266 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/cp_balance.py @@ -0,0 +1,393 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq +import paddle +import numpy as np +from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda +import paddle.distributed as dist +import hashlib +from typing import List, Tuple, Dict, Optional + +# --- 调试辅助函数 --- + +def save_tensor(x: paddle.Tensor, name: str): + """将 Paddle Tensor 保存为 txt 文件,用于调试。""" + x_np = x.numpy() + np.savetxt(f'{name}.txt', x_np.reshape(-1, x_np.shape[-1]), fmt='%d') + +def tensor_md5(tensor: paddle.Tensor) -> str: + """计算 Paddle Tensor 的 MD5 哈希值,用于验证数据一致性。""" + x_bytes = tensor.numpy().tobytes() + md5_hash = hashlib.md5(x_bytes).hexdigest() + print(f"Tensor MD5: {md5_hash}") + return md5_hash + +# --- 核心工作负载计算与分配 --- + +def get_q_workload( + start_row_indices: paddle.Tensor, + q_chunk_size: int, + m_block_size: int, + n_block_size: int +) -> paddle.Tensor: + """ + 根据稀疏attention的起止索引,估算每个query chunk的计算负载。 + 这是负载均衡的第一步,目的是量化每个数据块的计算成本。 + + Args: + start_row_indices (paddle.Tensor): 形状为 [B, H, S, 2] 或 [B, H, S, 4] 的张量, + 表示每个 query token 需要计算的 key token 的起止范围。 + 维度4的顺序为 [LTS, LTE, UTS, UTE]。 + 维度2的顺序为 [LTS, UTE]。 + q_chunk_size (int): Query 侧进行负载均衡分析的块大小。 + m_block_size (int): FlashAttention kernel 中 query 侧的块大小 (Br)。 + n_block_size (int): FlashAttention kernel 中 key 侧的块大小 (Bc)。 + + Returns: + paddle.Tensor: 形状为 [1, H, Tchunks, 2] 的张量, + 其中 Tchunks 是 chunk 的数量。 + 每个 chunk 的信息为 [workload, original_index], + 表示该 chunk 的估算工作量和其原始索引。 + """ + assert start_row_indices is not None, "start_row_indices cannot be None" + assert q_chunk_size % m_block_size == 0, "q_chunk_size must be divisible by m_block_size" + + # 1. 解析输入的起止索引 + # start_row_indices 可能包含下三角(LT)和上三角(UT)的起止(Start/End)信息 + LTS, LTE, UTS, UTE = None, None, None, None + if start_row_indices.shape[-1] == 4: + LTS, LTE, UTS, UTE = paddle.split(start_row_indices, 4, axis=-1) + LTS, LTE, UTS, UTE = [t.squeeze(-1) for t in (LTS, LTE, UTS, UTE)] + elif start_row_indices.shape[-1] == 2: + LTS, UTE = paddle.split(start_row_indices, 2, axis=-1) + LTS, UTE = LTS.squeeze(-1), UTE.squeeze(-1) + + # 2. 获取维度信息 + # 从任意一个非None的张量中获取 Batch, Head, Sequence Length + valid_tensor = next(t for t in [LTS, LTE, UTS, UTE] if t is not None) + B, H, S = valid_tensor.shape + + # 计算块的数量 + Tr = S // m_block_size # Query 侧块总数 + Tc = S // n_block_size # Key 侧块总数 + Tchunks = S // q_chunk_size # 用于负载均衡的 chunk 总数 + assert Tr % Tchunks == 0, "Total row blocks must be divisible by total chunks" + blocks_per_chunk = Tr // Tchunks + + # 3. 使用自定义CUDA核预计算每个 Key 块内的索引最大/最小值 + # 这一步是关键优化,它将 O(S) 的扫描操作降维到 O(S/Bc), + # 极大地加速了后续工作负载的估算。 + def scan_max_min(tensor): + if tensor is not None: + return scanMaxMinChunkedKernel(tensor, n_block_size, B, H, S) + return None, None + + LTStartMax_gpu, LTStartMin_gpu = scan_max_min(LTS) + LTEndMax_gpu, LTEndMin_gpu = scan_max_min(LTE) + UTStartMax_gpu, UTStartMin_gpu = scan_max_min(UTS) + UTEndMax_gpu, UTEndMin_gpu = scan_max_min(UTE) + + # 4. 使用自定义CUDA核计算每个 Query 块的工作负载 + # 这个核模拟了 FlashAttention 的块状计算过程,但只计算需要被激活的块的数量, + # 而不是执行实际的矩阵乘法,从而高效地估算出工作负载。 + all_indices_max_min = [ + LTStartMax_gpu, LTStartMin_gpu, LTEndMax_gpu, LTEndMin_gpu, + UTStartMax_gpu, UTStartMin_gpu, UTEndMax_gpu, UTEndMin_gpu + ] + workload_per_block = reduce_workload(all_indices_max_min, B, H, Tr, Tc, m_block_size, S) + + # 5. 将每个块的工作负载聚合到 chunk 级别 + workload_grouped = workload_per_block.reshape([B, H, Tchunks, blocks_per_chunk, 1]) + workload_per_chunk = paddle.sum(workload_grouped, axis=3).sum(axis=0).reshape([1, H, Tchunks]) + + # 6. 准备最终输出,包含工作负载和原始索引 + final_res = paddle.zeros([1, H, Tchunks, 2], dtype='int32', device=start_row_indices.place) + final_res[:, :, :, 0] = workload_per_chunk + final_res[:, :, :, 1] = paddle.arange(0, Tchunks, dtype="int32") + + return final_res + + +def assign_tasks_heap( + tasks: np.ndarray, + num_buckets: int +) -> Tuple[List[List[Tuple[int, int]]], List[int], int]: + """ + 使用小顶堆的贪心算法,将带有权重和索引的任务列表分配到 M 个桶中, + 以实现负载均衡。 + + Args: + tasks (np.ndarray): 形状为 (N, 2) 的任务数组,每行是 [weight, index]。 + num_buckets (int): 桶的数量(通常等于通信组的 world size)。 + + Returns: + Tuple: + - buckets (List[List[Tuple[int, int]]]): 分配结果,每个子列表是一个桶的任务。 + - bucket_weights (List[int]): 每个桶的总权重。 + - cuts (int): 数据切分次数,衡量数据重排后的连续性。 + """ + n = len(tasks) + if n == 0: + return [[] for _ in range(num_buckets)], [0] * num_buckets, 0 + + # 每个桶的期望任务数量 + batch_size = n // num_buckets + + # 按权重降序排序任务,优先分配最重的任务 + tasks_sorted = sorted(tasks, key=lambda x: -x[0]) + + # 初始化桶和记录每个桶当前状态的变量 + buckets = [[] for _ in range(num_buckets)] + bucket_weights = [0] * num_buckets + bucket_counts = [0] * num_buckets + + # 初始化小顶堆,用于快速找到当前总权重最小的桶 + # 堆中元素为 (current_weight, bucket_index) + heap = [(0, i) for i in range(num_buckets)] + + # 贪心分配:依次将最重的任务分配给当前总权重最小的、且未满的桶 + for weight, idx in tasks_sorted: + # 找到一个可以放入任务的桶 + temp_popped = [] + found_bucket = False + while heap: + bucket_sum, bucket_idx = heapq.heappop(heap) + if bucket_counts[bucket_idx] < batch_size: + # 找到桶,更新状态并放回堆中 + buckets[bucket_idx].append((weight, idx)) + bucket_weights[bucket_idx] += weight + bucket_counts[bucket_idx] += 1 + heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) + found_bucket = True + break + else: + # 该桶已满,暂存起来,继续寻找下一个 + temp_popped.append((bucket_sum, bucket_idx)) + + # 将之前因为满了而弹出的桶重新放回堆中 + for item in temp_popped: + heapq.heappush(heap, item) + + if not found_bucket: + # 如果所有桶都满了(通常在 n % num_buckets != 0 时发生) + # 将剩余的任务分配给当前总权重最小的桶 + bucket_sum, bucket_idx = heapq.heappop(heap) + buckets[bucket_idx].append((weight, idx)) + bucket_weights[bucket_idx] += weight + bucket_counts[bucket_idx] += 1 + heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) + + + # (可选)按任务原始序号对每个桶内部进行排序,方便调试 + for i in range(num_buckets): + buckets[i] = sorted(buckets[i], key=lambda x: x[1]) + + # 统计切分次数:衡量重排后数据块的连续性 + all_assigned_indices = sorted([idx for bucket in buckets for _, idx in bucket]) + cuts = sum(1 for i in range(1, len(all_assigned_indices)) if all_assigned_indices[i] != all_assigned_indices[i-1] + 1) + + return buckets, bucket_weights, cuts + + +# --- 数据通信与重排辅助函数 --- + +def get_send_dict(buckets: List[List[Tuple[int, int]]], cp_size: int, rank: int) -> Dict[int, List[int]]: + """ + 根据负载均衡分配结果,为当前 rank 生成 all-to-all 通信的发送字典。 + + Args: + buckets (List): 所有 rank 的任务分配结果。 + cp_size (int): 通信组大小。 + rank (int): 当前进程的 rank。 + + Returns: + Dict[int, List[int]]: 发送字典。key 是目标 rank,value 是要发送给该 rank 的本地 chunk 索引列表。 + """ + send_dict = {i: [] for i in range(cp_size)} + # 遍历所有桶(即所有目标 rank 的任务列表) + for target_rank, bucket in enumerate(buckets): + for _, chunk_idx in bucket: + # 如果某个 chunk 的原始属主是当前 rank,则需要将其发送 + if chunk_idx // cp_size == rank: + # chunk_idx % cp_size 得到的是在当前 rank 上的局部索引 + send_dict[target_rank].append(chunk_idx % cp_size) + return send_dict + +def get_recv_dict(bucket: List[Tuple[int, int]], cp_size: int) -> Dict[int, List[int]]: + """ + 根据当前 rank 的任务分配结果,生成 all-to-all 通信的接收字典。 + + Args: + bucket (List): 当前 rank 分配到的任务列表。 + cp_size (int): 通信组大小。 + + Returns: + Dict[int, List[int]]: 接收字典。key 是源 rank,value 是从该 rank 接收的数据块 + 应该被放置到的本地位置索引列表。 + """ + recv_dict = {i: [] for i in range(cp_size)} + # 遍历分配给我的所有任务 + for local_pos, (_, chunk_idx) in enumerate(bucket): + # chunk_idx.item() // cp_size 得到的是这个 chunk 原始所在的 rank + source_rank = chunk_idx.item() // cp_size + recv_dict[source_rank].append(local_pos) + return recv_dict + +def balance_alltoall( + input_tensor: paddle.Tensor, + cp_size: int, + cp_group, + chunk_size: int, + send_dict: Dict[int, List[int]], + recv_dict: Dict[int, List[int]] +) -> paddle.Tensor: + """ + 执行 all-to-all 通信,根据 send/recv 字典对 `input_tensor` 进行数据重排。 + 此函数已重构,可统一处理不同维度的张量。 + + Args: + input_tensor (paddle.Tensor): 待重排的张量,如 Q, K, V。 + cp_size (int): 通信组大小。 + cp_group (dist.Group): Paddle 分布式通信组。 + chunk_size (int): 数据块的大小。 + send_dict (Dict): 发送字典。 + recv_dict (Dict): 接收字典。 + + Returns: + paddle.Tensor: 重排后的张量。 + """ + original_shape = input_tensor.shape + B, S = original_shape[0], original_shape[1] + + # 将输入张量统一 reshape 为 3D (B, S, -1) 以便统一处理 + tensor_3d = input_tensor.reshape((B, S, -1)) + HD = tensor_3d.shape[-1] + + # 1. 准备发送数据 (Gather) + # 根据 send_dict,从本地张量中收集需要发送给其他 rank 的数据块 + send_list = [] + for target_rank in range(cp_size): + indices_to_send = send_dict[target_rank] + if indices_to_send: + # 将所有要发往同一个 rank 的数据块拼接在一起 + data_to_send = paddle.concat( + [tensor_3d[:, idx * chunk_size:(idx + 1) * chunk_size, :] for idx in indices_to_send], + axis=1 + ) + send_list.append(data_to_send) + else: + # 注意:NCCL alltoall 不支持大小为 0 的张量,因此发送一个虚拟的、 + # 非常小的张量作为占位符。接收方也需对应接收。 + send_list.append(paddle.zeros((B, 1, HD), dtype=input_tensor.dtype)) + + # 2. 准备接收缓冲区 (Scatter) + # 根据 recv_dict,为从其他 rank 接收的数据准备相应大小的空缓冲区 + recv_list = [] + for source_rank in range(cp_size): + num_chunks_to_recv = len(recv_dict[source_rank]) + if num_chunks_to_recv > 0: + recv_list.append( + paddle.empty((B, chunk_size * num_chunks_to_recv, HD), dtype=input_tensor.dtype) + ) + else: + # 对应发送方的虚拟张量,接收一个同样大小的虚拟缓冲区 + recv_list.append(paddle.empty((B, 1, HD), dtype=input_tensor.dtype)) + + # 3. 执行 All-to-All 通信 + dist.alltoall(out_tensor_list=recv_list, in_tensor_list=send_list, group=cp_group) + + # 4. 将接收到的数据重新组装成最终张量 + final_res_3d = paddle.empty_like(tensor_3d) + for source_rank in range(cp_size): + local_positions = recv_dict[source_rank] + if local_positions: + received_data = recv_list[source_rank] + # 将从 source_rank 接收到的数据块,放置到它们在本地应该在的位置 + for i, local_pos in enumerate(local_positions): + start_s = local_pos * chunk_size + end_s = (local_pos + 1) * chunk_size + data_start = i * chunk_size + data_end = (i + 1) * chunk_size + final_res_3d[:, start_s:end_s, :] = received_data[:, data_start:data_end, :] + + # 恢复原始形状 + return final_res_3d.reshape(original_shape) + + +# --- 主流程函数 --- + +def balance_flashmask_input( + startend_row_indices: paddle.Tensor, + cp_size: int, + cp_rank: int, + balance_chunk_size: int = 2048, + q_block_size: int = 128, + k_block_size: int = 128 +) -> Tuple[paddle.Tensor, List[List[Tuple[int, int]]]]: + """ + FlashMask 输入数据的负载均衡主流程。 + 该函数协调整个过程:估算工作负载 -> 任务分配 -> 生成通信计划 -> 数据重排。 + + Args: + startend_row_indices (paddle.Tensor): 稀疏 attention 的原始起止索引。 + cp_size (int): 通信组大小。 + cp_rank (int): 当前进程的 rank。 + balance_chunk_size (int): 用于负载均衡分析和数据移动的块大小。 + q_block_size (int): FlashAttention kernel 的 query 块大小。 + k_block_size (int): FlashAttention kernel 的 key 块大小。 + + Returns: + Tuple: + - local_startend_row_indices (paddle.Tensor): 经过负载均衡和重排后, + 当前 rank 需要处理的局部起止索引。 + - buckets (List[List[Tuple[int, int]]]): 全局的任务分配方案,用于后续 + 对 Q, K, V 等张量进行同样的重排。 + """ + # 步骤 1: 估算每个 chunk 的工作负载 + paddle.base.core.nvprof_nvtx_push("get_q_workload") + workload = get_q_workload(startend_row_indices, balance_chunk_size, q_block_size, k_block_size) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 2: 使用堆贪心算法在 CPU 上进行任务分配 + paddle.base.core.nvprof_nvtx_push("assign_tasks_heap") + # 将 workload tensor 转换成 numpy 数组以用于 heapq + tasks_np = workload.reshape([-1, 2]).cpu().numpy() + buckets, _, _ = assign_tasks_heap(tasks_np, cp_size) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 5: 根据全局分配方案 `buckets`,对原始索引张量进行重排 (Gather) + # 这一步创建了一个全局视角下、数据块被重新排列后的 `startend_row_indices`。 + paddle.base.core.nvprof_nvtx_push("startend_row_indices_rerank") + # 将 `buckets` 展平,得到一个新的 chunk 顺序 + rerank_indices = np.array([idx for bucket in buckets for _, idx in bucket], dtype=np.int32) + indices_tensor = paddle.to_tensor(rerank_indices, dtype='int32', place=startend_row_indices.place) + + # 使用 CUDA 核高效地执行 gather 操作 + startend_row_indices_rerank = indices_rerank_cuda(startend_row_indices, indices_tensor) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 6: 从重排后的全局索引中,计算出当前 rank 的局部索引 (Localize) + # 这一步将全局索引(可能跨越整个序列长度S)转换为相对于本地数据块的局部索引。 + paddle.base.core.nvprof_nvtx_push("indices_to_chunks") + local_bucket_indices = [x[1] for x in buckets[cp_rank]] + local_indices_tensor = paddle.to_tensor(local_bucket_indices, dtype='int32', place=startend_row_indices.place) + + # 使用 CUDA 核高效地执行索引的 clipping 和 offsetting + local_startend_row_indices = indices_to_chunks_cuda( + startend_row_indices_rerank, local_indices_tensor, balance_chunk_size + ) + paddle.base.core.nvprof_nvtx_pop() + + return local_startend_row_indices, buckets \ No newline at end of file diff --git a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py new file mode 100644 index 00000000000..ffc4a67e3b8 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py @@ -0,0 +1,60 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flashmask_cpbalance_cudaops as cp_balance_ops + +def scanMaxMinChunkedKernel(input_tensor, Bc, B, H, S): + maxo,mino = cp_balance_ops.scan_max_min( + input_tensor, + H, + S, + S, + Bc, + False, + 0.0, + 0, + 0 + ) + + return maxo, mino + + +def reduce_workload(start_row_maxmin_indice_list, B, H, Tr, Tc, Br, S): + ( + LTStartMax, + LTStartMin, + LTEndMax, + LTEndMin, + UTStartMax, + UTStartMin, + UTEndMax, + UTEndMin, + ) = start_row_maxmin_indice_list + + workload = cp_balance_ops.reduce_workload( + LTStartMax, LTStartMin, LTEndMax, LTEndMin, UTStartMax, UTStartMin, UTEndMax, UTEndMin, + B, H, Tr, Tc, S, Br, False, 128 + ) + + return workload + +def indices_to_chunks_cuda(startend_row_indices, bucket_idx, chunksize=2048): + result = cp_balance_ops.indices_to_chunks(startend_row_indices, bucket_idx, chunksize) + return result + +def indices_rerank_cuda(startend_row_indices, indices, balance_chunk_size=2048): + B, H, S, D = startend_row_indices.shape + num_chunks = (S + balance_chunk_size - 1) // balance_chunk_size + startend_row_indices_rerank = cp_balance_ops.indices_rerank(startend_row_indices, indices, B, H, S,D,num_chunks,balance_chunk_size) + return startend_row_indices_rerank diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu new file mode 100644 index 00000000000..231e00a4aaa --- /dev/null +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu @@ -0,0 +1,681 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +#define CHECK_CUDA_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +int get_kBlockN(int head_size_rounded, bool is_flashmask, bool is_causal, bool has_softcap, + bool is_local, int seqlen_q, int seqlen_k, bool has_lt_end, bool has_ut_start) { + if (head_size_rounded <= 64) { + if (is_flashmask && !is_causal) { + return 96; + } else if ((is_causal && has_softcap) || is_flashmask) { + return 128; + } else { + return 128; + } + } else if (head_size_rounded <= 128) { + if (is_causal || is_local || has_softcap) { + return 128; + } else { + if (seqlen_q >= 1024 || seqlen_k >= 1024) { + return 128; + } else { + return 64; + } + } + } else if (head_size_rounded <= 256) { + if (has_lt_end && has_ut_start) { + return 32; + } else { + return 64; + } + } else { + // 不支持的情况 + throw std::runtime_error("head_size_rounded not supported"); + } +} + +template +__global__ +void scanMaxMinChunkedKernel( + const int *input, int b, int n, int *maxo, int *mino) { + int bid = threadIdx.y + blockIdx.y * blockDim.y; + if (bid >= b) return; + int i_offset = bid * n; + input = input + i_offset; + const int nblock_seqlen = ((n + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + constexpr int nums = (kBlockN + 31) / 32; + int warpId = blockIdx.x; + int tid = threadIdx.x; + int lane_id = threadIdx.x % 32; + int maxv, minv; + int idx = warpId * kBlockN + tid; + if (warpId * kBlockN + kBlockN > n) { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if (idx < n && lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + } + idx += 32; + } + } else { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if(lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + idx += 32; + } + } + } + __syncwarp(); + maxv = __reduce_max_sync(0xffffffff, maxv); + minv = __reduce_min_sync(0xffffffff, minv); + if (tid == 0) { + maxo[bid * nblock_seqlen + warpId] = maxv; + mino[bid * nblock_seqlen + warpId] = minv; + } +} + +// Enum for pointer dispatching in reduce_workload_kernel +enum PtrDispatch { SINGLE_PTR = 1, DUAL_PTR = 2, FULL_PTR = 4 }; + +template +__global__ void reduce_workload_kernel( + const int* LTStartMax, const int* LTStartMin, + const int* LTEndMax, const int* LTEndMin, + const int* UTStartMax, const int* UTStartMin, + const int* UTEndMax, const int* UTEndMin, + int* workload, // [B, H, Tr, 1] + int BH, int Tr, int Tc, int S, + int Br // m_block_size +) { + int bh = blockIdx.y; + int tr = blockIdx.x; + int tc = threadIdx.x; + int warpId = threadIdx.x / 32; + int laneId = threadIdx.x % 32; + + if(tr >= Tr) return; + + int wl = 0; + bool fully_masked = true; + bool partially_masked = false; + int lt_start_max = INT_MAX; + int lt_start_min = INT_MAX; + int lt_end_max = INT_MAX; + int lt_end_min = INT_MAX; + int ut_start_max = INT_MIN; + int ut_start_min = INT_MIN; + int ut_end_max = INT_MIN; + int ut_end_min = INT_MIN; + + __shared__ int smem[32]; + + const int idx = bh * Tc + tc; + const int q_idx = bh * Tr + tr; + + // m_block_s/e: Q block boundaries within a single (batch, head) — use tr only, not q_idx. + // q_idx includes the bh offset for output indexing, but mask values are in [0, S) per (b,h). + const int m_block_s = tr * kBlockM; + const int m_block_e = m_block_s + kBlockM < S ? m_block_s + kBlockM : S; + + lt_start_max = tc < Tc ? LTStartMax[idx] : INT_MAX; + lt_start_min = tc < Tc ? LTStartMin[idx] : INT_MAX; + + // 分支展开 + if constexpr (PTR_DISPATCH_TAG == FULL_PTR) { + lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; + lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; + ut_start_max = tc < Tc ? UTStartMax[idx] : INT_MIN; + ut_start_min = tc < Tc ? UTStartMin[idx] : INT_MIN; + ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; + ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; + + fully_masked = (m_block_s >= lt_start_max && m_block_e <= lt_end_min) || + (m_block_s >= ut_start_max && m_block_e <= ut_end_min); + partially_masked = (m_block_s < lt_end_max && m_block_e > lt_start_min) || + (m_block_s < ut_end_max && m_block_e > ut_start_min); + } + else if constexpr (PTR_DISPATCH_TAG == DUAL_PTR) { + if constexpr (is_causal) { + lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; + lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; + fully_masked = m_block_s >= lt_start_max && m_block_e <= lt_end_min; + partially_masked = m_block_s < lt_end_max && m_block_e > lt_start_min; + } else { + ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; + ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; + fully_masked = (m_block_s >= lt_start_max) || (m_block_e <= ut_end_min); + partially_masked = (m_block_e > lt_start_min) || (m_block_s < ut_end_max); + } + } + else if constexpr (PTR_DISPATCH_TAG == SINGLE_PTR) { + fully_masked = m_block_s >= lt_start_max; + partially_masked = m_block_e > lt_start_min; + } + + if(tc >= Tc){ + fully_masked = true; + partially_masked = false; + } + wl = fully_masked ? 0 : 1; + + unsigned mask = 0xffffffff; + // warp reduce sum + int wl_sum = wl; + for (int offset = 16; offset > 0; offset >>= 1) { + wl_sum += __shfl_down_sync(mask, wl_sum, offset); + } + if (laneId == 0) { + smem[warpId] = wl_sum; + } + __syncthreads(); + + if (threadIdx.x < 32) { + int val = (threadIdx.x < (blockDim.x + 31)/32) ? smem[threadIdx.x] : 0; + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(mask, val, offset); + } + if (threadIdx.x == 0) { + workload[q_idx] = val; + } + } +} + +__global__ void indices_to_chunks_kernel( + const int* startend_row_indices, + const int* chunk_bucket_indices, + int* chunked_result, + int num_rows, + int num_buckets, + int chunk_size) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + + int max_chunk_index = 0; + int row_val = startend_row_indices[row]; + + for (int bucket = 0; bucket < num_buckets; ++bucket) { + int bucket_idx = chunk_bucket_indices[bucket]; + int chunk_start = bucket_idx * chunk_size; + int local_index = row_val - chunk_start; + local_index = max(local_index, 0); + local_index = min(local_index, chunk_size); + + if (local_index > 0) { + local_index += bucket * chunk_size; + } + + if (bucket == 0 || local_index > max_chunk_index) { + max_chunk_index = local_index; + } + } + chunked_result[row] = max_chunk_index; +} + +__global__ void indices_rerank_kernel( + const int* startend_row_indices, + int* output_reranked_indices, + const int* chunk_indices, + int batch_size, + int num_heads, + int seq_len, + int feature_dim, + int num_chunks, + int chunk_size +) { + int output_seq_len = num_chunks * chunk_size; + int total_elements = batch_size * output_seq_len * num_heads * feature_dim; + int flat_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (flat_idx >= total_elements) return; + + int d = flat_idx % feature_dim; + int s_out = (flat_idx / feature_dim) % output_seq_len; + int h = (flat_idx / feature_dim / output_seq_len) % num_heads; + int b = (flat_idx / feature_dim / output_seq_len / num_heads) % batch_size; + + int chunk_id = s_out / chunk_size; + int chunk_offset = s_out % chunk_size; + int src_s = chunk_indices[chunk_id] * chunk_size + chunk_offset; + + if (src_s >= seq_len) return; + + int src_flat_idx = ((b * num_heads + h) * seq_len + src_s) * feature_dim + d; + int dst_flat_idx = flat_idx; + + output_reranked_indices[dst_flat_idx] = startend_row_indices[src_flat_idx]; +} + + + + +// ============================================================================ +// ScanMaxMin Operator +// ============================================================================ + +std::vector scan_max_min_cuda( + const paddle::Tensor& input, + const int head_size_rounded, + const int seq_len_q, + const int seq_len_k, + const int blocksize = -1, + const bool is_causal = false, + const float softcap = 0.0, + const int window_size_left = 0, + const int window_size_right = 0) { + CHECK_CUDA_INPUT(input); + + // The scanMaxMin kernel treats input as flat [batch, seqlen]. + // Input tensor is [B, H, S] from Python (H is always 1 in practice; after squeeze(-1) from [B,H,S,D]). + // We compute total_batch = product of all dims except the last, so it handles [B,S], [B,H,S] etc. + const auto dims = input.shape(); + const auto ndim = dims.size(); + int64_t total_batch = 1; + for (int i = 0; i < ndim - 1; i++) total_batch *= dims[i]; + const auto num_sequences = dims[ndim - 1]; + // head_dim only used by get_kBlockN heuristic; safe default when blocksize is explicit + const auto head_dim = (ndim >= 4) ? dims[3] : 1; + + PADDLE_ENFORCE_EQ( + num_sequences, + seq_len_k, + common::errors::InvalidArgument( + "Input tensor's third dimension (num_sequences) must be equal to seq_len_k.")); + + const bool is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + const bool is_flashmask = true; + const bool has_softcap = softcap > 0.0; + const bool has_lt_end = !is_causal && head_dim >= 2; + const bool has_ut_start = head_dim == 4; + + const int kernel_block_size_n = + blocksize > 0 ? blocksize : get_kBlockN(head_size_rounded, + is_flashmask, + is_causal, + has_softcap, + is_local, + seq_len_q, + seq_len_k, + has_lt_end, + has_ut_start); + + // Pad the number of blocks to be a multiple of 4 for performance + const int num_blocks_seqlen = + ((num_sequences + kernel_block_size_n - 1) / kernel_block_size_n + 3) & 0xfffffffc; + + std::vector output_shape = {total_batch, num_blocks_seqlen}; + auto max_output = paddle::empty(output_shape, input.dtype(), input.place()); + auto min_output = paddle::empty(output_shape, input.dtype(), input.place()); + + // Launch kernel + dim3 block_dim(32, 4); + dim3 grid_dim((num_sequences + kernel_block_size_n - 1) / kernel_block_size_n, + (total_batch + 3) / 4); + + const cudaStream_t stream = input.stream(); + + switch (kernel_block_size_n) { + case 32: + scanMaxMinChunkedKernel<32><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + case 64: + scanMaxMinChunkedKernel<64><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + case 96: + scanMaxMinChunkedKernel<96><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + case 128: + scanMaxMinChunkedKernel<128><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + default: + PD_THROW("Unsupported kernel_block_size_n: %d", kernel_block_size_n); + } + return {max_output, min_output}; +} + +std::vector ScanMaxMin( + const paddle::Tensor& input, + int head_size_rounded, + int seq_len_q, + int seq_len_k, + int blocksize, + bool is_causal, + float softcap, + int window_size_left, + int window_size_right) { +#ifdef PADDLE_WITH_CUDA + if (input.is_gpu()) { + return scan_max_min_cuda(input, + head_size_rounded, + seq_len_q, + seq_len_k, + blocksize, + is_causal, + softcap, + window_size_left, + window_size_right); + } +#endif + PD_THROW("Unsupported device: ScanMaxMin operator is only available for CUDA."); +} + + +// ============================================================================ +// ReduceWorkload Operator +// ============================================================================ + +template +void launch_reduce_workload_kernel( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + paddle::Tensor& workload, + int batch_times_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + cudaStream_t stream) { + + dim3 block_dim(1024, 1); + dim3 grid_dim(num_row_blocks, batch_times_heads); + + int ptr_dispatch_tag = SINGLE_PTR; + if (lt_end_max || ut_end_max) { + ptr_dispatch_tag = DUAL_PTR; + if (ut_start_max) { + ptr_dispatch_tag = FULL_PTR; + } + } + + int* workload_ptr = workload.data(); + const int* lt_start_max_ptr = lt_start_max.data(); + const int* lt_start_min_ptr = lt_start_min.data(); + const int* lt_end_max_ptr = lt_end_max ? lt_end_max.get().data() : nullptr; + const int* lt_end_min_ptr = lt_end_min ? lt_end_min.get().data() : nullptr; + const int* ut_start_max_ptr = ut_start_max ? ut_start_max.get().data() : nullptr; + const int* ut_start_min_ptr = ut_start_min ? ut_start_min.get().data() : nullptr; + const int* ut_end_max_ptr = ut_end_max ? ut_end_max.get().data() : nullptr; + const int* ut_end_min_ptr = ut_end_min ? ut_end_min.get().data() : nullptr; + + if (ptr_dispatch_tag == FULL_PTR) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else if (ptr_dispatch_tag == DUAL_PTR) { + if (is_causal) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } + } else if (ptr_dispatch_tag == SINGLE_PTR) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else { + PD_THROW("Unknown pointer dispatch tag."); + } +} + +std::vector reduce_workload_cuda( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + int batch_size, + int num_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + int m_block_size) { + + const int kBlockM = m_block_size; + const int batch_times_heads = batch_size * num_heads; + + // Use the actual padded stride from scanMaxMin output, not the caller's unpadded num_col_blocks. + // scanMaxMin pads nblock_seqlen to a multiple of 4 for performance; if num_col_blocks differs + // from the tensor's actual column count, the flat index bh*Tc+tc would be wrong. + const int Tc_stride = static_cast(lt_start_max.shape()[1]); + + // Allocate output tensor + std::vector output_shape = {batch_size, num_heads, num_row_blocks, 1}; + auto workload = paddle::empty(output_shape, lt_start_max.dtype(), lt_start_max.place()); + + cudaStream_t stream = lt_start_max.stream(); + + switch (kBlockM) { + case 64: + launch_reduce_workload_kernel<64>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, Tc_stride, stride, row_block_size, is_causal, stream); + break; + case 96: + launch_reduce_workload_kernel<96>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, Tc_stride, stride, row_block_size, is_causal, stream); + break; + case 128: + launch_reduce_workload_kernel<128>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, Tc_stride, stride, row_block_size, is_causal, stream); + break; + default: + PD_THROW("Unsupported m_block_size: %d", kBlockM); + } + return {workload}; +} + +std::vector ReduceWorkloadOp( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + int batch_size, + int num_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + int m_block_size) { +#ifdef PADDLE_WITH_CUDA + if (lt_start_max.is_gpu()) { + return reduce_workload_cuda(lt_start_max, + lt_start_min, + lt_end_max, + lt_end_min, + ut_start_max, + ut_start_min, + ut_end_max, + ut_end_min, + batch_size, + num_heads, + num_row_blocks, + num_col_blocks, + stride, + row_block_size, + is_causal, + m_block_size); + } +#endif + PD_THROW("Unsupported device: ReduceWorkload operator is only available for CUDA."); +} + + +// ============================================================================ +// IndicesToChunks & IndicesRerank Operators +// ============================================================================ + +std::vector IndicesToChunksOp( + const paddle::Tensor& row_indices, + const paddle::Tensor& chunk_bucket_indices, + int chunk_size) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ(row_indices.is_gpu(), true, + common::errors::InvalidArgument("Input 'row_indices' must be a CUDA tensor.")); + + auto chunked_result = paddle::empty_like(row_indices); + + const int num_rows = row_indices.numel(); + const int num_buckets = chunk_bucket_indices.numel(); + const int num_threads_per_block = 256; + const int num_blocks = (num_rows + num_threads_per_block - 1) / num_threads_per_block; + + indices_to_chunks_kernel<<>>( + row_indices.data(), + chunk_bucket_indices.data(), + chunked_result.data(), + num_rows, + num_buckets, + chunk_size); + + return {chunked_result}; +#else + PD_THROW("Unsupported device: IndicesToChunks operator is only available for CUDA."); +#endif +} + +std::vector IndicesRerankOp( + const paddle::Tensor& input_row_indices, + const paddle::Tensor& chunk_indices, + int batch_size, + int num_heads, + int seq_len, + int feature_dim, + int num_chunks, + int chunk_size) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ(input_row_indices.is_gpu(), true, + common::errors::InvalidArgument("Input 'input_row_indices' must be a CUDA tensor.")); + + const int output_seq_len = num_chunks * chunk_size; + auto reranked_indices = paddle::empty({batch_size, num_heads, output_seq_len, feature_dim}, + input_row_indices.dtype(), + input_row_indices.place()); + + const int total_elements = batch_size * output_seq_len * num_heads * feature_dim; + const int num_threads_per_block = 256; + const int num_blocks = (total_elements + num_threads_per_block - 1) / num_threads_per_block; + + indices_rerank_kernel<<>>( + input_row_indices.data(), + reranked_indices.data(), + chunk_indices.data(), + batch_size, + num_heads, + seq_len, + feature_dim, + num_chunks, + chunk_size); + + return {reranked_indices}; +#else + PD_THROW("Unsupported device: IndicesRerank operator is only available for CUDA."); +#endif +} + + +// ============================================================================ +// Operator Registrations +// ============================================================================ + +PD_BUILD_OP(scan_max_min) + .Inputs({"Input"}) + .Outputs({"MaxOut", "MinOut"}) + .Attrs({"head_size_rounded: int", + "seq_len_q: int", + "seq_len_k: int", + "blocksize: int", + "is_causal: bool", + "softcap: float", + "window_size_left: int", + "window_size_right: int"}) + .SetKernelFn(PD_KERNEL(ScanMaxMin)); + +PD_BUILD_OP(reduce_workload) + .Inputs({"LTStartMax", "LTStartMin", + paddle::Optional("LTEndMax"), paddle::Optional("LTEndMin"), + paddle::Optional("UTStartMax"), paddle::Optional("UTStartMin"), + paddle::Optional("UTEndMax"), paddle::Optional("UTEndMin")}) + .Outputs({"Workload"}) + .Attrs({"batch_size: int", + "num_heads: int", + "num_row_blocks: int", + "num_col_blocks: int", + "stride: int", + "row_block_size: int", + "is_causal: bool", + "m_block_size: int"}) + .SetKernelFn(PD_KERNEL(ReduceWorkloadOp)); + +PD_BUILD_OP(indices_to_chunks) + .Inputs({"RowIndices", "ChunkBucketIndices"}) + .Outputs({"ChunkedResult"}) + .Attrs({"chunk_size: int"}) + .SetKernelFn(PD_KERNEL(IndicesToChunksOp)); + +PD_BUILD_OP(indices_rerank) + .Inputs({"InputRowIndices", "ChunkIndices"}) + .Outputs({"RerankedIndices"}) + .Attrs({"batch_size: int", + "num_heads: int", + "seq_len: int", + "feature_dim: int", + "num_chunks: int", + "chunk_size: int"}) + .SetKernelFn(PD_KERNEL(IndicesRerankOp)); \ No newline at end of file diff --git a/flashmask/flash_mask/cp_balance/csrc/setup.py b/flashmask/flash_mask/cp_balance/csrc/setup.py new file mode 100644 index 00000000000..eac80bf937a --- /dev/null +++ b/flashmask/flash_mask/cp_balance/csrc/setup.py @@ -0,0 +1,134 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import shutil +import re + + +def get_version_from_txt(): + version_file = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_file, "r") as f: + version = f.read().strip() + return version + + +def custom_version_scheme(version): + base_version = get_version_from_txt() + date_str = ( + subprocess.check_output( + ["git", "log", "-1", "--format=%cd", "--date=format:%Y%m%d"] + ) + .decode() + .strip() + ) + return f"{base_version}.dev{date_str}" + + +def no_local_scheme(version): + return "" + + +def change_pwd(): + """change_pwd""" + path = os.path.dirname(__file__) + if path: + os.chdir(path) + +def get_cuda_version(): + nvcc_path = shutil.which("nvcc") + if nvcc_path is None: + raise FileNotFoundError( + "nvcc command not found. Please make sure CUDA toolkit is installed and nvcc is in PATH." + ) + + result = subprocess.run( + ["nvcc", "--version"], + capture_output=True, + text=True, + check=True, + ) + version_output = result.stdout + + match = re.search(r"release (\d+)\.(\d+)", version_output) + if not match: + raise ValueError( + f"Cannot parse CUDA version from nvcc output:\n{version_output}" + ) + cuda_major = int(match.group(1)) + cuda_minor = int(match.group(2)) + return cuda_major, cuda_minor + + +def setup_ops_extension(): + from paddle.utils.cpp_extension import CUDAExtension, setup + + nvcc_args = [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-maxrregcount=32", + "-lineinfo", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_90a,code=sm_90a", + "-gencode=arch=compute_100,code=sm_100", + "-DNDEBUG", + ] + cuda_major, cuda_minor = get_cuda_version() + if cuda_major < 12: + raise ValueError( + f"CUDA version must be >= 12. Detected version: {cuda_major}.{cuda_minor}" + ) + if cuda_major == 12 and cuda_minor < 8: + nvcc_args = [arg for arg in nvcc_args if "compute_100" not in arg] + + ext_module = CUDAExtension( + sources=[ + # cpp files + # cuda files + "./cp_balance_utils.cu", + ], + include_dirs=[ + os.path.join(os.getcwd(), "./"), + ], + extra_compile_args={ + "cxx": [ + "-O3", + "-w", + "-Wno-abi", + "-fPIC", + "-std=c++17", + ], + "nvcc": nvcc_args, + }, + ) + + change_pwd() + setup( + name="flashmask_cpbalance_cudaops", + ext_modules=[ext_module], + version="0.0.1", + setup_requires=["setuptools_scm"], + ) + + +setup_ops_extension() \ No newline at end of file diff --git a/flashmask/setup.py b/flashmask/setup.py index b6940550a88..e074eaf5274 100644 --- a/flashmask/setup.py +++ b/flashmask/setup.py @@ -76,7 +76,8 @@ def _get_version(): # ============================================================ # Packages: exclude modules not being built # ============================================================ -exclude_packages = ['build', 'build.*', 'tests', 'tests.*'] +exclude_packages = ['build', 'build.*', 'tests', 'tests.*', + 'flash_mask.cp_balance.csrc', 'flash_mask.cp_balance.csrc.*'] if not BUILD_FA3: exclude_packages += [ 'flash_mask.flashmask_attention_v3', @@ -393,3 +394,25 @@ def _get_cuda_version(): paddle_setup(**setup_kwargs, ext_modules=ext_modules) else: setuptools_setup(**setup_kwargs) + +# ============================================================ +# CP Balance: CUDA extension (built via its own setup.py after main setup) +# Paddle's cpp_extension.setup only supports 1 Extension per call, +# so we invoke cp_balance's setup.py as a subprocess. +# ============================================================ +if BUILD_FA3: + cp_balance_csrc_dir = os.path.join(FLASH_MASK_DIR, 'cp_balance', 'csrc') + print("[flashmask] Building CP Balance CUDA extension...") + result = subprocess.run( + [sys.executable, 'setup.py', 'install'], + cwd=cp_balance_csrc_dir, + capture_output=True, + text=True, + ) + if result.returncode != 0: + print(f"[flashmask] CP Balance build STDERR:\n{result.stderr}") + raise RuntimeError( + f"Failed to build CP Balance CUDA extension.\n" + f"You can build it manually: cd {cp_balance_csrc_dir} && python setup.py install" + ) + print("[flashmask] CP Balance CUDA extension built successfully.")