File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 77from gllm .dist_utils import (get_tp_size , get_tp_rank , divide ,
88 split_tensor_along_last_dim , tensor_model_parallel_all_reduce )
99from gllm .layers .quantization .fp8 import fp8LinearMethod
10+ from gllm .utils import get_device_capability
1011
1112class LinearBase (torch .nn .Module ):
1213 """Base linear layer.
@@ -53,6 +54,8 @@ def create_weights(self,
5354 requires_grad = False )
5455 self .register_parameter ('weight' , weight )
5556 elif self .quant_config ['quant_method' ] == 'fp8' :
57+ if get_device_capability () < 89 :
58+ raise Exception (f'FP8 quantizaiton method is not supported on device capability less than 89 (current is { get_device_capability ()} )' )
5659 self .activation_scheme = self .quant_config ['activation_scheme' ]
5760 self .block_quant = 'weight_block_size' in self .quant_config
5861 if self .block_quant :
Original file line number Diff line number Diff line change 88import tempfile
99import logging
1010import tqdm
11+ import torch
1112
1213from logger import logger
1314from functools import partial
@@ -167,4 +168,9 @@ def get_dtype_bytes(dtype):
167168 info = torch .finfo (dtype )
168169 else :
169170 info = torch .iinfo (dtype )
170- return info .bits // 8 # bits => bytes
171+ return info .bits // 8 # bits => bytes
172+
173+ def get_device_capability ():
174+ device = torch .cuda .current_device ()
175+ capability_arr = torch .cuda .get_device_capability (device )
176+ return capability_arr [0 ]* 10 + capability_arr [1 ]
You can’t perform that action at this time.
0 commit comments