-
Notifications
You must be signed in to change notification settings - Fork 661
[Fusion] [Graph]Add Matmul Allreduce Rmsnorm fusion Pass #5034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: wxsIcey <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new fusion pass for Matmul -> AllReduce -> RMSNorm to optimize performance on Ascend hardware. The changes include a new configuration flag, the fusion pass implementation, and its integration into the compilation process. My review has identified a few issues: there are some leftover debugging print statements that should be removed. More critically, the new fusion pass contains a bug where tensor parallel rank and world size are hardcoded to 0, which will cause failures in distributed setups. There are also some logging statements with inappropriately high severity levels that could flood production logs. I've provided suggestions to fix these issues.
| out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight, | ||
| self.tp_group_name, 0, 0, self.eps, True, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tpRankSize and tpRankId arguments for torch.ops._C_ascend.matmul_allreduce_add_rmsnorm are hardcoded to 0. This is a critical bug that will cause incorrect behavior in distributed environments. Please use the tensor parallel world size and the correct rank ID.
While self.local_rank is correctly initialized, the world size is missing. You can get it using get_tp_group().world_size. For better performance, consider caching this value in the __init__ method.
| out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight, | |
| self.tp_group_name, 0, 0, self.eps, True, True) | |
| out0, out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(x, weight, residual, rms_norm_weight, | |
| self.tp_group_name, get_tp_group().world_size, self.local_rank, self.eps, True, True) |
| print("=========torch compile graph=========") | ||
| print(graph.graph) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logging.info("=========before fusion graph========") | ||
| logging.info(graph.graph) | ||
| self.begin() | ||
| self.matched_count = self.pattern_match_passes.apply(graph) | ||
| logging.info("=========after fusion graph========") | ||
| logging.info(graph.graph) | ||
| logging.warning("Replaced %s patterns", self.matched_count) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logging levels used here are too high. Printing entire graphs with logging.info can be excessively verbose for production environments and should be changed to logging.debug. Additionally, logging.warning should be reserved for potential problems, not for reporting a successful operation like pattern replacement, which should be logged at the info or debug level.
| logging.info("=========before fusion graph========") | |
| logging.info(graph.graph) | |
| self.begin() | |
| self.matched_count = self.pattern_match_passes.apply(graph) | |
| logging.info("=========after fusion graph========") | |
| logging.info(graph.graph) | |
| logging.warning("Replaced %s patterns", self.matched_count) | |
| logging.debug("=========before fusion graph========") | |
| logging.debug(graph.graph) | |
| self.begin() | |
| self.matched_count = self.pattern_match_passes.apply(graph) | |
| logging.debug("=========after fusion graph========") | |
| logging.debug(graph.graph) | |
| logging.info("Replaced %s patterns", self.matched_count) |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
Signed-off-by: wxsIcey <[email protected]>
Uh oh!
There was an error while loading. Please reload this page.