Skip to content

Commit 2827f41

Browse files
authored
[Gemini] GeminiDPP convert to PyTorch Module. (#2151)
1 parent bdef9df commit 2827f41

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

colossalai/nn/parallel/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.distributed as dist
33

44
from colossalai.gemini.chunk import Chunk
5+
from colossalai.tensor import ColoTensor
56
from colossalai.utils import get_current_device
67

78

@@ -19,3 +20,30 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
1920
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
2021

2122
return total_temp
23+
24+
25+
def _add_param(model, name, param):
26+
name_list = name.split('.')
27+
module = model._modules[name_list[0]]
28+
for i in range(1, len(name_list) - 1):
29+
module = module._modules[name_list[i]]
30+
module._parameters[name_list[-1]] = param
31+
32+
33+
def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
34+
"""convert_to_torch_module
35+
36+
Args:
37+
gemini_ddp_model (GeminiDDP): a gemini ddp model
38+
39+
Returns:
40+
torch.nn.Module: a torch model contains the params of gemini_ddp_model
41+
"""
42+
module = gemini_ddp_model.module
43+
44+
for n, p in module.named_parameters():
45+
if isinstance(p, ColoTensor):
46+
p.to_replicate_()
47+
_add_param(module, n, p.data)
48+
49+
return module

colossalai/tensor/colo_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) ->
103103
self.process_group = spec.pg
104104

105105
self._type = TensorType.NONMODEL
106-
self._graph_node = None
107106

108107
def has_compute_spec(self) -> bool:
109108
return self.compute_spec is not None
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from functools import partial
2+
3+
import pytest
4+
import torch.multiprocessing as mp
5+
6+
import colossalai
7+
from colossalai.nn.parallel.utils import convert_to_torch_module
8+
from colossalai.tensor import ColoTensor
9+
from colossalai.testing import parameterize, rerun_if_address_is_in_use
10+
from colossalai.utils import free_port
11+
from colossalai.utils.cuda import get_current_device
12+
from colossalai.utils.model.colo_init_context import ColoInitContext
13+
from tests.components_to_test.registry import non_distributed_component_funcs
14+
15+
16+
@parameterize('model_name', ['resnet18', 'bert'])
17+
def run_convert_torch_module(model_name: str):
18+
get_components_func = non_distributed_component_funcs.get_callable(model_name)
19+
model_builder, _, _, _, _ = get_components_func()
20+
21+
with ColoInitContext(device='cpu'):
22+
model = model_builder(checkpoint=False)
23+
24+
from colossalai.nn.parallel import GeminiDDP
25+
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
26+
27+
pytorch_model = convert_to_torch_module(model)
28+
29+
for n, p in pytorch_model.named_parameters():
30+
assert not isinstance(p, ColoTensor)
31+
32+
33+
def run_dist(rank, world_size, port):
34+
config = {}
35+
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
36+
run_convert_torch_module()
37+
38+
39+
@pytest.mark.dist
40+
@pytest.mark.parametrize('world_size', [1, 4])
41+
@rerun_if_address_is_in_use()
42+
def test_convert_torch_module(world_size):
43+
run_func = partial(run_dist, world_size=world_size, port=free_port())
44+
mp.spawn(run_func, nprocs=world_size)
45+
46+
47+
if __name__ == '__main__':
48+
test_convert_torch_module(2)

0 commit comments

Comments
 (0)