-
Notifications
You must be signed in to change notification settings - Fork 462
[not for land] [mxfp8 training] fix TP bug #3985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,24 @@ | |
| torch.ops.aten.clone.default, | ||
| torch.ops.aten.transpose.int, | ||
| torch.ops.aten.t.default, | ||
| # required for TP - scatter_ is used to distribute weights | ||
| torch.ops.c10d.scatter_.default, | ||
| # # required for quantizing weights | ||
| # torch.ops.aten.mul.Tensor, | ||
| # torch.ops.aten.abs.default, | ||
| # torch.ops.aten.amax.default, | ||
| # torch.ops.aten.clamp.default, | ||
| # torch.ops.aten.to.dtype, | ||
| # torch.ops.aten.unsqueeze.default, | ||
| # torch.ops.aten.div.Tensor, | ||
| # torch.ops.aten.reshape.default, | ||
| # torch.ops.aten.isnan.default, | ||
| # torch.ops.aten.log2.default, | ||
| # torch.ops.aten.where.default, | ||
| # torch.ops.aten.where.self, | ||
| # torch.ops.aten.ceil.default, | ||
| # torch.ops.aten.view.dtype, | ||
| # torch.ops.aten.squeeze.dim, | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -89,6 +107,7 @@ def __torch_function__(cls, func, types, args, kwargs={}): | |
|
|
||
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args, kwargs={}): | ||
| print("[TORCH_DISPATCH]: ", func.__name__) | ||
| # unwrap args/kwargs and extract config | ||
| config = None | ||
|
|
||
|
|
@@ -222,10 +241,6 @@ def __torch_function__(cls, func, types, args, kwargs={}): | |
| # Use torchao scaled grouped mm with dynamic quant for | ||
| # "2d x 3d with offsets" case (used for routed experts). | ||
| # Otherwise, fall back to regular grouped mm. | ||
| # | ||
| # TODO: support "3d x 3d without offsets" case, which is | ||
| # used for shared experts. This is basically the grouped_mm | ||
| # kernel handling a bmm. | ||
| A, B = args[0], args[1] | ||
|
|
||
| assert not isinstance(A, cls), f"A should not be a {cls.__name__}" | ||
|
|
@@ -263,14 +278,11 @@ class MXFP8TrainingWeightWrapperTensor(TrainingWeightWrapperBaseTensor): | |
| @classmethod | ||
| def __torch_function__(cls, func, types, args, kwargs={}): | ||
| # grouped_mm op override | ||
| print("[TORCH_FUNCTION]", func.__name__) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove these before landing? |
||
| if func.__name__ == "_grouped_mm": | ||
| # Use torchao scaled grouped mm with dynamic quant for | ||
| # "2d x 3d with offsets" case (used for routed experts). | ||
| # Otherwise, fall back to regular grouped mm. | ||
| # | ||
| # TODO: support "3d x 3d without offsets" case, which is | ||
| # used for shared experts. This is basically the grouped_mm | ||
| # kernel handling a bmm. | ||
| A, B = args[0], args[1] | ||
|
|
||
| assert not isinstance(A, cls), f"A should not be a {cls.__name__}" | ||
|
|
@@ -285,26 +297,29 @@ def __torch_function__(cls, func, types, args, kwargs={}): | |
| if A_is_2d and B_is_2d_or_3d and offs is not None: | ||
| return _quantize_then_scaled_grouped_mm( | ||
| A, | ||
| B, | ||
| unwrap_weight(B), | ||
| offs=offs, | ||
| config=config, | ||
| ) | ||
|
|
||
| # linear op override | ||
| elif func.__name__ in ("linear", "mm", "matmul", "addmm"): | ||
| elif func.__name__ in ("linear", "mm", "mm.default"): | ||
| A, B = args[0], args[1] | ||
| assert not isinstance(A, cls), f"A should not be a {cls.__name__}" | ||
|
|
||
| assert not isinstance(A, cls), f"A should not be a {cls.__name__}" | ||
| assert isinstance(B, cls), f"B should be a {cls.__name__}" | ||
|
|
||
| config = B.config | ||
| assert isinstance(config, MXFP8TrainingOpConfig), ( | ||
| "expected MXFP8TrainingOpConfig" | ||
| ) | ||
|
|
||
|
|
||
| # Log weight shard statistics | ||
| weight = B._data | ||
|
|
||
| return _to_mxfp8_then_scaled_mm( | ||
| A, | ||
| B, | ||
| unwrap_weight(B), | ||
| kernel_preference=config.kernel_preference, | ||
| scale_calculation_mode=config.scale_calculation_mode, | ||
| wgrad_with_hp=config.wgrad_with_hp, | ||
|
|
@@ -315,3 +330,19 @@ def __torch_function__(cls, func, types, args, kwargs={}): | |
| # the wrapping behavior of the super() impl, go directly to dispatch | ||
| with torch._C.DisableTorchFunctionSubclass(): | ||
| return func(*args, **kwargs) | ||
|
|
||
|
|
||
| class _UnwrapWeight(torch.autograd.Function): | ||
| """Helper to unwrap the tensor subclass in a differentiable way""" | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, wrapper_tensor): | ||
| return wrapper_tensor._data | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| return grad_output | ||
|
|
||
|
|
||
| def unwrap_weight(wrapper_tensor): | ||
| return _UnwrapWeight.apply(wrapper_tensor) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,7 +170,8 @@ def to_mx( | |
| assert data_hp.shape[-1] % block_size == 0, ( | ||
| f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" | ||
| ) | ||
| assert data_hp.is_contiguous(), "unsupported" | ||
| if not data_hp.is_contiguous(): | ||
| assert data_hp.is_contiguous(), "unsupported" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a bit confused by this, if it's not contiguous it would fail like before, so is there a reason behind this change?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @andrewor14 sorry i linked the wrong PR, this is not the one that will address test failures, this is a WIP draft for an issue we are still trying to find a proper solution for - please disregard this PR. |
||
| assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" | ||
|
|
||
| orig_shape = data_hp.shape | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this is the real fix right? Do we need the other ops you commented out?