Skip to content

Commit ccb8eb8

Browse files
authored
Add checkpoint sharding unit tests (#2561)
* added checkpopint sharding tests
1 parent 591744e commit ccb8eb8

File tree

4 files changed

+100
-5
lines changed

4 files changed

+100
-5
lines changed

deepspeed/module_inject/replace_module.py

100755100644
+3-3
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def replace_fn(child, _policy, layer_id=0):
10221022
if transformer_name not in k
10231023
}),
10241024
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
1025-
new_config = json.dumps({
1025+
ckpt_config = json.dumps({
10261026
'type':
10271027
ckpt_name,
10281028
'base_dir':
@@ -1044,9 +1044,9 @@ def replace_fn(child, _policy, layer_id=0):
10441044
'dtype':
10451045
'int8' if quantize else ('float16' if fp16 else 'float32')
10461046
})
1047-
with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json",
1047+
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json",
10481048
"w") as cfg:
1049-
cfg.write(new_config)
1049+
cfg.write(ckpt_config)
10501050

10511051
rep_sd = replaced_module.state_dict()
10521052
for n, p in replaced_module.named_parameters():

deepspeed/module_inject/replace_policy.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def __init__(
112112
self.is_megatron_v2 = megatron_v2
113113
self.mlp_act_func_type = mlp_act_func_type
114114
self.pre_attn_norm = pre_attn_norm
115-
self.load_prefix = False
115+
self.use_load_prefix = use_load_prefix
116+
self.split_qkv = split_qkv
116117

117118
def attention(self):
118119
"""

deepspeed/runtime/state_dict_factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_sd_loader_json(json_file, checkpoint_engine):
3131
version = data['version']
3232
ckpt_type = data.get('parallelization', 'pp')
3333
mp_size = data.get('mp_size', 0)
34-
if 'bloom' in sd_type.lower():
34+
if sd_type.lower() in ['bloom', 'ds_model']:
3535
return data
3636
return SDLoaderFactory.get_sd_loader(ckpt_list,
3737
checkpoint_engine,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import pytest
3+
import torch
4+
import deepspeed
5+
from deepspeed.model_implementations import DeepSpeedTransformerInference
6+
from transformers import AutoConfig, AutoModelForCausalLM
7+
from unit.common import DistributedTest, DistributedFixture
8+
9+
10+
def check_dtype(model, expected_dtype):
11+
def find_dtype(module):
12+
for child in module.children():
13+
if isinstance(child, DeepSpeedTransformerInference):
14+
return child.attention.attn_qkvw.dtype
15+
else:
16+
found_dtype = find_dtype(child)
17+
if found_dtype:
18+
return found_dtype
19+
20+
found_dtype = find_dtype(model)
21+
assert found_dtype, "Did not find DeepSpeedTransformerInference in model"
22+
assert (
23+
found_dtype == expected_dtype
24+
), f"Expected transformer dtype {expected_dtype}, but found {found_dtype}"
25+
26+
27+
@pytest.fixture(params=[
28+
"bigscience/bloom-560m",
29+
"EleutherAI/gpt-j-6B",
30+
"EleutherAI/gpt-neo-125M",
31+
"facebook/opt-125m"
32+
])
33+
def model_name(request):
34+
return request.param
35+
36+
37+
@pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"])
38+
def dtype(request):
39+
return request.param
40+
41+
42+
class save_shard(DistributedFixture):
43+
world_size = 2
44+
45+
def run(self, model_name, class_tmpdir):
46+
# Only write a checkpoint if one does not exist
47+
if not os.path.isdir(os.path.join(class_tmpdir, model_name)):
48+
world_size = int(os.getenv("WORLD_SIZE", "1"))
49+
inf_config = {
50+
"replace_with_kernel_inject": True,
51+
"dtype": torch.float16,
52+
"replace_method": "auto",
53+
"enable_cuda_graph": False,
54+
"tensor_parallel": {
55+
"tp_size": world_size
56+
},
57+
"save_mp_checkpoint_path": os.path.join(str(class_tmpdir),
58+
model_name),
59+
}
60+
61+
# Load model and save sharded checkpoint
62+
model = AutoModelForCausalLM.from_pretrained(model_name,
63+
torch_dtype=torch.float16)
64+
model = deepspeed.init_inference(model, config=inf_config)
65+
66+
67+
@pytest.mark.seq_inference
68+
class TestCheckpointShard(DistributedTest):
69+
world_size = 2
70+
71+
def test(self, model_name, dtype, class_tmpdir, save_shard):
72+
world_size = int(os.getenv("WORLD_SIZE", "1"))
73+
inf_config = {
74+
"replace_with_kernel_inject": True,
75+
"dtype": dtype,
76+
"replace_method": "auto",
77+
"enable_cuda_graph": False,
78+
"tensor_parallel": {
79+
"tp_size": world_size
80+
},
81+
"checkpoint": os.path.join(class_tmpdir,
82+
model_name,
83+
"ds_inference_config.json"),
84+
}
85+
86+
# Load model on meta tensors
87+
model_config = AutoConfig.from_pretrained(model_name)
88+
# Note that we use half precision to load initially, even for int8
89+
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
90+
model = AutoModelForCausalLM.from_config(model_config,
91+
torch_dtype=torch.bfloat16)
92+
model = model.eval()
93+
model = deepspeed.init_inference(model, config=inf_config)
94+
check_dtype(model, dtype)

0 commit comments

Comments
 (0)