Skip to content

Conversation

@wxsIcey
Copy link
Collaborator

@wxsIcey wxsIcey commented Dec 15, 2025

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fusion pass for Matmul -> AllReduce -> RMSNorm to optimize performance on Ascend hardware. The changes include a new configuration flag, the fusion pass implementation, and its integration into the compilation process. My review has identified a few issues: there are some leftover debugging print statements that should be removed. More critically, the new fusion pass contains a bug where tensor parallel rank and world size are hardcoded to 0, which will cause failures in distributed setups. There are also some logging statements with inappropriately high severity levels that could flood production logs. I've provided suggestions to fix these issues.

Comment on lines +64 to +65
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight,
self.tp_group_name, 0, 0, self.eps, True, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tpRankSize and tpRankId arguments for torch.ops._C_ascend.matmul_allreduce_add_rmsnorm are hardcoded to 0. This is a critical bug that will cause incorrect behavior in distributed environments. Please use the tensor parallel world size and the correct rank ID.

While self.local_rank is correctly initialized, the world size is missing. You can get it using get_tp_group().world_size. For better performance, consider caching this value in the __init__ method.

Suggested change
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight,
self.tp_group_name, 0, 0, self.eps, True, True)
out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight,
self.tp_group_name, get_tp_group().world_size, self.local_rank, self.eps, True, True)

Comment on lines +52 to +54
print("=========torch compile graph=========")
print(graph.graph)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These print statements appear to be for debugging and should be removed from production code to keep logs clean.

Comment on lines +84 to +90
logging.info("=========before fusion graph========")
logging.info(graph.graph)
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.info("=========after fusion graph========")
logging.info(graph.graph)
logging.warning("Replaced %s patterns", self.matched_count)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logging levels used here are too high. Printing entire graphs with logging.info can be excessively verbose for production environments and should be changed to logging.debug. Additionally, logging.warning should be reserved for potential problems, not for reporting a successful operation like pattern replacement, which should be logged at the info or debug level.

Suggested change
logging.info("=========before fusion graph========")
logging.info(graph.graph)
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.info("=========after fusion graph========")
logging.info(graph.graph)
logging.warning("Replaced %s patterns", self.matched_count)
logging.debug("=========before fusion graph========")
logging.debug(graph.graph)
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.debug("=========after fusion graph========")
logging.debug(graph.graph)
logging.info("Replaced %s patterns", self.matched_count)

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: wxsIcey <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant