File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed
Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -191,6 +191,8 @@ class TPMeta(BaseModel):
191191
192192
193193def _inplace_pin_memory (files : list [str ], rank : int | None = None ) -> list [MemoryBuffer ]:
194+ device_index = torch .cuda .current_device ()
195+
194196 def _parse_and_pin_from_safetensors (file_path : str ) -> MemoryBuffer :
195197 """
196198 safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
@@ -204,6 +206,7 @@ def _pin(t: torch.Tensor):
204206 Pin the memory of tensor in-place.
205207 See: https://github.com/pytorch/pytorch/issues/32167
206208 """
209+ torch .cuda .set_device (device_index )
207210 cudart = torch .cuda .cudart ()
208211 r = cudart .cudaHostRegister (t .data_ptr (), t .numel () * t .element_size (), 0 )
209212 assert r == 0 , f"pin memory error, error code: { r } "
Original file line number Diff line number Diff line change 1414from loguru import logger
1515from safetensors import safe_open
1616
17- from checkpoint_engine .ps import ParameterServer , request_inference_to_update
17+ from checkpoint_engine import request_inference_to_update
18+ from checkpoint_engine .ps import ParameterServer
1819
1920
2021@contextmanager
You can’t perform that action at this time.
0 commit comments