Skip to content

Commit

Permalink
fix transformer_engine import error and tp_overlap key error
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Feb 19, 2025
1 parent a71091c commit 5036935
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 160 deletions.
2 changes: 2 additions & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@
tp_overlap_cfg=dict(
tp_comm_overlap_ag=True,
tp_comm_overlap_rs=True,
tp_comm_bulk_wgrad=True,
tp_comm_bulk_dgrad=True,
),
),
pipeline=dict(size=1, interleaved_overlap=True),
Expand Down
4 changes: 2 additions & 2 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
if linear_name in ("gate"):
return "gate" # for MoE model
elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"):
if gpc.config.parallel.tensor.tp_overlap:
if gpc.config.parallel["tensor"].get("tp_overlap", False):
return "tecolumn"
else:
return "column"
Expand All @@ -170,7 +170,7 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
elif linear_name in ("wo", "out_proj", "w2") and tp_mode == TensorParallelMode.isp.name:
return "column"
elif linear_name in ("wo", "out_proj", "w2"):
if gpc.config.parallel.tensor.tp_overlap:
if gpc.config.parallel["tensor"].get("tp_overlap", False):
return "terow"
else:
return "row"
Expand Down
2 changes: 1 addition & 1 deletion internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator),
)

if gpc.config.parallel["tensor"]["tp_overlap"]:
if gpc.config.parallel["tensor"].get("tp_overlap", False):
self._initialize_tp_comm_ub()

# set attributes
Expand Down
1 change: 0 additions & 1 deletion internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def args_sanity_check():
tp_comm_overlap_rs=True,
tp_comm_bulk_wgrad=True,
tp_comm_bulk_dgrad=True,
tp_comm_overlap_rs_dgrad=False,
)

# set default value for weight parallel
Expand Down
305 changes: 157 additions & 148 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

import torch
import torch.distributed as dist
import transformer_engine as te

try:
import transformer_engine as te

has_te = True
except (ModuleNotFoundError, ImportError):
has_te = False

from torch import nn

from internlm.accelerator import get_accelerator
Expand Down Expand Up @@ -1011,161 +1018,163 @@ def __init__(
self.full_weight_shape = torch.Size((num_groups, in_features, out_features))


class TEColumnParallelLinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""
if has_te:

def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if is_expert:
raise ValueError("Transformer Engine linear layers do not yet support MoE")

# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
if is_te_min_version("0.12.0"):
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"]["tp_overlap"]:
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_wgrad", True
)
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_dgrad", True
)
if is_te_min_version("1.5.0", check_equality=False):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name

parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)
super().__init__(
in_features=in_features,
out_features=out_features,
sequence_parallel=gpc.config.parallel.sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
**extra_kwargs,
)

def forward(self, x):
"""Forward."""
_is_first_microbatch = self.is_first_microbatch
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)

self.is_first_microbatch = False
class TEColumnParallelLinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""

return out
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if is_expert:
raise ValueError("Transformer Engine linear layers do not yet support MoE")

# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
if is_te_min_version("0.12.0"):
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"].get("tp_overlap", False):
extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_wgrad", True
)
extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_bulk_dgrad", True
)
if is_te_min_version("1.5.0", check_equality=False):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name

parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)
super().__init__(
in_features=in_features,
out_features=out_features,
sequence_parallel=gpc.config.parallel.sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
**extra_kwargs,
)

def forward(self, x):
"""Forward."""
_is_first_microbatch = self.is_first_microbatch
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)

class TERowParallelLinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""
self.is_first_microbatch = False

def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
# TE returns a zero length Tensor when bias=False and
# return_bias=True. Here we need a single Tensor
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
if is_te_min_version("0.12.0"):
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"]["tp_overlap"]:
if is_te_min_version("1.5.0"):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_rs", True
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
return out

self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

split_mode = "row"
if explicit_expert_comm:
assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size)
in_features = in_features // tp_size
split_mode = None
tp_size = 1
tp_group = None
class TERowParallelLinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
"""

super().__init__(
in_features=in_features,
out_features=out_features,
sequence_parallel=gpc.config.parallel.sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=split_mode,
**extra_kwargs,
)
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
# TE returns a zero length Tensor when bias=False and
# return_bias=True. Here we need a single Tensor
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True

extra_kwargs = {"params_dtype": gpc.config.model.dtype}
if is_te_min_version("0.12.0"):
extra_kwargs["device"] = torch.cuda.current_device()

if gpc.config.parallel["tensor"].get("tp_overlap", False):
if is_te_min_version("1.5.0"):
extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_ag", True
)
extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get(
"tp_comm_overlap_rs", True
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0")
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name

self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1
parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert)
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
tp_size = gpc.get_world_size(parallel_mode)
tp_group = gpc.get_group(parallel_mode)
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

split_mode = "row"
if explicit_expert_comm:
assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size)
in_features = in_features // tp_size
split_mode = None
tp_size = 1
tp_group = None

super().__init__(
in_features=in_features,
out_features=out_features,
sequence_parallel=gpc.config.parallel.sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=split_mode,
**extra_kwargs,
)

for param in self.parameters():
setattr(param, "allreduce", not (is_expert and self.expert_parallel))
def forward(self, x):
"""Forward."""
_is_first_microbatch = self.is_first_microbatch
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)
self.is_first_microbatch = False

def forward(self, x):
"""Forward."""
_is_first_microbatch = self.is_first_microbatch
x = x.transpose(0, 1)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
out = out.transpose(0, 1)
self.is_first_microbatch = False
return out

return out
else:
TEColumnParallelLinear = ColumnParallelLinear
TERowParallelLinear = RowParallelLinear


def new_linear(
Expand Down Expand Up @@ -1217,7 +1226,7 @@ def new_linear(
weight_scale=weight_scale,
norm_head=norm_head,
)
elif split_mode == "column":
elif split_mode == "column" or (split_mode == "tecolumn" and not has_te):
return ColumnParallelLinear(
in_features,
out_features,
Expand All @@ -1236,7 +1245,7 @@ def new_linear(
is_expert,
tp_comm_buffer_name,
)
elif split_mode == "row":
elif split_mode == "row" or (split_mode == "terow" and not has_te):
return RowParallelLinear(
in_features,
out_features,
Expand Down
Loading

0 comments on commit 5036935

Please sign in to comment.