Skip to content

Commit d42afd3

Browse files
authored
[builder] runtime adam and fused_optim builder (#2184)
1 parent 550f8f8 commit d42afd3

File tree

7 files changed

+205
-9
lines changed

7 files changed

+205
-9
lines changed

colossalai/amp/naive_amp/_fp16_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import torch.distributed as dist
66

77
try:
8-
import colossalai._C.fused_optim
8+
from colossalai._C import fused_optim
99
except:
1010
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
11+
from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder
12+
fused_optim = FusedOptimBuilder().load()
1113

1214
from torch.distributed import ProcessGroup
1315
from torch.optim import Optimizer
@@ -35,7 +37,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
3537
if overflow_buf:
3638
overflow_buf.fill_(0)
3739
# Scaling with factor `1.0` is equivalent to copy.
38-
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
40+
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
3941
else:
4042
for this_, that_ in zip(this, that):
4143
that_.copy_(this_)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .cpu_adam import CPUAdamBuilder
2+
from .fused_optim import FusedOptimBuilder
3+
4+
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder']
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
6+
class Builder(object):
7+
8+
def colossalai_src_path(self, code_path):
9+
if os.path.isabs(code_path):
10+
return code_path
11+
else:
12+
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
13+
14+
def strip_empty_entries(self, args):
15+
'''
16+
Drop any empty strings from the list of compile and link flags
17+
'''
18+
return [x for x in args if len(x) > 0]
19+
20+
def load(self, verbose=True):
21+
"""
22+
23+
load and compile cpu_adam lib at runtime
24+
25+
Args:
26+
verbose (bool, optional): show detailed info. Defaults to True.
27+
"""
28+
import time
29+
30+
from torch.utils.cpp_extension import load
31+
start_build = time.time()
32+
33+
op_module = load(name=self.name,
34+
sources=self.strip_empty_entries(self.sources),
35+
extra_include_paths=self.strip_empty_entries(self.extra_include_paths),
36+
extra_cflags=self.extra_cxx_flags,
37+
extra_cuda_cflags=self.extra_cuda_flags,
38+
extra_ldflags=[],
39+
verbose=verbose)
40+
41+
build_duration = time.time() - start_build
42+
if verbose:
43+
print(f"Time to load {self.name} op: {build_duration} seconds")
44+
45+
return op_module
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
from .builder import Builder
6+
7+
8+
class CPUAdamBuilder(Builder):
9+
NAME = "cpu_adam"
10+
BASE_DIR = "cuda_native"
11+
12+
def __init__(self):
13+
self.name = CPUAdamBuilder.NAME
14+
super().__init__()
15+
16+
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
17+
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
18+
self.extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
19+
self.extra_cuda_flags = [
20+
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
21+
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
22+
]
23+
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
24+
25+
def sources_files(self):
26+
return [
27+
os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"),
28+
]
29+
30+
def include_paths(self):
31+
import torch
32+
from torch.utils.cpp_extension import CUDA_HOME
33+
cuda_include = os.path.join(CUDA_HOME, "include")
34+
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), cuda_include]
35+
36+
def colossalai_src_path(self, code_path):
37+
if os.path.isabs(code_path):
38+
return code_path
39+
else:
40+
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
41+
42+
def strip_empty_entries(self, args):
43+
'''
44+
Drop any empty strings from the list of compile and link flags
45+
'''
46+
return [x for x in args if len(x) > 0]
47+
48+
def builder(self):
49+
from torch.utils.cpp_extension import CUDAExtension
50+
return CUDAExtension(
51+
name=self.name,
52+
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
53+
include_dirs=self.extra_include_paths,
54+
extra_compile_args={
55+
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
56+
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
57+
})
58+
59+
def load(self, verbose=True):
60+
"""
61+
62+
load and compile cpu_adam lib at runtime
63+
64+
Args:
65+
verbose (bool, optional): show detailed info. Defaults to True.
66+
"""
67+
import time
68+
69+
from torch.utils.cpp_extension import load
70+
start_build = time.time()
71+
72+
op_module = load(name=self.name,
73+
sources=self.strip_empty_entries(self.sources),
74+
extra_include_paths=self.strip_empty_entries(self.extra_include_paths),
75+
extra_cflags=self.extra_cxx_flags,
76+
extra_cuda_cflags=self.extra_cuda_flags,
77+
extra_ldflags=[],
78+
verbose=verbose)
79+
80+
build_duration = time.time() - start_build
81+
if verbose:
82+
print(f"Time to load {self.name} op: {build_duration} seconds")
83+
84+
return op_module
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
import re
3+
4+
import torch
5+
6+
from .builder import Builder
7+
8+
9+
class FusedOptimBuilder(Builder):
10+
NAME = "fused_optim"
11+
BASE_DIR = "cuda_native/csrc"
12+
13+
def __init__(self):
14+
self.name = FusedOptimBuilder.NAME
15+
super().__init__()
16+
17+
self.extra_cxx_flags = []
18+
self.extra_cuda_flags = ['-lineinfo']
19+
for arch in torch.cuda.get_arch_list():
20+
res = re.search(r'sm_(\d+)', arch)
21+
if res:
22+
arch_cap = res[1]
23+
if int(arch_cap) >= 60:
24+
self.extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
25+
26+
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
27+
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
28+
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
29+
30+
def sources_files(self):
31+
return [
32+
os.path.join(FusedOptimBuilder.BASE_DIR, fname) for fname in [
33+
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
34+
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
35+
]
36+
]
37+
38+
def include_paths(self):
39+
import torch
40+
from torch.utils.cpp_extension import CUDA_HOME
41+
cuda_include = os.path.join(CUDA_HOME, "include")
42+
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), cuda_include]
43+
44+
def builder(self):
45+
from torch.utils.cpp_extension import CUDAExtension
46+
return CUDAExtension(
47+
name=self.name,
48+
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
49+
include_dirs=self.extra_include_paths,
50+
extra_compile_args={
51+
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
52+
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
53+
})

colossalai/nn/optimizer/hybrid_adam.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ def __init__(self,
7777
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
7878
self.adamw_mode = adamw_mode
7979
try:
80-
import colossalai._C.cpu_optim
81-
import colossalai._C.fused_optim
80+
from colossalai._C import cpu_optim, fused_optim
8281
except ImportError:
83-
raise ImportError('Please install colossalai from source code to use HybridAdam')
82+
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
83+
fused_optim = FusedOptimBuilder().load()
84+
cpu_optim = CPUAdamBuilder().load()
8485

85-
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
86-
adamw_mode)
86+
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
8787

88-
self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam
88+
self.gpu_adam_op = fused_optim.multi_tensor_adam
8989
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
9090

9191
@torch.no_grad()

tests/test_optimizer/test_cpu_adam.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
6969
try:
7070
import colossalai._C.cpu_optim
7171
cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
72+
print("use prebuilt CPUAdamOptimizer")
7273
except:
73-
raise ImportError("Import cpu adam error, please install colossal from source code")
74+
from colossalai.kernel.op_builder.cpu_adam import CPUAdamBuilder
75+
lib = CPUAdamBuilder().load()
76+
cpu_adam_op = lib.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
77+
print("build CPUAdamOptimizer at runtime")
7478

7579
cpu_adam_op.step(
7680
step,
@@ -115,3 +119,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
115119
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
116120
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
117121
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
122+
123+
124+
if __name__ == '__main__':
125+
test_cpu_adam()

0 commit comments

Comments
 (0)