Skip to content

Commit 157e88b

Browse files
duduyi2013facebook-github-bot
authored andcommitted
shardTensor metadata recalc after checkpoint state_dict (pytorch#4146)
Summary: Pull Request resolved: pytorch#4146 X-link: facebookresearch/FBGEMM#1227 Add ST metadata recalc into the DistributedCheckpointWrapper, so that all the state_dict calls will invoke recalc virtual PMT. Reviewed By: pradeepfn Differential Revision: D73567632 fbshipit-source-id: 40dbdaa6f51a0d58dcb48008688c1a9a3c8939e3
1 parent f257102 commit 157e88b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def is_pinned(self):
159159

160160
@property
161161
def dtype(self) -> torch.dtype:
162+
if isinstance(self._wrapped, torch.Tensor):
163+
return self._wrapped.dtype
162164
mapping = {"c10::Half": "half"}
163165
dtype_str: str = self._wrapped.dtype_str
164166
dtype_str = mapping.get(dtype_str, dtype_str)
@@ -169,14 +171,17 @@ def dtype(self) -> torch.dtype:
169171

170172
@property
171173
def device(self) -> torch.device:
174+
if isinstance(self._wrapped, torch.Tensor):
175+
return self._wrapped.device
172176
device_str: str = self._wrapped.device_str
173177
device = torch.device(device_str)
174178
assert isinstance(device, torch.device)
175179
return device
176180

177181
@property
178182
def layout(self) -> torch.layout:
179-
pass
183+
if isinstance(self._wrapped, torch.Tensor):
184+
return self._wrapped.layout
180185
layout_str_mapping = {
181186
"SparseCsr": "sparse_csr",
182187
"Strided": "strided",

0 commit comments

Comments
 (0)