|
4 | 4 | import os |
5 | 5 | import pickle |
6 | 6 | import random |
7 | | -import socket |
8 | 7 | import threading |
9 | 8 | import time |
10 | 9 | from collections import defaultdict |
11 | 10 | from collections.abc import Callable |
12 | 11 | from datetime import timedelta |
13 | | -from functools import lru_cache |
14 | 12 | from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple |
15 | 13 |
|
16 | 14 | import httpx |
|
23 | 21 | from safetensors.torch import safe_open |
24 | 22 | from torch.multiprocessing.reductions import reduce_tensor |
25 | 23 |
|
26 | | -from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid |
| 24 | +from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid |
27 | 25 |
|
28 | 26 |
|
29 | 27 | if TYPE_CHECKING: |
@@ -261,21 +259,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None |
261 | 259 | raise ValueError(f"fail to get physical gpu id {device_index}") from e |
262 | 260 |
|
263 | 261 |
|
264 | | -@lru_cache(maxsize=1) |
265 | | -def _get_ip() -> str: |
266 | | - try: |
267 | | - # try to get ip from network interface |
268 | | - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: |
269 | | - s.connect(("8.8.8.8", 80)) |
270 | | - return s.getsockname()[0] |
271 | | - except Exception as e: # noqa: BLE001 |
272 | | - # fallback to get ip from hostname |
273 | | - logger.warning( |
274 | | - f"fail to get ip from network interface, fallback to get ip from hostname: {e}" |
275 | | - ) |
276 | | - return socket.gethostbyname(socket.gethostname()) |
277 | | - |
278 | | - |
279 | 262 | def _ibv_get_device_list() -> list[str]: |
280 | 263 | lib = ctypes.CDLL("libibverbs.so.1") |
281 | 264 | lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices |
@@ -600,7 +583,7 @@ def __init__(self, device_manager: DeviceManager): |
600 | 583 | gpu_count = device_manager.device_module.device_count() |
601 | 584 | local_rank = self.rank % gpu_count |
602 | 585 | device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) |
603 | | - self.ip = _get_ip() |
| 586 | + self.ip = get_ip() |
604 | 587 |
|
605 | 588 | # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases |
606 | 589 | retry_count = 8 |
@@ -792,7 +775,7 @@ def gather_metas(self, checkpoint_name: str): |
792 | 775 | for x in self._memory_pool.get(checkpoint_name, []) |
793 | 776 | ], |
794 | 777 | p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr, |
795 | | - host_ip=_get_ip(), |
| 778 | + host_ip=get_ip(), |
796 | 779 | device_uuid=self._device_uuid, |
797 | 780 | ) |
798 | 781 |
|
|
0 commit comments