Skip to content

Moving to CUDA after safe_open is slow #672

@francois-rozet

Description

@francois-rozet

System Info

Depending on the file system, moving safe-tensors to CUDA is 10x slower than first cloning the tensors then moving them. This issue has heavy implications for other HF libraries (see huggingface/diffusers#12599).

In this case my file system (scratch) is BeeGFS, which is very common in HPC clusters.

Information

  • The official example scripts
  • My own modified scripts

Reproduction

First create a safetensors file on your filesystem.

import torch
import time

from safetensors import safe_open
from safetensors.torch import save_file, load

weights = {}

for i in range(7):
    weights[f"weight.{i}"] = torch.randn((1024, 1024 + i))

save_file(weights, "scratch/model.safetensors")

Moving the tensors to CUDA after reading them is slow.

%%timeit

weights = {}

with safe_open("scratch/model.safetensors", framework="pt", device="cpu") as f:
    for k in f.keys():
        weights[k] = f.get_tensor(k)

temp = [w.cuda() for w in weights.values()]

torch.cuda.synchronize()
# 903 ms ± 7.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Cloning the tensors then moving them to CUDA is 10x faster.

%%timeit

weights = {}

with safe_open("scratch/model.safetensors", framework="pt", device="cpu") as f:
    for k in f.keys():
        weights[k] = f.get_tensor(k)

temp = [w.clone().cuda() for w in weights.values()]  # clone then move

torch.cuda.synchronize()
# 138 ms ± 1.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In both cases, loading all tensors at once (no memory mapping) is faster.

%%timeit

weights = load(open("scratch/model.safetensors", "rb").read())

temp = [w.cuda() for w in weights.values()]

torch.cuda.synchronize()
# 31.7 ms ± 460 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Expected behavior

Moving to CUDA should not be slower than first cloning then moving.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions