|
| 1 | +import os |
| 2 | +import pytest |
| 3 | +import torch |
| 4 | +import deepspeed |
| 5 | +from deepspeed.model_implementations import DeepSpeedTransformerInference |
| 6 | +from transformers import AutoConfig, AutoModelForCausalLM |
| 7 | +from unit.common import DistributedTest, DistributedFixture |
| 8 | + |
| 9 | + |
| 10 | +def check_dtype(model, expected_dtype): |
| 11 | + def find_dtype(module): |
| 12 | + for child in module.children(): |
| 13 | + if isinstance(child, DeepSpeedTransformerInference): |
| 14 | + return child.attention.attn_qkvw.dtype |
| 15 | + else: |
| 16 | + found_dtype = find_dtype(child) |
| 17 | + if found_dtype: |
| 18 | + return found_dtype |
| 19 | + |
| 20 | + found_dtype = find_dtype(model) |
| 21 | + assert found_dtype, "Did not find DeepSpeedTransformerInference in model" |
| 22 | + assert ( |
| 23 | + found_dtype == expected_dtype |
| 24 | + ), f"Expected transformer dtype {expected_dtype}, but found {found_dtype}" |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture(params=[ |
| 28 | + "bigscience/bloom-560m", |
| 29 | + "EleutherAI/gpt-j-6B", |
| 30 | + "EleutherAI/gpt-neo-125M", |
| 31 | + "facebook/opt-125m" |
| 32 | +]) |
| 33 | +def model_name(request): |
| 34 | + return request.param |
| 35 | + |
| 36 | + |
| 37 | +@pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"]) |
| 38 | +def dtype(request): |
| 39 | + return request.param |
| 40 | + |
| 41 | + |
| 42 | +class save_shard(DistributedFixture): |
| 43 | + world_size = 2 |
| 44 | + |
| 45 | + def run(self, model_name, class_tmpdir): |
| 46 | + # Only write a checkpoint if one does not exist |
| 47 | + if not os.path.isdir(os.path.join(class_tmpdir, model_name)): |
| 48 | + world_size = int(os.getenv("WORLD_SIZE", "1")) |
| 49 | + inf_config = { |
| 50 | + "replace_with_kernel_inject": True, |
| 51 | + "dtype": torch.float16, |
| 52 | + "replace_method": "auto", |
| 53 | + "enable_cuda_graph": False, |
| 54 | + "tensor_parallel": { |
| 55 | + "tp_size": world_size |
| 56 | + }, |
| 57 | + "save_mp_checkpoint_path": os.path.join(str(class_tmpdir), |
| 58 | + model_name), |
| 59 | + } |
| 60 | + |
| 61 | + # Load model and save sharded checkpoint |
| 62 | + model = AutoModelForCausalLM.from_pretrained(model_name, |
| 63 | + torch_dtype=torch.float16) |
| 64 | + model = deepspeed.init_inference(model, config=inf_config) |
| 65 | + |
| 66 | + |
| 67 | +@pytest.mark.seq_inference |
| 68 | +class TestCheckpointShard(DistributedTest): |
| 69 | + world_size = 2 |
| 70 | + |
| 71 | + def test(self, model_name, dtype, class_tmpdir, save_shard): |
| 72 | + world_size = int(os.getenv("WORLD_SIZE", "1")) |
| 73 | + inf_config = { |
| 74 | + "replace_with_kernel_inject": True, |
| 75 | + "dtype": dtype, |
| 76 | + "replace_method": "auto", |
| 77 | + "enable_cuda_graph": False, |
| 78 | + "tensor_parallel": { |
| 79 | + "tp_size": world_size |
| 80 | + }, |
| 81 | + "checkpoint": os.path.join(class_tmpdir, |
| 82 | + model_name, |
| 83 | + "ds_inference_config.json"), |
| 84 | + } |
| 85 | + |
| 86 | + # Load model on meta tensors |
| 87 | + model_config = AutoConfig.from_pretrained(model_name) |
| 88 | + # Note that we use half precision to load initially, even for int8 |
| 89 | + with deepspeed.OnDevice(dtype=torch.float16, device="meta"): |
| 90 | + model = AutoModelForCausalLM.from_config(model_config, |
| 91 | + torch_dtype=torch.bfloat16) |
| 92 | + model = model.eval() |
| 93 | + model = deepspeed.init_inference(model, config=inf_config) |
| 94 | + check_dtype(model, dtype) |
0 commit comments