Skip to content

Commit f0fc35a

Browse files
committed
feat: split __main__.py from ps.py
1 parent e48c442 commit f0fc35a

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

checkpoint_engine/__main__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import argparse
2+
import os
3+
4+
from loguru import logger
5+
6+
from checkpoint_engine.api import _init_api
7+
from checkpoint_engine.ps import ParameterServer
8+
9+
10+
@logger.catch(reraise=True)
11+
def run_from_cli():
12+
import uvicorn
13+
14+
parser = argparse.ArgumentParser(description="Parameter Server")
15+
parser.add_argument("--uds", type=str)
16+
17+
args = parser.parse_args()
18+
logger.info(
19+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
20+
)
21+
22+
assert args.uds and len(args.uds) > 0, args.uds
23+
ps = ParameterServer(auto_pg=True)
24+
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
25+
26+
27+
if __name__ == "__main__":
28+
run_from_cli()

checkpoint_engine/ps.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import argparse
21
import os
32
import threading
43
import time
@@ -13,7 +12,6 @@
1312
from loguru import logger
1413
from torch.multiprocessing.reductions import reduce_tensor
1514

16-
from checkpoint_engine.api import _init_api
1715
from checkpoint_engine.data_types import (
1816
BucketRange,
1917
DataToGather,
@@ -822,24 +820,3 @@ def _update_per_bucket(
822820
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
823821

824822
self.device_manager.device_module.empty_cache()
825-
826-
827-
@logger.catch(reraise=True)
828-
def run_from_cli():
829-
import uvicorn
830-
831-
parser = argparse.ArgumentParser(description="Parameter Server")
832-
parser.add_argument("--uds", type=str)
833-
834-
args = parser.parse_args()
835-
logger.info(
836-
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
837-
)
838-
839-
assert args.uds and len(args.uds) > 0, args.uds
840-
ps = ParameterServer(auto_pg=True)
841-
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
842-
843-
844-
if __name__ == "__main__":
845-
run_from_cli()

0 commit comments

Comments
 (0)