Skip to content

Commit 24e798c

Browse files
Add New Models (#22)
* Add New Models * Add MiniMaxAI/MiniMax-M2 * Change Default Duration of Running of All Models to 2 Hours
1 parent e893a15 commit 24e798c

7 files changed

Lines changed: 138 additions & 76 deletions

File tree

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
image = "/capstor/store/cscs/swissai/infra01/container-images/sglang_kimi_k2.5_cuda13.sqsh"
2+
3+
# "src_path:trg_path" mounts the src_path on the host inside the container at the trg_path.
4+
mounts = [
5+
"/iopsstor/store/cscs/swissai/a09/xyao/bin:/ocfbin",
6+
"/capstor",
7+
"/iopsstor",
8+
"/usr/lib64/libhwloc.so.15:/usr/lib/libhwloc.so.15",
9+
"/usr/lib64/libpciaccess.so.0:/usr/lib/libpciaccess.so.0",
10+
"/usr/lib64/libxml2.so.2:/usr/lib/libxml2.so.2",
11+
"/usr/lib64/libnuma.so.1:/usr/lib/libnuma.so.1",
12+
]
13+
14+
workdir = "/opt"
15+
16+
[env]
17+
# NCCL_DEBUG = "INFO" # uncomment for debugging
18+
# NCCL_DEBUG_SUBSYS = "INIT,NET" # uncomment for debugging
19+
NCCL_NET = "AWS Libfabric"
20+
NCCL_CROSS_NIC = "1"
21+
NCCL_NET_GDR_LEVEL = "PHB"
22+
NCCL_SOCKET_IFNAME = "hsn"
23+
NCCL_PROTO = "^LL128"
24+
FI_CXI_COMPAT = "0"
25+
FI_MR_CACHE_MONITOR = "userfaultfd"
26+
FI_CXI_RX_MATCH_MODE = "software"
27+
FI_CXI_DEFAULT_CQ_SIZE = "131072"
28+
FI_CXI_DEFAULT_TX_SIZE = "32768"
29+
FI_CXI_DISABLE_HOST_REGISTER = "1"
30+
OFI_NCCL_DISABLE_DMABUF = "1"
31+
SGL_ENABLE_JIT_DEEPGEMM = "0"
32+
33+
[annotations]
34+
com.hooks.aws_ofi_nccl.enabled = "true"
35+
com.hooks.aws_ofi_nccl.variant = "cuda13"
36+
com.hooks.cxi.enabled = "true"

src/swiss_ai_model_launch/assets/models.json

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
"environment": null,
77
"workers": 1,
88
"nodes_per_worker": 1,
9-
"time": "00:30:00"
9+
"time": "02:00:00",
10+
"framework_args": ""
1011
},
1112
{
1213
"vendor": "swiss-ai",
@@ -15,7 +16,8 @@
1516
"environment": null,
1617
"workers": 1,
1718
"nodes_per_worker": 1,
18-
"time": "00:30:00"
19+
"time": "02:00:00",
20+
"framework_args": ""
1921
},
2022
{
2123
"vendor": "swiss-ai",
@@ -24,7 +26,8 @@
2426
"environment": null,
2527
"workers": 1,
2628
"nodes_per_worker": 1,
27-
"time": "00:30:00"
29+
"time": "02:00:00",
30+
"framework_args": ""
2831
},
2932
{
3033
"vendor": "swiss-ai",
@@ -33,24 +36,37 @@
3336
"environment": null,
3437
"workers": 1,
3538
"nodes_per_worker": 1,
36-
"time": "00:30:00"
39+
"time": "02:00:00",
40+
"framework_args": ""
3741
},
3842
{
39-
"vendor": "zai-org",
40-
"model_name": "GLM-4.7-Flash",
43+
"vendor": "Qwen",
44+
"model_name": "Qwen3-235B-A22B-Instruct-2507",
4145
"framework": "sglang",
4246
"environment": null,
4347
"workers": 1,
44-
"nodes_per_worker": 1,
45-
"time": "00:30:00"
48+
"nodes_per_worker": 2,
49+
"time": "02:00:00",
50+
"framework_args": "--tp-size 8"
4651
},
4752
{
48-
"vendor": "zai-org",
49-
"model_name": "GLM-4.7-Flash",
50-
"framework": "vllm",
53+
"vendor": "moonshotai",
54+
"model_name": "Kimi-K2.5",
55+
"framework": "sglang",
56+
"environment": "src/swiss_ai_model_launch/assets/envs/sglang_kimi.toml",
57+
"workers": 1,
58+
"nodes_per_worker": 4,
59+
"time": "02:00:00",
60+
"framework_args": "--tp-size 16 --trust-remote-code --tool-call-parser kimi_k2 --reasoning-parser kimi_k2"
61+
},
62+
{
63+
"vendor": "MiniMaxAI",
64+
"model_name": "MiniMax-M2",
65+
"framework": "sglang",
5166
"environment": null,
5267
"workers": 1,
53-
"nodes_per_worker": 1,
54-
"time": "00:30:00"
68+
"nodes_per_worker": 2,
69+
"time": "02:00:00",
70+
"framework_args": "--tp-size 8 --ep-size 8 --tool-call-parser minimax-m2 --reasoning-parser minimax-append-think --trust-remote-code --mem-fraction-static 0.85"
5571
}
5672
]

