Skip to content

Commit f95e0f7

Browse files
committed
Add OS Env Var Support
1 parent 0e868f4 commit f95e0f7

4 files changed

Lines changed: 61 additions & 34 deletions

File tree

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@ There are three ways to initialize the launcher:
5050

5151
For health check, you will be prompted to provide your CSCS API key. If you don't have the API key, follow the instructions in the [Appendix](#acquiring-cscs-api-key).
5252

53-
All prompts can be pre-filled via CLI arguments to skip interactive prompts:
54-
55-
| Argument | Description |
56-
| ------------------------------ | ------------------------------------------------------ |
57-
| `--launcher` | Job submission method (`firecrest`, `remote`, `slurm`) |
58-
| `--firecrest-url` | FirecREST API URL |
59-
| `--firecrest-token-uri` | FirecREST token URI |
60-
| `--firecrest-client-id` | FirecREST client ID |
61-
| `--firecrest-client-secret` | FirecREST client secret |
62-
| `--remote-launcher-address` | Remote launcher address (if using `remote`) |
63-
| `--remote-launcher-auth-token` | Remote launcher auth token (if using `remote`) |
64-
| `--cscs-api-key` | CSCS API key for health checks |
65-
| `--telemetry-endpoint` | Endpoint for telemetry reports |
53+
All prompts can be pre-filled to skip interactive prompts:
54+
55+
| CLI Argument | Environment Variable | Description |
56+
| --------------------------- | -------------------------------- | ------------------------------------------------------ |
57+
| `--launcher` | | Job submission method (`firecrest`, `remote`, `slurm`) |
58+
| `--firecrest-url` | | FirecREST API URL |
59+
| `--firecrest-token-uri` | | FirecREST token URI |
60+
| | `SML_FIRECREST_CLIENT_ID` | FirecREST client ID |
61+
| | `SML_FIRECREST_CLIENT_SECRET` | FirecREST client secret |
62+
| `--remote-launcher-address` | | Remote launcher address (if using `remote`) |
63+
| | `SML_REMOTE_LAUNCHER_AUTH_TOKEN` | Auth token for remote launcher (if using `remote`) |
64+
| | `SML_CSCS_API_KEY` | CSCS API key for health checks |
65+
| `--telemetry-endpoint` | | Endpoint for telemetry reports |
6666

6767
### Launching a Model (`sml quickstart`)
6868

