Skip to content

Commit

Permalink
optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Feb 19, 2025
1 parent c392ba5 commit 7ac997e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 79 deletions.
124 changes: 48 additions & 76 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def __init__(

if has_te:

class TEColumnParallelLinear(te.pytorch.Linear):
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""
Expand All @@ -1032,34 +1032,36 @@ def __init__(
bias: bool,
skip_bias_add: bool,
is_expert: bool,
split_mode: str = "none",
tp_comm_buffer_name: str = None,
):
if is_expert:
raise ValueError("Transformer Engine linear layers do not yet support MoE")

# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
# return_bias=True. Here we need a single Tensor
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
if is_te_min_version("0.12.0"):
extra_kwargs["device"] = torch.cuda.current_device()
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"].get("tp_overlap", False):
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_wgrad", True
)
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_dgrad", True
)
if is_te_min_version("1.5.0", check_equality=False):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
if split_mode == "column":
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_wgrad", True
)
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_dgrad", True
)
elif split_mode == "row":
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_rs", True
)
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
Expand All @@ -1070,6 +1072,7 @@ def __init__(
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)

super().__init__(
in_features=in_features,
out_features=out_features,
Expand All @@ -1078,7 +1081,7 @@ def __init__(
tp_size=tp_size,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
parallel_mode=split_mode,
**extra_kwargs,
)

Expand All @@ -1088,14 +1091,13 @@ def forward(self, x):
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)

self.is_first_microbatch = False

return out

class TERowParallelLinear(te.pytorch.Linear):
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Wrapper for the TELinear layer.
"""

def __init__(
Expand All @@ -1107,72 +1109,42 @@ def __init__(
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
# TE returns a zero length Tensor when bias=False and
# return_bias=True. Here we need a single Tensor
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
if is_te_min_version("0.12.0"):
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"].get("tp_overlap", False):
if is_te_min_version("1.5.0"):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_rs", True
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name

self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

split_mode = "row"
if explicit_expert_comm:
assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size)
in_features = in_features // tp_size
split_mode = None
tp_size = 1
tp_group = None

super().__init__(
in_features=in_features,
out_features=out_features,
sequence_parallel=gpc.config.parallel.sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
in_features,
out_features,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=split_mode,
**extra_kwargs,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
split_mode="column",
tp_comm_buffer_name=tp_comm_buffer_name,
)

def forward(self, x):
"""Forward."""
_is_first_microbatch = self.is_first_microbatch
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)
self.is_first_microbatch = False
class TERowParallelLinear(TELinear):
"""
Wrapper for the TELinear layer.
"""

return out
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
super().__init__(
in_features,
out_features,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
split_mode="row",
tp_comm_buffer_name=tp_comm_buffer_name,
)

else:
TELinear = ParallelLinearWithCommExt
TEColumnParallelLinear = ColumnParallelLinear
TERowParallelLinear = RowParallelLinear

Expand Down
5 changes: 2 additions & 3 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@
RewardModelLinear,
RowParallelLinear,
ScaleColumnParallelLinear,
TEColumnParallelLinear,
TERowParallelLinear,
TELinear,
new_linear,
)
from internlm.model.modules.norm import new_layer_norm
Expand Down Expand Up @@ -209,7 +208,7 @@ def _check_module(name, module):
elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True)
# for non-moe linear module
elif isinstance(module, (ParallelLinearWithCommExt, TERowParallelLinear, TEColumnParallelLinear)):
elif isinstance(module, (ParallelLinearWithCommExt, TELinear)):
for param in module.parameters():
if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
Expand Down

0 comments on commit 7ac997e

Please sign in to comment.