forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_run_multi_mpi_comm_tasks.py
More file actions
46 lines (35 loc) · 1.45 KB
/
_run_multi_mpi_comm_tasks.py
File metadata and controls
46 lines (35 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
from typing import Literal
import click
from tensorrt_llm.executor.utils import (
LlmLauncherEnvs, get_spawn_proxy_process_ipc_hmac_key_env)
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
from tensorrt_llm.llmapi.utils import print_colored
def run_task(task_type: Literal["submit", "submit_sync"]):
tasks = range(10)
assert os.environ[
LlmLauncherEnvs.
TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
hmac_key = get_spawn_proxy_process_ipc_hmac_key_env()
client = RemoteMpiCommSessionClient(
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR],
hmac_key=hmac_key)
for task in tasks:
if task_type == "submit":
client.submit(print_colored, f"{task}\n", "green")
elif task_type == "submit_sync":
res = client.submit_sync(print_colored, f"{task}\n", "green")
print(res)
def run_multi_tasks(task_type: Literal["submit", "submit_sync"]):
for i in range(3):
print_colored(f"Running MPI comm task {i}\n", "green")
run_task(task_type)
print_colored(f"MPI comm task {i} completed\n", "green")
@click.command()
@click.option("--task_type",
type=click.Choice(["submit", "submit_sync"]),
default="submit")
def main(task_type: Literal["submit", "submit_sync"]):
run_multi_tasks(task_type)
if __name__ == "__main__":
main()