Skip to content

Commit 4a73109

Browse files
authored
fix: set current CUDA device in _inplace_pin_memory function (MoonshotAI#77)
1 parent 84e99ad commit 4a73109

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

checkpoint_engine/pin_memory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ class TPMeta(BaseModel):
191191

192192

193193
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
194+
device_index = torch.cuda.current_device()
195+
194196
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
195197
"""
196198
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
@@ -204,6 +206,7 @@ def _pin(t: torch.Tensor):
204206
Pin the memory of tensor in-place.
205207
See: https://github.com/pytorch/pytorch/issues/32167
206208
"""
209+
torch.cuda.set_device(device_index)
207210
cudart = torch.cuda.cudart()
208211
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
209212
assert r == 0, f"pin memory error, error code: {r}"

examples/update.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from loguru import logger
1515
from safetensors import safe_open
1616

17-
from checkpoint_engine.ps import ParameterServer, request_inference_to_update
17+
from checkpoint_engine import request_inference_to_update
18+
from checkpoint_engine.ps import ParameterServer
1819

1920

2021
@contextmanager

0 commit comments

Comments
 (0)