Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""

def __init__(self, fuse_norm_quant: bool = True, **kwargs):
def __init__(self, fuse_norm_quant: bool = True, fuse_allreduce_rms: bool = True, **kwargs):
"""
Initialize the configuration.

Expand All @@ -236,7 +236,7 @@ def __init__(self, fuse_norm_quant: bool = True, **kwargs):
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
# Add more compilation related configs here as needed
self.fuse_allreduce_rms = fuse_allreduce_rms


class XliteGraphConfig:
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def fusion_pass_compile(
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

print("=========torch compile graph=========")
print(graph.graph)

Comment on lines +52 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These print statements appear to be for debugging and should be removed from production code to keep logs clean.

def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config["graph_fusion_manager"]
graph = current_pass_manager(graph, runtime_shape)
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ def configure(self, config: VllmConfig):
from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed
if self.ascend_compilation_config.get("fuse_allreduce_rms", True):
from .passes.allreduce_rmsnorm_fusion_pass import MatmulAllReduceAddRMSNormPass
self.passes.append(MatmulAllReduceAddRMSNormPass(config))
99 changes: 99 additions & 0 deletions vllm_ascend/compilation/passes/allreduce_rmsnorm_fusion_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.distributed.parallel_state import get_tp_group
from vllm.config import VllmConfig

import logging

class MatmulAllReduceAddRMSNormPattern:

def __init__(self, vllm_config, eps=1e-6):
self.vllm_config = vllm_config
self.eps = eps
device_group = get_tp_group().device_group
self.local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.tp_group_name = backend.get_hccl_comm_name(self.local_rank)

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
x = torch.randn(2, 4, device="npu")
weight = torch.randn(8, 4, device="npu")
residual = torch.randn(2, 8, device="npu")
rms_norm_weight = torch.randn(8, device="npu")
return [x, weight, residual, rms_norm_weight]

def register(self, pm_pass: PatternMatcherPass):

def pattern(x, weight, residual, rms_norm_weight):
"""
Pattern for AddRMSNormQuant fusion.
"""
tmp = torch.nn.functional.linear(x, weight)
# all_reduce_ = torch.ops.vllm.all_reduce(tmp, group_name="")
all_reduce_ = torch.ops.vllm.tensor_model_parallel_all_reduce(tmp)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]

return out0, out1

def replacement(x, weight, residual, rms_norm_weight):
"""
Replacement for the AddRMSNormQuant fusion.
"""
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight,
self.tp_group_name, 0, 0, self.eps, True, True)
Comment on lines +65 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tpRankSize and tpRankId arguments for torch.ops._C_ascend.matmul_allreduce_add_rmsnorm are hardcoded to 0. This is a critical bug that will cause incorrect behavior in distributed environments. Please use the tensor parallel world size and the correct rank ID.

While self.local_rank is correctly initialized, the world size is missing. You can get it using get_tp_group().world_size. For better performance, consider caching this value in the __init__ method.

Suggested change
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight,
self.tp_group_name, 0, 0, self.eps, True, True)
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight,
self.tp_group_name, get_tp_group().world_size, self.local_rank, self.eps, True, True)

return out0, out1

pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)


class MatmulAllReduceAddRMSNormPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="allreduce_rmsnorm_fusion_pass")

common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
MatmulAllReduceAddRMSNormPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)

def __call__(self, graph: torch.fx.Graph):
logging.info("=========before fusion graph========")
logging.info(graph.graph)
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.info("=========after fusion graph========")
logging.info(graph.graph)
logging.warning("Replaced %s patterns", self.matched_count)
Comment on lines +85 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logging levels used here are too high. Printing entire graphs with logging.info can be excessively verbose for production environments and should be changed to logging.debug. Additionally, logging.warning should be reserved for potential problems, not for reporting a successful operation like pattern replacement, which should be logged at the info or debug level.

Suggested change
logging.info("=========before fusion graph========")
logging.info(graph.graph)
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.info("=========after fusion graph========")
logging.info(graph.graph)
logging.warning("Replaced %s patterns", self.matched_count)
logging.debug("=========before fusion graph========")
logging.debug(graph.graph)
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.debug("=========after fusion graph========")
logging.debug(graph.graph)
logging.info("Replaced %s patterns", self.matched_count)

self.end_and_log()


def is_applicable(self, runtime_shape):
"""
Check if the pass is applicable for the current configuration.
"""
return True
9 changes: 9 additions & 0 deletions vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
else:
return tensor_model_parallel_all_reduce(final_hidden_states)

def tensor_model_parallel_all_reduce_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(final_hidden_states)

def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
layer_name: str) -> torch.Tensor:
Expand Down Expand Up @@ -336,6 +339,12 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
mutates_args=[],
dispatch_key="PrivateUse1")

direct_register_custom_op(op_name="tensor_model_parallel_all_reduce",
op_func=tensor_model_parallel_all_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")

direct_register_custom_op(op_name="matmul_and_reduce",
op_func=_matmul_and_reduce_impl,
fake_impl=_matmul_and_reduce_impl_fake,
Expand Down
Loading