-
Notifications
You must be signed in to change notification settings - Fork 284
Open
Description
System Info
Depending on the file system, moving safe-tensors to CUDA is 10x slower than first cloning the tensors then moving them. This issue has heavy implications for other HF libraries (see huggingface/diffusers#12599).
In this case my file system (scratch) is BeeGFS, which is very common in HPC clusters.
Information
- The official example scripts
- My own modified scripts
Reproduction
First create a safetensors file on your filesystem.
import torch
import time
from safetensors import safe_open
from safetensors.torch import save_file, load
weights = {}
for i in range(7):
weights[f"weight.{i}"] = torch.randn((1024, 1024 + i))
save_file(weights, "scratch/model.safetensors")Moving the tensors to CUDA after reading them is slow.
%%timeit
weights = {}
with safe_open("scratch/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
weights[k] = f.get_tensor(k)
temp = [w.cuda() for w in weights.values()]
torch.cuda.synchronize()
# 903 ms ± 7.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)Cloning the tensors then moving them to CUDA is 10x faster.
%%timeit
weights = {}
with safe_open("scratch/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
weights[k] = f.get_tensor(k)
temp = [w.clone().cuda() for w in weights.values()] # clone then move
torch.cuda.synchronize()
# 138 ms ± 1.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)In both cases, loading all tensors at once (no memory mapping) is faster.
%%timeit
weights = load(open("scratch/model.safetensors", "rb").read())
temp = [w.cuda() for w in weights.values()]
torch.cuda.synchronize()
# 31.7 ms ± 460 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)Expected behavior
Moving to CUDA should not be slower than first cloning then moving.
Metadata
Metadata
Assignees
Labels
No labels