Skip to content

Commit 57c1004

Browse files
authored
Merge pull request #94 from usnistgov/bugfix_apdist
Fix import error handling in APDist PipelineOp
2 parents a3de58a + 3e6175c commit 57c1004

1 file changed

Lines changed: 18 additions & 5 deletions

File tree

AFL/double_agent/AmplitudePhaseDistance.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from apdist.distances import AmplitudePhaseDistance as numpy_apdist
1111
import torch
1212
except ImportError as e:
13-
raise RuntimeError((
14-
"ImportError encountered: {}\n"
15-
"To use amplitude-distance as a similarity measure, please install:\n"
16-
"pip install git+https://github.com/kiranvad/Amplitude-Phase-Distance"
17-
).format(str(e)))
13+
# Keep this metric optional so AgentDriver can start without apdist/torch.
14+
_APDIST_IMPORT_ERROR = e
15+
torch_apdist = None
16+
numpy_apdist = None
17+
torch = None
18+
else:
19+
_APDIST_IMPORT_ERROR = None
1820

1921
class AmplitudePhaseDistance(PairMetric):
2022
"""Computes pairwise amplitude phase distance between samples
@@ -79,8 +81,18 @@ def __init__(
7981
)
8082
self.method = method
8183

84+
@staticmethod
85+
def _ensure_apdist_available() -> None:
86+
if _APDIST_IMPORT_ERROR is not None:
87+
raise RuntimeError((
88+
"AmplitudePhaseDistance requires optional dependency 'apdist' (and 'torch' for continuous mode).\n"
89+
"Install with: pip install git+https://github.com/kiranvad/Amplitude-Phase-Distance\n"
90+
"Original import error: {}"
91+
).format(str(_APDIST_IMPORT_ERROR)))
92+
8293
def calculate(self, dataset: xr.Dataset) -> Self:
8394
"""Apply this `PipelineOp` to the supplied `xarray.Dataset`"""
95+
self._ensure_apdist_available()
8496
data = self._get_variable(dataset)
8597

8698
domain_variable = [d for d in data.dims if d != self.sample_dim][0]
@@ -125,6 +137,7 @@ def calculate(self, dataset: xr.Dataset) -> Self:
125137
return self
126138

127139
def _get_pairiwise_ap(self, method, t, f_ref, f_query, **kwargs):
140+
self._ensure_apdist_available()
128141

129142
alpha = kwargs.get("alpha", 0.5)
130143
if method=="continuous":

0 commit comments

Comments
 (0)