Skip to content

Commit 047f15a

Browse files
authored
[Fix] import MXLinear from Primus Turbo (#272)
Signed-off-by: Gene Der Su <[email protected]>
1 parent c9bfb4b commit 047f15a

File tree

2 files changed

+4
-4
lines changed
  • .github/workflows
  • primus/backends/torchtitan/components/quantization

2 files changed

+4
-4
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
pull_request:
1111

1212
env:
13-
PRIMUS_TURBO_COMMIT: 0385cdd615cb4eff28a1cbbf3583fccf95d11fe9 # chore: refactor grouped gemm blockwise python code (#142)
13+
PRIMUS_TURBO_COMMIT: 7ecd1edc31a3f1607e92b536b4250f9d98bfe423 # feat: unify fp8 gemm (#148)
1414

1515
jobs:
1616
code-lint:

primus/backends/torchtitan/components/quantization/mx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn as nn
99
from primus_turbo.pytorch.core.float8 import Float8QuantConfig, ScalingGranularity
10-
from primus_turbo.pytorch.modules import MXLinear
10+
from primus_turbo.pytorch.modules import Float8Linear
1111
from torchtitan.config.job_config import JobConfig
1212
from torchtitan.distributed import ParallelDims
1313
from torchtitan.protocols.model_converter import (
@@ -21,8 +21,8 @@
2121

2222
def replace_turbo_mxlinear_modules(model: nn.Module, config: Float8QuantConfig):
2323
for name, module in model.named_children():
24-
if isinstance(module, torch.nn.Linear) and not isinstance(module, MXLinear):
25-
mx_linear = MXLinear.from_float(module, config)
24+
if isinstance(module, torch.nn.Linear) and not isinstance(module, Float8Linear):
25+
mx_linear = Float8Linear.from_float(module, config)
2626
setattr(model, name, mx_linear)
2727
else:
2828
replace_turbo_mxlinear_modules(module, config)

0 commit comments

Comments
 (0)