Skip to content

Commit b2c8251

Browse files
author
cuixiaojin
committed
[modify] get ip
1 parent 56f0ec2 commit b2c8251

File tree

2 files changed

+22
-23
lines changed

2 files changed

+22
-23
lines changed

checkpoint_engine/device_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,26 @@
22
import re
33
import socket
44
import subprocess
5-
65
import torch
76

7+
from functools import lru_cache
8+
from loguru import logger
9+
10+
11+
@lru_cache(maxsize=1)
12+
def get_ip() -> str:
13+
try:
14+
# try to get ip from network interface
15+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
16+
s.connect(("8.8.8.8", 80))
17+
return s.getsockname()[0]
18+
except Exception as e: # noqa: BLE001
19+
# fallback to get ip from hostname
20+
logger.warning(
21+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
22+
)
23+
return socket.gethostbyname(socket.gethostname())
24+
825

926
def npu_generate_uuid() -> str:
1027
str_pid = str(os.getpid())
@@ -21,8 +38,7 @@ def npu_generate_uuid() -> str:
2138
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
2239
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
2340
chip_id = int(match_chip_id.group(1))
24-
server_ip = socket.gethostbyname(socket.gethostname())
25-
return f"{server_ip}-{npu_id * chip_count + chip_id}"
41+
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
2642
ValueError("The current process is not running on the npu device")
2743
except subprocess.CalledProcessError:
2844
ValueError("The current process is not running on the npu device")

checkpoint_engine/ps.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import os
55
import pickle
66
import random
7-
import socket
87
import threading
98
import time
109
from collections import defaultdict
1110
from collections.abc import Callable
1211
from datetime import timedelta
13-
from functools import lru_cache
1412
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
1513

1614
import httpx
@@ -23,7 +21,7 @@
2321
from safetensors.torch import safe_open
2422
from torch.multiprocessing.reductions import reduce_tensor
2523

26-
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
24+
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
2725

2826

2927
if TYPE_CHECKING:
@@ -261,21 +259,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
261259
raise ValueError(f"fail to get physical gpu id {device_index}") from e
262260

263261

264-
@lru_cache(maxsize=1)
265-
def _get_ip() -> str:
266-
try:
267-
# try to get ip from network interface
268-
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
269-
s.connect(("8.8.8.8", 80))
270-
return s.getsockname()[0]
271-
except Exception as e: # noqa: BLE001
272-
# fallback to get ip from hostname
273-
logger.warning(
274-
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
275-
)
276-
return socket.gethostbyname(socket.gethostname())
277-
278-
279262
def _ibv_get_device_list() -> list[str]:
280263
lib = ctypes.CDLL("libibverbs.so.1")
281264
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
@@ -600,7 +583,7 @@ def __init__(self, device_manager: DeviceManager):
600583
gpu_count = device_manager.device_module.device_count()
601584
local_rank = self.rank % gpu_count
602585
device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
603-
self.ip = _get_ip()
586+
self.ip = get_ip()
604587

605588
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
606589
retry_count = 8
@@ -792,7 +775,7 @@ def gather_metas(self, checkpoint_name: str):
792775
for x in self._memory_pool.get(checkpoint_name, [])
793776
],
794777
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
795-
host_ip=_get_ip(),
778+
host_ip=get_ip(),
796779
device_uuid=self._device_uuid,
797780
)
798781

0 commit comments

Comments
 (0)