|
10 | 10 | from apdist.distances import AmplitudePhaseDistance as numpy_apdist |
11 | 11 | import torch |
12 | 12 | except ImportError as e: |
13 | | - raise RuntimeError(( |
14 | | - "ImportError encountered: {}\n" |
15 | | - "To use amplitude-distance as a similarity measure, please install:\n" |
16 | | - "pip install git+https://github.com/kiranvad/Amplitude-Phase-Distance" |
17 | | - ).format(str(e))) |
| 13 | + # Keep this metric optional so AgentDriver can start without apdist/torch. |
| 14 | + _APDIST_IMPORT_ERROR = e |
| 15 | + torch_apdist = None |
| 16 | + numpy_apdist = None |
| 17 | + torch = None |
| 18 | +else: |
| 19 | + _APDIST_IMPORT_ERROR = None |
18 | 20 |
|
19 | 21 | class AmplitudePhaseDistance(PairMetric): |
20 | 22 | """Computes pairwise amplitude phase distance between samples |
@@ -79,8 +81,18 @@ def __init__( |
79 | 81 | ) |
80 | 82 | self.method = method |
81 | 83 |
|
| 84 | + @staticmethod |
| 85 | + def _ensure_apdist_available() -> None: |
| 86 | + if _APDIST_IMPORT_ERROR is not None: |
| 87 | + raise RuntimeError(( |
| 88 | + "AmplitudePhaseDistance requires optional dependency 'apdist' (and 'torch' for continuous mode).\n" |
| 89 | + "Install with: pip install git+https://github.com/kiranvad/Amplitude-Phase-Distance\n" |
| 90 | + "Original import error: {}" |
| 91 | + ).format(str(_APDIST_IMPORT_ERROR))) |
| 92 | + |
82 | 93 | def calculate(self, dataset: xr.Dataset) -> Self: |
83 | 94 | """Apply this `PipelineOp` to the supplied `xarray.Dataset`""" |
| 95 | + self._ensure_apdist_available() |
84 | 96 | data = self._get_variable(dataset) |
85 | 97 |
|
86 | 98 | domain_variable = [d for d in data.dims if d != self.sample_dim][0] |
@@ -125,6 +137,7 @@ def calculate(self, dataset: xr.Dataset) -> Self: |
125 | 137 | return self |
126 | 138 |
|
127 | 139 | def _get_pairiwise_ap(self, method, t, f_ref, f_query, **kwargs): |
| 140 | + self._ensure_apdist_available() |
128 | 141 |
|
129 | 142 | alpha = kwargs.get("alpha", 0.5) |
130 | 143 | if method=="continuous": |
|
0 commit comments