Skip to content

Commit 138d24d

Browse files
committed
Add check for quantization FP8
1 parent 49a0ce1 commit 138d24d

2 files changed

Lines changed: 10 additions & 1 deletion

File tree

gllm/layers/linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from gllm.dist_utils import (get_tp_size, get_tp_rank, divide,
88
split_tensor_along_last_dim, tensor_model_parallel_all_reduce)
99
from gllm.layers.quantization.fp8 import fp8LinearMethod
10+
from gllm.utils import get_device_capability
1011

1112
class LinearBase(torch.nn.Module):
1213
"""Base linear layer.
@@ -53,6 +54,8 @@ def create_weights(self,
5354
requires_grad=False)
5455
self.register_parameter('weight', weight)
5556
elif self.quant_config['quant_method'] == 'fp8':
57+
if get_device_capability() < 89:
58+
raise Exception(f'FP8 quantizaiton method is not supported on device capability less than 89 (current is {get_device_capability()})')
5659
self.activation_scheme = self.quant_config['activation_scheme']
5760
self.block_quant = 'weight_block_size' in self.quant_config
5861
if self.block_quant:

gllm/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99
import logging
1010
import tqdm
11+
import torch
1112

1213
from logger import logger
1314
from functools import partial
@@ -167,4 +168,9 @@ def get_dtype_bytes(dtype):
167168
info = torch.finfo(dtype)
168169
else:
169170
info = torch.iinfo(dtype)
170-
return info.bits // 8 # bits => bytes
171+
return info.bits // 8 # bits => bytes
172+
173+
def get_device_capability():
174+
device = torch.cuda.current_device()
175+
capability_arr = torch.cuda.get_device_capability(device)
176+
return capability_arr[0]*10 + capability_arr[1]

0 commit comments

Comments
 (0)