src/swiss_ai_model_launch/cli/configuration/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,11 @@ async def _resolve_options(
185185
return await cast(Callable[[], Awaitable[OptionsDict]], self.options_factory)()
186186

187187
async def aconfigure(self, get_value: GetValueFn | None = None) -> None:
188-
self.value = await self._build_question(
189-
await self._resolve_options(get_value)
190-
).ask_async()
188+
options = await self._resolve_options(get_value)
189+
if len(options) == 1:
190+
self.value = next(iter(options))
191+
else:
192+
self.value = await self._build_question(options).ask_async()
191193
self._on_answer()
192194

193195

src/swiss_ai_model_launch/cli/healthcheck/checker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ async def check_model_health(served_model_name: str, api_key: str) -> ModelHealt
2323
},
2424
timeout=_TIMEOUT_SECONDS,
2525
)
26-
return ModelHealth.HEALTHY if response.is_success else ModelHealth.ERROR
26+
return (
27+
ModelHealth.HEALTHY if response.is_success else ModelHealth.NOT_RESPONDING
28+
)
2729
except (httpx.TransportError, httpx.TimeoutException):
28-
return ModelHealth.NOT_RESPONDING
30+
return ModelHealth.ERROR

src/swiss_ai_model_launch/cli/main.py

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import re
3-
from collections.abc import Awaitable, Callable
43

54
import firecrest as f7t
65

@@ -87,11 +86,18 @@ async def _get_partitions(
8786
)
8887

8988

