Skip to content

[Draft] [Feature] Add Flux Kernel Support for TP+SP Parallel #777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: deepep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import torch
import torch.distributed as dist

from lightllm.utils.envs_utils import enable_env_vars
from ..transformer_layer_infer import TransformerLayerInfer
from ...infer_struct import InferStateInfo
from lightllm.utils.infer_utils import mark_cost_time
Expand All @@ -22,6 +24,8 @@ def __init__(self, layer_num, network_config, mode):
self.tp_o_head_num_ = -1
self.head_dim_ = -1
self.embed_dim_ = -1
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
self.enable_flux = enable_env_vars("LIGHTLLM_ENABLE_FLUX")
return

def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import flux
import torch
from torch.distributed import ProcessGroup
from abc import abstractmethod
from typing import Optional, Tuple, List, Dict, Union
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
Expand Down Expand Up @@ -40,6 +42,110 @@ def __init__(
# 标记是否存在 bias, 由子类初始化
self.has_bias: bool = None

def flux_gemm_rs(
self,
input_tensor: torch.Tensor,
out: Optional[torch.Tensor] = None,
use_custom_tensor_mananger: bool = True,
transpose_weight: bool = True,
):
# logger.debug("flux_gemm_rs kernel start")
# check the weight contiguous
if not self.weight.is_contiguous():
self.weight = self.weight.contiguous()
assert self.weight.is_contiguous(), "weight must be contiguous"

local_M = input_tensor.size(0)
M = local_M * self.tp_world_size_
N = self.weight.size(1)

if out is None:
# out shape: [4, 4096]
shape = (local_M, self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
if use_custom_tensor_mananger:
out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False)
else:
out = torch.zeros(shape, dtype=dtype, device=device)
# logger.debug(f"{input_tensor.shape}, {self.weight.shape}, {out.shape}")

# with flux.util.group_profile(
# name="gemm_rs_" + os.environ["TORCHELASTIC_RUN_ID"], do_prof=False, group=torch.distributed.group.WORLD
# ):
gemm_rs_op = flux.GemmRS(
torch.distributed.group.WORLD,
1,
(M + 1024 - 1) // 1024 * 1024,
N,
input_tensor.dtype,
out.dtype,
transpose_weight=transpose_weight,
)
# logger.debug(f"gemm_rs_kernel initialized M={M}, N={N}, local_M={local_M}")
out = gemm_rs_op.forward(
input_tensor,
self.weight,
bias=self.bias,
fast_accum=False,
reduce_scatter_option=flux.ReduceScatterOption(),
)
return out

def flux_ag_gemm(
self,
input_tensor: torch.Tensor,
out: Optional[torch.Tensor] = None,
use_custom_tensor_mananger: bool = True,
transpose_weight: bool = True,
):
# logger.debug("flux_ag_gemm kernel start")
# check the weight contiguous
if not self.weight.is_contiguous():
self.weight = self.weight.contiguous()
assert self.weight.is_contiguous(), "weight must be contiguous"

local_M = input_tensor.size(0)
M = local_M * self.tp_world_size_
K = input_tensor.size(1)
N = self.weight.size(1)

if out is None:
shape = (M, self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
if use_custom_tensor_mananger:
out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False)
else:
out = torch.zeros(shape, dtype=dtype, device=device)
# logger.debug(f"{input_tensor.shape}, {self.weight.shape}, {out.shape}")

# with flux.util.group_profile(
# name="ag_gemm_" + os.environ["TORCHELASTIC_RUN_ID"], do_prof=False, group=torch.distributed.group.WORLD
# ):
ag_option = flux.AllGatherOption()
ag_gemm_op = flux.AGKernel(
torch.distributed.group.WORLD,
1,
M,
N,
K,
input_tensor.dtype,
output_dtype=out.dtype,
)
# full_input = torch.empty((M, K), dtype=input_tensor.dtype, device=input_tensor.device)
# logger.debug(f"ag_gemm_kernel initialized M={M}, N={N}, K={K}")
out = ag_gemm_op.forward(
input_tensor,
self.weight,
output=out,
bias=self.bias,
transpose_weight=transpose_weight,
# gathered_input=full_input,
all_gather_option=ag_option,
)
return out, None

def mm(
self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True
) -> torch.Tensor:
Expand Down
122 changes: 73 additions & 49 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,25 @@ def _get_qkv(
def _tpsp_get_qkv(
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
if self.tp_world_size_ > 1:
sp_token_num, hidden_dim = input.shape
gather_input = self.alloc_tensor(
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
)
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
input = gather_input[0 : len(infer_state.position_cos), :]

q = layer_weight.q_proj.mm(input)
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
if not self.enable_flux or self.tp_world_size_ == 1:
if self.tp_world_size_ > 1:
sp_token_num, hidden_dim = input.shape
gather_input = self.alloc_tensor(
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
)
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
input = gather_input[0 : len(infer_state.position_cos), :]

q = layer_weight.q_proj.mm(input)
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
else:
q, _ = layer_weight.q_proj.flux_ag_gemm(input)
kv, _ = layer_weight.kv_proj.flux_ag_gemm(input)
q = q[: len(infer_state.position_cos)]
kv = kv[: len(infer_state.position_cos)]
cache_kv = kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
Expand Down Expand Up @@ -248,20 +255,28 @@ def _tpsp_get_o(
) -> torch.Tensor:
input = input.view(-1, self.tp_o_head_num_ * self.head_dim_)
dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :])

if self.tp_world_size_ > 1:
sp_token_num = o_tensor.shape[0] // self.tp_world_size_
reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device)
reduce_scatter_tensor(
output=reduce_o_tensor,
input=o_tensor,
op=dist.ReduceOp.SUM,
group=infer_state.dist_group,
async_op=False,
)
o_tensor = reduce_o_tensor
if not self.enable_flux or self.tp_world_size_ == 1:
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :])

