Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions mache/deploy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _resolve_login_pixi_env(
prefix: str,
pixi_cfg: dict[str, Any],
runtime: dict[str, Any],
compute_mpi: str,
compute_mpi: str | None,
) -> dict[str, str] | None:
runtime_pixi = runtime.get('pixi')
if not isinstance(runtime_pixi, dict):
Expand All @@ -563,6 +563,8 @@ def _resolve_login_pixi_env(
resolved_login_mpi, login_mpi_prefix = _get_mpi_settings(
pixi_cfg={'mpi': login_mpi}
)
if resolved_login_mpi is None or login_mpi_prefix is None:
raise ValueError('login_mpi must resolve to a non-empty MPI setting')

if 'login_prefix' in runtime_pixi:
login_prefix_raw = runtime_pixi.get('login_prefix')
Expand Down Expand Up @@ -885,15 +887,12 @@ def _resolve_pixi_mpi(
*,
pixi_cfg: dict[str, Any],
runtime: dict[str, Any],
) -> tuple[str, str]:
mpi_override = None
) -> tuple[str | None, str | None]:
mpi_cfg = pixi_cfg
runtime_pixi = runtime.get('pixi')
if isinstance(runtime_pixi, dict):
mpi_override = runtime_pixi.get('mpi')
if mpi_override is not None:
mpi_cfg: dict[str, Any] = {'mpi': str(mpi_override)}
else:
mpi_cfg = pixi_cfg
if 'mpi' in runtime_pixi:
mpi_cfg = {'mpi': runtime_pixi.get('mpi')}
return _get_mpi_settings(pixi_cfg=mpi_cfg)


Expand Down Expand Up @@ -1346,7 +1345,7 @@ def _write_load_script(
software_version: str,
runtime_version_cmd: str | None,
machine: str | None,
compute_pixi_mpi: str,
compute_pixi_mpi: str | None,
toolchain_compiler: str | None,
toolchain_mpi: str | None,
spack_library_view: str | None,
Expand Down Expand Up @@ -1434,7 +1433,7 @@ def _write_load_script(
runtime_version_cmd_sh=shlex.quote(runtime_version_cmd or ''),
machine=machine or '',
load_script=os.path.abspath(script_path),
compute_pixi_mpi=compute_pixi_mpi,
compute_pixi_mpi=compute_pixi_mpi or '',
login_prefix=login_prefix_abs,
login_pixi_toml=login_pixi_toml,
login_pixi_mpi=login_pixi_mpi,
Expand Down Expand Up @@ -1467,7 +1466,7 @@ def _write_load_scripts(
software_version: str,
runtime_version_cmd: str | None,
machine: str | None,
compute_pixi_mpi: str,
compute_pixi_mpi: str | None,
toolchain_pairs: list[tuple[str, str]],
spack_results: Any,
spack_software_env: Any,
Expand Down Expand Up @@ -1640,29 +1639,26 @@ def _pixi_install(

def _get_mpi_settings(
pixi_cfg: dict[str, Any],
) -> tuple[str, str]:
) -> tuple[str | None, str | None]:
"""Determine MPI-related template replacements.

Returns
-------
mpi : str
mpi : str or None
The conda package name for MPI (e.g. "mpich", "openmpi", or "nompi").
mpi_prefix : str
mpi_prefix : str or None
The conda-forge variant prefix used in build-string selectors
(e.g. "nompi", "mpi_mpich", "mpi_openmpi").
"""

if 'mpi' not in pixi_cfg:
raise ValueError(
"'mpi' not found in [pixi] section of deploy/config.yaml.j2"
)
return None, None

mpi_raw = pixi_cfg.get('mpi')
mpi = str(mpi_raw).strip().lower() if mpi_raw is not None else ''
if not mpi:
raise ValueError(
"'mpi' in [pixi] section of deploy/config.yaml.j2 is empty"
)
mpi = _normalize_optional_token(pixi_cfg.get('mpi'))
if mpi is None:
return None, None

mpi = mpi.lower()
if any(ch.isspace() for ch in mpi):
raise ValueError(
"'mpi' in [pixi] section of deploy/config.yaml.j2 must not "
Expand Down
1 change: 1 addition & 0 deletions mache/deploy/templates/config.yaml.j2.j2
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pixi:

# MPI provider for conda packages.
# Supported values in `mache deploy run`:
# - null if MPI is not used in the target software
# - nompi
# - mpich
# - openmpi
Expand Down
24 changes: 24 additions & 0 deletions tests/test_deploy_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,30 @@ def test_resolve_pixi_omit_dependencies_prefers_runtime_override():
assert omit_dependencies == ['git', 'ncview']


def test_get_mpi_settings_allows_missing_mpi():
mpi, mpi_prefix = deploy_run._get_mpi_settings(pixi_cfg={})

assert mpi is None
assert mpi_prefix is None


def test_get_mpi_settings_allows_null_mpi():
mpi, mpi_prefix = deploy_run._get_mpi_settings(pixi_cfg={'mpi': None})

assert mpi is None
assert mpi_prefix is None


def test_resolve_pixi_mpi_runtime_override_can_disable_mpi():
mpi, mpi_prefix = deploy_run._resolve_pixi_mpi(
pixi_cfg={'mpi': 'openmpi'},
runtime={'pixi': {'mpi': None}},
)

assert mpi is None
assert mpi_prefix is None


def test_resolve_login_pixi_env_uses_distinct_login_prefix():
login_env = deploy_run._resolve_login_pixi_env(
prefix='/tmp/compute-env',
Expand Down
Loading