forked from NVIDIA/physicsnemo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsharded_conv.py
More file actions
135 lines (102 loc) · 4.55 KB
/
sharded_conv.py
File metadata and controls
135 lines (102 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.distributed.tensor import (
Shard,
distribute_module,
)
from physicsnemo.distributed import DistributedManager
from physicsnemo.domain_parallel import ShardTensor, scatter_tensor
DistributedManager.initialize()
dm = DistributedManager()
###########################
# Single GPU - Create input
###########################
original_tensor = torch.randn(1, 8, 1024, 1024, device=dm.device, requires_grad=True)
###########################################
# Single GPU - Create a single-layer model:
###########################################
conv = torch.nn.Conv2d(8, 8, 3, stride=1, padding=1).to(dm.device)
########################################
# Single GPU - forward + loss + backward
########################################
single_gpu_output = conv(original_tensor)
# This isn't really a loss, just a pretend one that's scalar!
single_gpu_output.mean().backward()
# Copy the gradients produced here - so we don't overwrite them later.
original_tensor_grad = original_tensor.grad.data.clone()
####################
# Single GPU - DONE!
####################
#################
# Sharded - Setup
#################
# DeviceMesh is a pytorch object - you can initialize it directly, or for added
# flexibility physicsnemo can infer up to one mesh dimension for you
# (as a -1, like in a tensor.reshape() call...)
mesh = dm.initialize_mesh(mesh_shape=(-1,), mesh_dim_names=("domain_parallel",))
# A mesh, by the way, refers to devices and not data: it's a mesh of connected
# GPUs in this case, and the python DeviceMesh can be reused as many times as needed.
# That said, it can be decomposed similar to a tensor - multiple mesh axes, and
# you can axis sub-meshes. Each mesh also has ways to access process groups
# for targeted collectives.
###########################
# Sharded - Distribute Data
###########################
# This is now a tensor across all GPUs, spread on the "height" dimension == 2
# In general, to create a ShardTensor (or DTensor) you need to specify placements.
# Placements must be a list or tuple of `Shard()` or `Replicate()` objects
# from torch.distributed.tensor.
#
# Each index in the tuple represents the placement over the corresponding mesh dimension
# (so, mesh.ndim == len(placements)! )
# `Shard()` takes an argument representing the **tensor** index that is sharded.
# So below, the tensor is sharded over the tensor dimension 2 on the mesh dimension 0.
sharded_tensor = scatter_tensor(
original_tensor, 0, mesh, (Shard(2),), requires_grad=True
)
################################
# Sharded - distribute the model
################################
# We tell pytorch that the convolution will work on distributed tensors:
# And, over the same mesh!
distributed_conv = distribute_module(conv, mesh)
#####################################
# Sharded - forward + loss + backward
#####################################
# Now, we can do the distributed convolution:
sharded_output = distributed_conv(sharded_tensor)
sharded_output.mean().backward()
############################################
# Sharded - gather up outputs to all devices
############################################
# This triggers a collective allgather.
full_output = sharded_output.full_tensor()
full_grad = sharded_tensor.grad.full_tensor()
#################
# Accuracy Checks
#################
if dm.rank == 0:
# Only check on rank 0 because we used it's data and weights for the sharded tensor.
# Check that the output is the same as the single-device output:
assert torch.allclose(full_output, single_gpu_output, atol=1e-3, rtol=1e-3)
print(f"Global operation matches local! ")
# Check that the gradient is correct:
assert torch.allclose(original_tensor_grad, full_grad)
print(f"Gradient check passed!")
print(
f"Distributed grad sharding and local shape: {sharded_tensor.grad._spec.placements}, {sharded_tensor.grad.to_local().shape}"
)