src/swiss_ai_model_launch/cli/configuration/init_wizard.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,14 @@ class InitConfig(ChainConfiguration):
6464
PasswordConfiguration(
6565
name="firecrest_client_id",
6666
prompt="What is your FirecREST client ID?",
67+
env_var="SML_FIRECREST_CLIENT_ID",
68+
expose_as_arg=False,
6769
),
6870
PasswordConfiguration(
6971
name="firecrest_client_secret",
7072
prompt="What is your FirecREST client secret?",
73+
env_var="SML_FIRECREST_CLIENT_SECRET",
74+
expose_as_arg=False,
7175
),
7276
],
7377
),
@@ -82,6 +86,8 @@ class InitConfig(ChainConfiguration):
8286
name="remote_launcher_auth_token",
8387
prompt="What is your token for authenticating in "
8488
"remote launcher?",
89+
env_var="SML_REMOTE_LAUNCHER_AUTH_TOKEN",
90+
expose_as_arg=False,
8591
),
8692
],
8793
),
@@ -92,6 +98,8 @@ class InitConfig(ChainConfiguration):
9298
name="cscs_api_key",
9399
prompt="What is your CSCS API key? "
94100
"(https://serving.swissai.svc.cscs.ch)",
101+
env_var="SML_CSCS_API_KEY",
102+
expose_as_arg=False,
95103
),
96104
ChainConfiguration(
97105
name="telemetry_endpoint_configuration",

src/swiss_ai_model_launch/cli/configuration/models.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import inspect
3+
import os
34
from collections.abc import Awaitable, Callable
45
from typing import Annotated, Any, Literal, cast
56

@@ -47,24 +48,38 @@ def set_value(self, name: str, value: str) -> None:
4748
class _ResolvableConfiguration(_Configuration):
4849
value: str | None = None
4950
prompt: str | None = Field(default=None, exclude=True)
51+
env_var: str | None = Field(default=None, exclude=True)
52+
expose_as_arg: bool = Field(default=True, exclude=True)
5053

5154
def _get_question(self) -> Question:
5255
raise NotImplementedError # pragma: no cover
5356

5457
def _on_answer(self) -> None:
5558
pass
5659

60+
def _try_resolve_without_prompt(
61+
self, args: argparse.Namespace | None
62+
) -> str | None:
63+
if self.expose_as_arg and args is not None:
64+
arg_value = getattr(args, self.name, None)
65+
if arg_value is not None:
66+
return str(arg_value)
67+
if self.env_var is not None:
68+
env_value = os.environ.get(self.env_var)
69+
if env_value is not None:
70+
return env_value
71+
return None
72+
5773
async def aconfigure(
5874
self,
5975
get_value: GetValueFn | None = None,
6076
args: argparse.Namespace | None = None,
6177
) -> None:
62-
if args is not None:
63-
arg_value = getattr(args, self.name, None)
64-
if arg_value is not None:
65-
self.value = arg_value
66-
self._on_answer()
67-
return
78+
resolved = self._try_resolve_without_prompt(args)
79+
if resolved is not None:
80+
self.value = resolved
81+
self._on_answer()
82+
return
6883
self.value = await self._get_question().ask_async()
6984
self._on_answer()
7085

@@ -125,6 +140,8 @@ def _resolve_validator(self, get_value: GetValueFn | None) -> ValidatorFn | None
125140
return cast(ValidatorFn, self.validator)
126141

127142
def add_to_parser(self, parser: argparse.ArgumentParser) -> None:
143+
if not self.expose_as_arg:
144+
return
128145
parser.add_argument(
129146
f"--{self.name.replace('_', '-')}",
130147
dest=self.name,
@@ -138,12 +155,11 @@ async def aconfigure(
138155
get_value: GetValueFn | None = None,
139156
args: argparse.Namespace | None = None,
140157
) -> None:
141-
if args is not None:
142-
arg_value = getattr(args, self.name, None)
143-
if arg_value is not None:
144-
self.value = arg_value
145-
self._on_answer()
146-
return
158+
resolved = self._try_resolve_without_prompt(args)
159+
if resolved is not None:
160+
self.value = resolved
161+
self._on_answer()
162+
return
147163
self.value = await questionary.text(
148164
self.prompt or self.name,
149165
default=await self._resolve_default(get_value) or "",
@@ -166,6 +182,8 @@ def load_from_keyring(self) -> "PasswordConfiguration":
166182
return self
167183

168184
def add_to_parser(self, parser: argparse.ArgumentParser) -> None:
185+
if not self.expose_as_arg:
186+
return
169187
parser.add_argument(
170188
f"--{self.name.replace('_', '-')}",
171189
dest=self.name,
@@ -233,6 +251,8 @@ async def _resolve_options(
233251
return await cast(Callable[[], Awaitable[OptionsDict]], self.options_factory)()
234252

235253
def add_to_parser(self, parser: argparse.ArgumentParser) -> None:
254+
if not self.expose_as_arg:
255+
return
236256
kwargs: dict[str, Any] = {
237257
"dest": self.name,
238258
"default": None,
@@ -249,14 +269,13 @@ async def aconfigure(
249269
get_value: GetValueFn | None = None,
250270
args: argparse.Namespace | None = None,
251271
) -> None:
252-
if args is not None:
253-
arg_value = getattr(args, self.name, None)
254-
if arg_value is not None:
255-
options = await self._resolve_options(get_value)
256-
if not options or arg_value in options:
257-
self.value = arg_value
258-
self._on_answer()
259-
return
272+
resolved = self._try_resolve_without_prompt(args)
273+
if resolved is not None:
274+
options = await self._resolve_options(get_value)
275+
if not options or resolved in options:
276+
self.value = resolved
277+
self._on_answer()
278+
return
260279
options = await self._resolve_options(get_value)
261280
if len(options) == 1:
262281
self.value = next(iter(options))

src/swiss_ai_model_launch/cli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _make_launch_request_config(
114114
validator=lambda v: bool(
115115
re.fullmatch(r"[0-9]{1,2}:[0-5][0-9]:[0-5][0-9]", v)
116116
),
117-
default_factory=None,
117+
default_factory=time_default_factory,
118118
),
119119
],
120120
)

0 commit comments

Comments
 (0)