forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 33
[Feat] CP-balance formal incorporation as flash_mask sub-module #127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Enigmatisms
wants to merge
2
commits into
PaddlePaddle:main
Choose a base branch
from
Enigmatisms:cp_balance
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .cp_balance import balance_flashmask_input | ||
| from .cp_balance_cuda_kernels import indices_rerank_cuda, indices_to_chunks_cuda |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,379 @@ | ||
| import heapq | ||
| import paddle | ||
| import numpy as np | ||
| from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda | ||
| import paddle.distributed as dist | ||
| import hashlib | ||
| from typing import List, Tuple, Dict, Optional | ||
|
|
||
| # --- 调试辅助函数 --- | ||
|
|
||
| def save_tensor(x: paddle.Tensor, name: str): | ||
| """将 Paddle Tensor 保存为 txt 文件,用于调试。""" | ||
| x_np = x.numpy() | ||
| np.savetxt(f'{name}.txt', x_np.reshape(-1, x_np.shape[-1]), fmt='%d') | ||
|
|
||
| def tensor_md5(tensor: paddle.Tensor) -> str: | ||
| """计算 Paddle Tensor 的 MD5 哈希值,用于验证数据一致性。""" | ||
| x_bytes = tensor.numpy().tobytes() | ||
| md5_hash = hashlib.md5(x_bytes).hexdigest() | ||
| print(f"Tensor MD5: {md5_hash}") | ||
| return md5_hash | ||
|
|
||
| # --- 核心工作负载计算与分配 --- | ||
|
|
||
| def get_q_workload( | ||
| start_row_indices: paddle.Tensor, | ||
| q_chunk_size: int, | ||
| m_block_size: int, | ||
| n_block_size: int | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| 根据稀疏attention的起止索引,估算每个query chunk的计算负载。 | ||
| 这是负载均衡的第一步,目的是量化每个数据块的计算成本。 | ||
|
|
||
| Args: | ||
| start_row_indices (paddle.Tensor): 形状为 [B, H, S, 2] 或 [B, H, S, 4] 的张量, | ||
| 表示每个 query token 需要计算的 key token 的起止范围。 | ||
| 维度4的顺序为 [LTS, LTE, UTS, UTE]。 | ||
| 维度2的顺序为 [LTS, UTE]。 | ||
| q_chunk_size (int): Query 侧进行负载均衡分析的块大小。 | ||
| m_block_size (int): FlashAttention kernel 中 query 侧的块大小 (Br)。 | ||
| n_block_size (int): FlashAttention kernel 中 key 侧的块大小 (Bc)。 | ||
|
|
||
| Returns: | ||
| paddle.Tensor: 形状为 [1, H, Tchunks, 2] 的张量, | ||
| 其中 Tchunks 是 chunk 的数量。 | ||
| 每个 chunk 的信息为 [workload, original_index], | ||
| 表示该 chunk 的估算工作量和其原始索引。 | ||
| """ | ||
| assert start_row_indices is not None, "start_row_indices cannot be None" | ||
| assert q_chunk_size % m_block_size == 0, "q_chunk_size must be divisible by m_block_size" | ||
|
|
||
| # 1. 解析输入的起止索引 | ||
| # start_row_indices 可能包含下三角(LT)和上三角(UT)的起止(Start/End)信息 | ||
| LTS, LTE, UTS, UTE = None, None, None, None | ||
| if start_row_indices.shape[-1] == 4: | ||
| LTS, LTE, UTS, UTE = paddle.split(start_row_indices, 4, axis=-1) | ||
| LTS, LTE, UTS, UTE = [t.squeeze(-1) for t in (LTS, LTE, UTS, UTE)] | ||
| elif start_row_indices.shape[-1] == 2: | ||
| LTS, UTE = paddle.split(start_row_indices, 2, axis=-1) | ||
| LTS, UTE = LTS.squeeze(-1), UTE.squeeze(-1) | ||
|
|
||
| # 2. 获取维度信息 | ||
| # 从任意一个非None的张量中获取 Batch, Head, Sequence Length | ||
| valid_tensor = next(t for t in [LTS, LTE, UTS, UTE] if t is not None) | ||
| B, H, S = valid_tensor.shape | ||
|
|
||
| # 计算块的数量 | ||
| Tr = S // m_block_size # Query 侧块总数 | ||
| Tc = S // n_block_size # Key 侧块总数 | ||
| Tchunks = S // q_chunk_size # 用于负载均衡的 chunk 总数 | ||
| assert Tr % Tchunks == 0, "Total row blocks must be divisible by total chunks" | ||
| blocks_per_chunk = Tr // Tchunks | ||
|
|
||
| # 3. 使用自定义CUDA核预计算每个 Key 块内的索引最大/最小值 | ||
| # 这一步是关键优化,它将 O(S) 的扫描操作降维到 O(S/Bc), | ||
| # 极大地加速了后续工作负载的估算。 | ||
| def scan_max_min(tensor): | ||
| if tensor is not None: | ||
| return scanMaxMinChunkedKernel(tensor, n_block_size, B, H, S) | ||
| return None, None | ||
|
|
||
| LTStartMax_gpu, LTStartMin_gpu = scan_max_min(LTS) | ||
| LTEndMax_gpu, LTEndMin_gpu = scan_max_min(LTE) | ||
| UTStartMax_gpu, UTStartMin_gpu = scan_max_min(UTS) | ||
| UTEndMax_gpu, UTEndMin_gpu = scan_max_min(UTE) | ||
|
|
||
| # 4. 使用自定义CUDA核计算每个 Query 块的工作负载 | ||
| # 这个核模拟了 FlashAttention 的块状计算过程,但只计算需要被激活的块的数量, | ||
| # 而不是执行实际的矩阵乘法,从而高效地估算出工作负载。 | ||
| all_indices_max_min = [ | ||
| LTStartMax_gpu, LTStartMin_gpu, LTEndMax_gpu, LTEndMin_gpu, | ||
| UTStartMax_gpu, UTStartMin_gpu, UTEndMax_gpu, UTEndMin_gpu | ||
| ] | ||
| workload_per_block = reduce_workload(all_indices_max_min, B, H, Tr, Tc, m_block_size, S) | ||
|
|
||
| # 5. 将每个块的工作负载聚合到 chunk 级别 | ||
| workload_grouped = workload_per_block.reshape([B, H, Tchunks, blocks_per_chunk, 1]) | ||
| workload_per_chunk = paddle.sum(workload_grouped, axis=3).sum(axis=0).reshape([1, H, Tchunks]) | ||
|
|
||
| # 6. 准备最终输出,包含工作负载和原始索引 | ||
| final_res = paddle.zeros([1, H, Tchunks, 2], dtype='int32', device=start_row_indices.place) | ||
| final_res[:, :, :, 0] = workload_per_chunk | ||
| final_res[:, :, :, 1] = paddle.arange(0, Tchunks, dtype="int32") | ||
|
|
||
| return final_res | ||
|
|
||
|
|
||
| def assign_tasks_heap( | ||
| tasks: np.ndarray, | ||
| num_buckets: int | ||
| ) -> Tuple[List[List[Tuple[int, int]]], List[int], int]: | ||
| """ | ||
| 使用小顶堆的贪心算法,将带有权重和索引的任务列表分配到 M 个桶中, | ||
| 以实现负载均衡。 | ||
|
|
||
| Args: | ||
| tasks (np.ndarray): 形状为 (N, 2) 的任务数组,每行是 [weight, index]。 | ||
| num_buckets (int): 桶的数量(通常等于通信组的 world size)。 | ||
|
|
||
| Returns: | ||
| Tuple: | ||
| - buckets (List[List[Tuple[int, int]]]): 分配结果,每个子列表是一个桶的任务。 | ||
| - bucket_weights (List[int]): 每个桶的总权重。 | ||
| - cuts (int): 数据切分次数,衡量数据重排后的连续性。 | ||
| """ | ||
| n = len(tasks) | ||
| if n == 0: | ||
| return [[] for _ in range(num_buckets)], [0] * num_buckets, 0 | ||
|
|
||
| # 每个桶的期望任务数量 | ||
| batch_size = n // num_buckets | ||
|
|
||
| # 按权重降序排序任务,优先分配最重的任务 | ||
| tasks_sorted = sorted(tasks, key=lambda x: -x[0]) | ||
|
|
||
| # 初始化桶和记录每个桶当前状态的变量 | ||
| buckets = [[] for _ in range(num_buckets)] | ||
| bucket_weights = [0] * num_buckets | ||
| bucket_counts = [0] * num_buckets | ||
|
|
||
| # 初始化小顶堆,用于快速找到当前总权重最小的桶 | ||
| # 堆中元素为 (current_weight, bucket_index) | ||
| heap = [(0, i) for i in range(num_buckets)] | ||
|
|
||
| # 贪心分配:依次将最重的任务分配给当前总权重最小的、且未满的桶 | ||
| for weight, idx in tasks_sorted: | ||
| # 找到一个可以放入任务的桶 | ||
| temp_popped = [] | ||
| found_bucket = False | ||
| while heap: | ||
| bucket_sum, bucket_idx = heapq.heappop(heap) | ||
| if bucket_counts[bucket_idx] < batch_size: | ||
| # 找到桶,更新状态并放回堆中 | ||
| buckets[bucket_idx].append((weight, idx)) | ||
| bucket_weights[bucket_idx] += weight | ||
| bucket_counts[bucket_idx] += 1 | ||
| heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) | ||
| found_bucket = True | ||
| break | ||
| else: | ||
| # 该桶已满,暂存起来,继续寻找下一个 | ||
| temp_popped.append((bucket_sum, bucket_idx)) | ||
|
|
||
| # 将之前因为满了而弹出的桶重新放回堆中 | ||
| for item in temp_popped: | ||
| heapq.heappush(heap, item) | ||
|
|
||
| if not found_bucket: | ||
| # 如果所有桶都满了(通常在 n % num_buckets != 0 时发生) | ||
| # 将剩余的任务分配给当前总权重最小的桶 | ||
| bucket_sum, bucket_idx = heapq.heappop(heap) | ||
| buckets[bucket_idx].append((weight, idx)) | ||
| bucket_weights[bucket_idx] += weight | ||
| bucket_counts[bucket_idx] += 1 | ||
| heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) | ||
|
|
||
|
|
||
| # (可选)按任务原始序号对每个桶内部进行排序,方便调试 | ||
| for i in range(num_buckets): | ||
| buckets[i] = sorted(buckets[i], key=lambda x: x[1]) | ||
|
|
||
| # 统计切分次数:衡量重排后数据块的连续性 | ||
| all_assigned_indices = sorted([idx for bucket in buckets for _, idx in bucket]) | ||
| cuts = sum(1 for i in range(1, len(all_assigned_indices)) if all_assigned_indices[i] != all_assigned_indices[i-1] + 1) | ||
|
|
||
| return buckets, bucket_weights, cuts | ||
|
|
||
|
|
||
| # --- 数据通信与重排辅助函数 --- | ||
|
|
||
| def get_send_dict(buckets: List[List[Tuple[int, int]]], cp_size: int, rank: int) -> Dict[int, List[int]]: | ||
| """ | ||
| 根据负载均衡分配结果,为当前 rank 生成 all-to-all 通信的发送字典。 | ||
|
|
||
| Args: | ||
| buckets (List): 所有 rank 的任务分配结果。 | ||
| cp_size (int): 通信组大小。 | ||
| rank (int): 当前进程的 rank。 | ||
|
|
||
| Returns: | ||
| Dict[int, List[int]]: 发送字典。key 是目标 rank,value 是要发送给该 rank 的本地 chunk 索引列表。 | ||
| """ | ||
| send_dict = {i: [] for i in range(cp_size)} | ||
| # 遍历所有桶(即所有目标 rank 的任务列表) | ||
| for target_rank, bucket in enumerate(buckets): | ||
| for _, chunk_idx in bucket: | ||
| # 如果某个 chunk 的原始属主是当前 rank,则需要将其发送 | ||
| if chunk_idx // cp_size == rank: | ||
| # chunk_idx % cp_size 得到的是在当前 rank 上的局部索引 | ||
| send_dict[target_rank].append(chunk_idx % cp_size) | ||
| return send_dict | ||
|
|
||
| def get_recv_dict(bucket: List[Tuple[int, int]], cp_size: int) -> Dict[int, List[int]]: | ||
| """ | ||
| 根据当前 rank 的任务分配结果,生成 all-to-all 通信的接收字典。 | ||
|
|
||
| Args: | ||
| bucket (List): 当前 rank 分配到的任务列表。 | ||
| cp_size (int): 通信组大小。 | ||
|
|
||
| Returns: | ||
| Dict[int, List[int]]: 接收字典。key 是源 rank,value 是从该 rank 接收的数据块 | ||
| 应该被放置到的本地位置索引列表。 | ||
| """ | ||
| recv_dict = {i: [] for i in range(cp_size)} | ||
| # 遍历分配给我的所有任务 | ||
| for local_pos, (_, chunk_idx) in enumerate(bucket): | ||
| # chunk_idx.item() // cp_size 得到的是这个 chunk 原始所在的 rank | ||
| source_rank = chunk_idx.item() // cp_size | ||
| recv_dict[source_rank].append(local_pos) | ||
| return recv_dict | ||
|
|
||
| def balance_alltoall( | ||
| input_tensor: paddle.Tensor, | ||
| cp_size: int, | ||
| cp_group, | ||
| chunk_size: int, | ||
| send_dict: Dict[int, List[int]], | ||
| recv_dict: Dict[int, List[int]] | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| 执行 all-to-all 通信,根据 send/recv 字典对 `input_tensor` 进行数据重排。 | ||
| 此函数已重构,可统一处理不同维度的张量。 | ||
|
|
||
| Args: | ||
| input_tensor (paddle.Tensor): 待重排的张量,如 Q, K, V。 | ||
| cp_size (int): 通信组大小。 | ||
| cp_group (dist.Group): Paddle 分布式通信组。 | ||
| chunk_size (int): 数据块的大小。 | ||
| send_dict (Dict): 发送字典。 | ||
| recv_dict (Dict): 接收字典。 | ||
|
|
||
| Returns: | ||
| paddle.Tensor: 重排后的张量。 | ||
| """ | ||
| original_shape = input_tensor.shape | ||
| B, S = original_shape[0], original_shape[1] | ||
|
|
||
| # 将输入张量统一 reshape 为 3D (B, S, -1) 以便统一处理 | ||
| tensor_3d = input_tensor.reshape((B, S, -1)) | ||
| HD = tensor_3d.shape[-1] | ||
|
|
||
| # 1. 准备发送数据 (Gather) | ||
| # 根据 send_dict,从本地张量中收集需要发送给其他 rank 的数据块 | ||
| send_list = [] | ||
| for target_rank in range(cp_size): | ||
| indices_to_send = send_dict[target_rank] | ||
| if indices_to_send: | ||
| # 将所有要发往同一个 rank 的数据块拼接在一起 | ||
| data_to_send = paddle.concat( | ||
| [tensor_3d[:, idx * chunk_size:(idx + 1) * chunk_size, :] for idx in indices_to_send], | ||
| axis=1 | ||
| ) | ||
| send_list.append(data_to_send) | ||
| else: | ||
| # 注意:NCCL alltoall 不支持大小为 0 的张量,因此发送一个虚拟的、 | ||
| # 非常小的张量作为占位符。接收方也需对应接收。 | ||
| send_list.append(paddle.zeros((B, 1, HD), dtype=input_tensor.dtype)) | ||
|
|
||
| # 2. 准备接收缓冲区 (Scatter) | ||
| # 根据 recv_dict,为从其他 rank 接收的数据准备相应大小的空缓冲区 | ||
| recv_list = [] | ||
| for source_rank in range(cp_size): | ||
| num_chunks_to_recv = len(recv_dict[source_rank]) | ||
| if num_chunks_to_recv > 0: | ||
| recv_list.append( | ||
| paddle.empty((B, chunk_size * num_chunks_to_recv, HD), dtype=input_tensor.dtype) | ||
| ) | ||
| else: | ||
| # 对应发送方的虚拟张量,接收一个同样大小的虚拟缓冲区 | ||
| recv_list.append(paddle.empty((B, 1, HD), dtype=input_tensor.dtype)) | ||
|
|
||
| # 3. 执行 All-to-All 通信 | ||
| dist.alltoall(out_tensor_list=recv_list, in_tensor_list=send_list, group=cp_group) | ||
|
|
||
| # 4. 将接收到的数据重新组装成最终张量 | ||
| final_res_3d = paddle.empty_like(tensor_3d) | ||
| for source_rank in range(cp_size): | ||
| local_positions = recv_dict[source_rank] | ||
| if local_positions: | ||
| received_data = recv_list[source_rank] | ||
| # 将从 source_rank 接收到的数据块,放置到它们在本地应该在的位置 | ||
| for i, local_pos in enumerate(local_positions): | ||
| start_s = local_pos * chunk_size | ||
| end_s = (local_pos + 1) * chunk_size | ||
| data_start = i * chunk_size | ||
| data_end = (i + 1) * chunk_size | ||
| final_res_3d[:, start_s:end_s, :] = received_data[:, data_start:data_end, :] | ||
|
|
||
| # 恢复原始形状 | ||
| return final_res_3d.reshape(original_shape) | ||
|
|
||
|
|
||
| # --- 主流程函数 --- | ||
|
|
||
| def balance_flashmask_input( | ||
| startend_row_indices: paddle.Tensor, | ||
| cp_size: int, | ||
| cp_rank: int, | ||
| balance_chunk_size: int = 2048, | ||
| q_block_size: int = 128, | ||
| k_block_size: int = 128 | ||
| ) -> Tuple[paddle.Tensor, List[List[Tuple[int, int]]]]: | ||
| """ | ||
| FlashMask 输入数据的负载均衡主流程。 | ||
| 该函数协调整个过程:估算工作负载 -> 任务分配 -> 生成通信计划 -> 数据重排。 | ||
|
|
||
| Args: | ||
| startend_row_indices (paddle.Tensor): 稀疏 attention 的原始起止索引。 | ||
| cp_size (int): 通信组大小。 | ||
| cp_rank (int): 当前进程的 rank。 | ||
| balance_chunk_size (int): 用于负载均衡分析和数据移动的块大小。 | ||
| q_block_size (int): FlashAttention kernel 的 query 块大小。 | ||
| k_block_size (int): FlashAttention kernel 的 key 块大小。 | ||
|
|
||
| Returns: | ||
| Tuple: | ||
| - local_startend_row_indices (paddle.Tensor): 经过负载均衡和重排后, | ||
| 当前 rank 需要处理的局部起止索引。 | ||
| - buckets (List[List[Tuple[int, int]]]): 全局的任务分配方案,用于后续 | ||
| 对 Q, K, V 等张量进行同样的重排。 | ||
| """ | ||
| # 步骤 1: 估算每个 chunk 的工作负载 | ||
| paddle.base.core.nvprof_nvtx_push("get_q_workload") | ||
| workload = get_q_workload(startend_row_indices, balance_chunk_size, q_block_size, k_block_size) | ||
| paddle.base.core.nvprof_nvtx_pop() | ||
|
|
||
| # 步骤 2: 使用堆贪心算法在 CPU 上进行任务分配 | ||
| paddle.base.core.nvprof_nvtx_push("assign_tasks_heap") | ||
| # 将 workload tensor 转换成 numpy 数组以用于 heapq | ||
| tasks_np = workload.reshape([-1, 2]).cpu().numpy() | ||
| buckets, _, _ = assign_tasks_heap(tasks_np, cp_size) | ||
| paddle.base.core.nvprof_nvtx_pop() | ||
|
|
||
| # 步骤 5: 根据全局分配方案 `buckets`,对原始索引张量进行重排 (Gather) | ||
| # 这一步创建了一个全局视角下、数据块被重新排列后的 `startend_row_indices`。 | ||
| paddle.base.core.nvprof_nvtx_push("startend_row_indices_rerank") | ||
| # 将 `buckets` 展平,得到一个新的 chunk 顺序 | ||
| rerank_indices = np.array([idx for bucket in buckets for _, idx in bucket], dtype=np.int32) | ||
| indices_tensor = paddle.to_tensor(rerank_indices, dtype='int32', place=startend_row_indices.place) | ||
|
|
||
| # 使用 CUDA 核高效地执行 gather 操作 | ||
| startend_row_indices_rerank = indices_rerank_cuda(startend_row_indices, indices_tensor) | ||
| paddle.base.core.nvprof_nvtx_pop() | ||
|
|
||
| # 步骤 6: 从重排后的全局索引中,计算出当前 rank 的局部索引 (Localize) | ||
| # 这一步将全局索引(可能跨越整个序列长度S)转换为相对于本地数据块的局部索引。 | ||
| paddle.base.core.nvprof_nvtx_push("indices_to_chunks") | ||
| local_bucket_indices = [x[1] for x in buckets[cp_rank]] | ||
| local_indices_tensor = paddle.to_tensor(local_bucket_indices, dtype='int32', place=startend_row_indices.place) | ||
|
|
||
| # 使用 CUDA 核高效地执行索引的 clipping 和 offsetting | ||
| local_startend_row_indices = indices_to_chunks_cuda( | ||
| startend_row_indices_rerank, local_indices_tensor, balance_chunk_size | ||
| ) | ||
| paddle.base.core.nvprof_nvtx_pop() | ||
|
|
||
| return local_startend_row_indices, buckets | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.