Skip to content

Commit f2710bb

Browse files
Make data contiguous before the inplace reshape-copy_ function (#2489)
Co-authored-by: Michael Wyatt <[email protected]>
1 parent be5ec50 commit f2710bb

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

deepspeed/module_inject/load_checkpoint.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def load_model_with_checkpoint(r_module,
1515
error_msgs = []
1616

1717
def transpose(data):
18+
data = data.contiguous()
1819
data1 = data.transpose(-1, -2).reshape(-1)
1920
data.reshape(-1).copy_(data1)
2021
data1 = None

deepspeed/module_inject/replace_module.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def replace_attn(child, policy):
215215
attn_module = transformer_inference.DeepSpeedDiffusersAttention(config)
216216

217217
def transpose(data):
218+
data = data.contiguous()
218219
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
219220
data = data.reshape(data.shape[-1], data.shape[-2])
220221
data.to(torch.cuda.current_device())
@@ -531,7 +532,7 @@ def replace_with_policy(child,
531532
# transpose it here to reduce inference cost!
532533
def transpose(data):
533534
# temp move to cpu to avoid requiring extra GPU memory during the reshape
534-
data = data.to('cpu')
535+
data = data.to('cpu').contiguous()
535536
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
536537
data = data.reshape(data.shape[-1], data.shape[-2])
537538
data.to(torch.cuda.current_device())

0 commit comments

Comments
 (0)