89+
def _split_vendor_model(combined: str) -> tuple[str, str]:
90+
vendor, model_name = combined.split("::", 1)
91+
return vendor, model_name
92+
93+
9094
async def _get_preconfigured_default(
9195
get_value_from_context: GetValueFn, preconfigured: list[LaunchRequest], field: str
9296
) -> str | None:
93-
vendor = get_value_from_context("model_vendor")
94-
model_name = get_value_from_context("model_name")
97+
combined = get_value_from_context("model_vendor_model")
98+
if combined is None:
99+
return None
100+
vendor, model_name = _split_vendor_model(combined)
95101
framework = get_value_from_context("framework")
96102
match = next(
97103
(
@@ -108,43 +114,36 @@ async def _get_preconfigured_default(
108114
return str(getattr(match, field))
109115

110116

111-
def _make_served_model_name_default(
112-
preconfigured: list[LaunchRequest],
113-
) -> Callable[[GetValueFn], Awaitable[str]]:
114-
async def _default(get_value: GetValueFn) -> str:
115-
value = await _get_preconfigured_default(
116-
get_value, preconfigured, "served_model_name"
117-
)
118-
if value and value != "None":
119-
return value
120-
return f"{get_value('model_vendor')}/{get_value('model_name')}-{create_salt(4)}"
121-
122-
return _default
117+
async def _get_router_options(get_value: GetValueFn) -> dict[str, tuple[str, str]]:
118+
workers = get_value("workers")
119+
if workers is not None and int(workers) > 1:
120+
return {
121+
"yes": ("Yes", "Use router to load balance across workers"),
122+
"no": ("No", "Do not use router"),
123+
}
124+
return {
125+
"no": ("No", "Do not use router"),
126+
}
123127

124128

125129
async def _get_launch_request(launcher: Launcher) -> LaunchRequest:
126130
preconfigured_launch_requests = await launcher.get_preconfigured_models()
127131

128-
async def _get_vendors() -> dict[str, tuple[str, str]]:
129-
return {
130-
lr.vendor: (lr.vendor, lr.vendor) for lr in preconfigured_launch_requests
131-
}
132-
133-
async def _get_models(
134-
get_value_from_context: GetValueFn,
135-
) -> dict[str, tuple[str, str]]:
136-
vendor = get_value_from_context("model_vendor")
137-
return {
138-
lr.model_name: (lr.model_name, lr.model_name)
139-
for lr in preconfigured_launch_requests
140-
if lr.vendor == vendor
141-
}
132+
async def _get_vendor_models() -> dict[str, tuple[str, str]]:
133+
seen: dict[str, tuple[str, str]] = {}
134+
for lr in preconfigured_launch_requests:
135+
key = f"{lr.vendor}::{lr.model_name}"
136+
if key not in seen:
137+
seen[key] = (lr.model_name, lr.vendor)
138+
return seen
142139

143140
async def _get_frameworks(
144141
get_value_from_context: GetValueFn,
145142
) -> dict[str, tuple[str, str]]:
146-
vendor = get_value_from_context("model_vendor")
147-
model_name = get_value_from_context("model_name")
143+
combined = get_value_from_context("model_vendor_model")
144+
if combined is None:
145+
return {}
146+
vendor, model_name = _split_vendor_model(combined)
148147
return {
149148
lr.framework: (lr.framework, lr.framework)
150149
for lr in preconfigured_launch_requests
@@ -155,14 +154,9 @@ async def _get_frameworks(
155154
name="launcher_request_configuration",
156155
chain=[
157156
OptionsConfiguration(
158-
name="model_vendor",
159-
prompt="Choose the model vendor.",
160-
options_factory=_get_vendors,
161-
),
162-
OptionsConfiguration(
163-
name="model_name",
157+
name="model_vendor_model",
164158
prompt="Choose the model to launch.",
165-
options_factory=_get_models,
159+
options_factory=_get_vendor_models,
166160
),
167161
OptionsConfiguration(
168162
name="framework",
@@ -177,13 +171,10 @@ async def _get_frameworks(
177171
get_value, preconfigured_launch_requests, "workers"
178172
),
179173
),
180-
TextConfiguration(
181-
name="nodes_per_worker",
182-
prompt="Number of nodes to use per worker for running the model.",
183-
validator=lambda v: v.isdigit() and int(v) > 0,
184-
default_factory=lambda get_value: _get_preconfigured_default(
185-
get_value, preconfigured_launch_requests, "nodes_per_worker"
186-
),
174+
OptionsConfiguration(
175+
name="use_router",
176+
prompt="Use router to load balance across workers.",
177+
options_factory=lambda get_value: _get_router_options(get_value),
187178
),
188179
TextConfiguration(
189180
name="time",
@@ -195,26 +186,35 @@ async def _get_frameworks(
195186
get_value, preconfigured_launch_requests, "time"
196187
),
197188
),
198-
TextConfiguration(
199-
name="served_model_name",
200-
prompt="Served model name.",
201-
validator=lambda s: len(s) > 0,
202-
default_factory=_make_served_model_name_default(
203-
preconfigured_launch_requests
204-
),
205-
),
206189
],
207190
)
208191
await launch_req_config.aconfigure()
209192

193+
vendor, model_name = _split_vendor_model(
194+
launch_req_config.get_non_none_value("model_vendor_model")
195+
)
196+
framework = launch_req_config.get_non_none_value("framework")
197+
preconfigured = next(
198+
(
199+
lr
200+
for lr in preconfigured_launch_requests
201+
if lr.vendor == vendor
202+
and lr.model_name == model_name
203+
and lr.framework == framework
204+
),
205+
None,
206+
)
210207
return LaunchRequest(
211-
vendor=launch_req_config.get_non_none_value("model_vendor"),
212-
model_name=launch_req_config.get_non_none_value("model_name"),
213-
framework=launch_req_config.get_non_none_value("framework"),
208+
vendor=vendor,
209+
model_name=model_name,
210+
framework=framework,
211+
environment=preconfigured.environment if preconfigured else None,
214212
workers=int(launch_req_config.get_non_none_value("workers")),
215-
nodes_per_worker=int(launch_req_config.get_non_none_value("nodes_per_worker")),
213+
nodes_per_worker=preconfigured.nodes_per_worker if preconfigured else 1,
216214
time=launch_req_config.get_non_none_value("time"),
217-
served_model_name=launch_req_config.get_non_none_value("served_model_name"),
215+
served_model_name=f"{vendor}/{model_name}-{create_salt(4)}",
216+
framework_args=preconfigured.framework_args if preconfigured else None,
217+
use_router=launch_req_config.get_non_none_value("use_router") == "yes",
218218
)
219219

220220

@@ -262,3 +262,7 @@ async def _monitor() -> None:
262262

263263
def main() -> None:
264264
asyncio.run(_main())
265+
266+
267+
if __name__ == "__main__":
268+
main()

src/swiss_ai_model_launch/launchers/firecrest_launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def _get_launch_args_from_request(
8585
)
8686
),
8787
telemetry_endpoint=self.telemetry_endpoint,
88+
use_router=launch_request.use_router,
8889
)
8990

9091
def _get_local_env_file_path(self, launch_request: LaunchRequest) -> str:

src/swiss_ai_model_launch/launchers/launch_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ class LaunchRequest(BaseModel):
1313
time: str
1414
served_model_name: str | None = None
1515
framework_args: str | None = None
16+
use_router: bool = False

0 commit comments

Comments
 (0)