File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed
primus/backends/torchtitan/components/quantization Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change 1010 pull_request :
1111
1212env :
13- PRIMUS_TURBO_COMMIT : 0385cdd615cb4eff28a1cbbf3583fccf95d11fe9 # chore: refactor grouped gemm blockwise python code (#142 )
13+ PRIMUS_TURBO_COMMIT : 7ecd1edc31a3f1607e92b536b4250f9d98bfe423 # feat: unify fp8 gemm (#148 )
1414
1515jobs :
1616 code-lint :
Original file line number Diff line number Diff line change 77import torch
88import torch .nn as nn
99from primus_turbo .pytorch .core .float8 import Float8QuantConfig , ScalingGranularity
10- from primus_turbo .pytorch .modules import MXLinear
10+ from primus_turbo .pytorch .modules import Float8Linear
1111from torchtitan .config .job_config import JobConfig
1212from torchtitan .distributed import ParallelDims
1313from torchtitan .protocols .model_converter import (
2121
2222def replace_turbo_mxlinear_modules (model : nn .Module , config : Float8QuantConfig ):
2323 for name , module in model .named_children ():
24- if isinstance (module , torch .nn .Linear ) and not isinstance (module , MXLinear ):
25- mx_linear = MXLinear .from_float (module , config )
24+ if isinstance (module , torch .nn .Linear ) and not isinstance (module , Float8Linear ):
25+ mx_linear = Float8Linear .from_float (module , config )
2626 setattr (model , name , mx_linear )
2727 else :
2828 replace_turbo_mxlinear_modules (module , config )
You can’t perform that action at this time.
0 commit comments