if self.tp_world_size_ > 1:
sp_token_num = o_tensor.shape[0] // self.tp_world_size_
reduce_o_tensor = self.alloc_tensor(
(sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device
)
reduce_scatter_tensor(
output=reduce_o_tensor,
input=o_tensor,
op=dist.ReduceOp.SUM,
group=infer_state.dist_group,
async_op=False,
)
o_tensor = reduce_o_tensor
else:
# padding input_tensor to dest_size
pad_input = self.alloc_tensor((dest_size, input.size(1)), dtype=input.dtype, device=input.device)
pad_input[0 : input.size(0), :] = input
o_tensor = layer_weight.o_proj.flux_gemm_rs(pad_input)

return o_tensor

Expand All @@ -280,30 +295,39 @@ def _tpsp_ffn(
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
input = input.view(-1, self.embed_dim_)
if self.tp_world_size_ > 1:
if not self.enable_flux or self.tp_world_size_ == 1:
if self.tp_world_size_ > 1:
sp_token_num, hidden_dim = input.shape
gather_input = self.alloc_tensor(
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
)
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
input = gather_input

up_gate_out = layer_weight.gate_up_proj.mm(input)
ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype)
silu_and_mul_fwd(up_gate_out, ffn1_out)
input = None
up_gate_out = None
ffn2_out = layer_weight.down_proj.mm(ffn1_out)
ffn1_out = None
if self.tp_world_size_ > 1:
sp_token_num = ffn2_out.shape[0] // self.tp_world_size_
reduce_o_tensor = self.alloc_tensor(
(sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device
)
reduce_scatter_tensor(
reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False
)
ffn2_out = reduce_o_tensor
else:
up_gate_out, _ = layer_weight.gate_up_proj.flux_ag_gemm(input)
sp_token_num, hidden_dim = input.shape
gather_input = self.alloc_tensor(
(sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device
)
all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False)
input = gather_input

up_gate_out = layer_weight.gate_up_proj.mm(input)
ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype)
silu_and_mul_fwd(up_gate_out, ffn1_out)
input = None
up_gate_out = None
ffn2_out = layer_weight.down_proj.mm(ffn1_out)
ffn1_out = None
if self.tp_world_size_ > 1:
sp_token_num = ffn2_out.shape[0] // self.tp_world_size_
reduce_o_tensor = self.alloc_tensor(
(sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device
)
reduce_scatter_tensor(
reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False
)
ffn2_out = reduce_o_tensor
ffn1_out = self.alloc_tensor((sp_token_num * self.tp_world_size_, up_gate_out.size(1) // 2), input.dtype)
silu_and_mul_fwd(up_gate_out, ffn1_out)
input = None
up_gate_out = None
ffn2_out = layer_weight.down_proj.flux_gemm_rs(ffn1_out)
return ffn2_out

# # keep code
Expand Down
6 changes: 6 additions & 0 deletions lightllm/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import flux
import torch.distributed as dist
import os
import torch
Expand Down Expand Up @@ -102,6 +103,11 @@ def init_distributed_env(kvargs):
rank=kvargs["rank_id"],
world_size=kvargs["world_size"],
)

# initilize flux communication kernel
flux.init_flux_shm(torch.distributed.group.WORLD)
torch.cuda.synchronize()

# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
dist.all_reduce(_a)
Expand Down