|
4 | 4 | import time |
5 | 5 | import os |
6 | 6 | import shutil |
7 | | - |
8 | | -import mlflow |
9 | 7 | import torch |
10 | 8 | import dill |
11 | 9 | import numbers |
@@ -132,89 +130,58 @@ def get_losses(self): |
132 | 130 |
|
133 | 131 | class MLFlowLogger(BasicLogger): |
134 | 132 | def __init__(self, args=None, savedir='test', verbosity=1, id=None, |
135 | | - stdout=('nstep_dev_loss', 'loop_dev_loss', 'best_loop_dev_loss', |
136 | | - 'nstep_dev_ref_loss', 'loop_dev_ref_loss'), |
| 133 | + stdout=('nstep_dev_loss','loop_dev_loss','best_loop_dev_loss', |
| 134 | + 'nstep_dev_ref_loss','loop_dev_ref_loss'), |
137 | 135 | logout=None): |
138 | | - """ |
139 | | -
|
140 | | - :param args: (Namespace) returned by argparse.ArgumentParser.parse_args() |
141 | | - args.location (str): path to directory on file system to store experiment results via mlflow |
142 | | - or if 'pnl_dadaist_store' then will save to our instance of the pnl mlflow |
143 | | - server. Must have copy of pnl_mlflow_secrets.py (containing a funtion to set |
144 | | - necessary environment variables) located in neuromancer/neuromancer/ |
145 | | - where main modules are located. |
146 | | - :param savedir: Unique folder name to temporarily save artifacts |
147 | | - :param verbosity: (int) Print to stdout every verbosity steps |
148 | | - :param id: (int) Optional unique experiment ID for hyperparameter optimization |
149 | | - :param stdout: (list of str) Metrics to print to stdout. These should correspond to keys in the output dictionary of the Problem |
150 | | - :param logout: (list of str) List of metric names to log via mlflow |
151 | | - """ |
| 136 | + # Lazy import so module import works even if mlflow isn't installed |
| 137 | + try: |
| 138 | + import mlflow # noqa: F401 |
| 139 | + except Exception as e: |
| 140 | + raise ImportError( |
| 141 | + "MLFlowLogger requires mlflow. Install with " |
| 142 | + "`pip install neuromancer[tracking]` or `pip install mlflow>=2.12`." |
| 143 | + ) from e |
| 144 | + |
| 145 | + import mlflow # use after we know it exists |
| 146 | + |
| 147 | + self._mlflow = mlflow |
| 148 | + self._mlflow.set_tracking_uri(args.location) |
| 149 | + self._mlflow.set_experiment(args.exp) |
| 150 | + self._mlflow.start_run(run_name=args.run, run_id=id) |
152 | 151 |
|
153 | | - mlflow.set_tracking_uri(args.location) |
154 | | - mlflow.set_experiment(args.exp) |
155 | | - mlflow.start_run(run_name=args.run, run_id=id) |
156 | 152 | super().__init__(args=args, savedir=savedir, verbosity=verbosity, stdout=stdout) |
157 | 153 | self.logout = logout |
158 | 154 |
|
159 | 155 | def log_parameters(self): |
160 | | - """ |
161 | | - Print experiment parameters to stdout |
162 | | -
|
163 | | - """ |
164 | 156 | params = {k: getattr(self.args, k) for k in vars(self.args)} |
165 | 157 | print({k: type(v) for k, v in params.items()}) |
166 | | - |
167 | | - mlflow.log_params(params) |
| 158 | + self._mlflow.log_params(params) |
168 | 159 |
|
169 | 160 | def log_weights(self, model): |
170 | | - """ |
171 | | -
|
172 | | - :param model: (nn.Module) |
173 | | - :return: (int) Number of learnable parameters in the model. |
174 | | - """ |
175 | 161 | nweights = super().log_weights(model) |
176 | | - mlflow.log_metric('nparams', float(nweights)) |
| 162 | + self._mlflow.log_metric('nparams', float(nweights)) |
177 | 163 |
|
178 | 164 | def log_metrics(self, output, step=0): |
179 | | - """ |
180 | | - Record metrics to mlflow |
181 | | -
|
182 | | - :param output: (dict {str: tensor}) Will only record 0d torch.Tensors (scalars) |
183 | | - :param step: (int) Epoch of training |
184 | | - """ |
185 | 165 | super().log_metrics(output, step) |
186 | | - _keys = {k for k in output.keys()} |
| 166 | + keys = set(output.keys()) |
187 | 167 | if self.logout is not None: |
188 | | - keys = [] |
189 | | - for k in _keys: |
190 | | - for kp in self.logout: |
191 | | - if kp in k: |
192 | | - keys.append(k) |
193 | | - else: |
194 | | - keys = _keys |
| 168 | + keys = {k for k in keys if any(p in k for p in self.logout)} |
195 | 169 | for k in keys: |
196 | 170 | v = output[k] |
197 | 171 | if isinstance(v, torch.Tensor) and torch.numel(v) == 1: |
198 | | - mlflow.log_metric(k, v.item()) |
| 172 | + self._mlflow.log_metric(k, v.item()) |
199 | 173 | elif isinstance(v, np.ndarray) and v.size == 1: |
200 | | - mlflow.log_metric(k, v.flatten()) |
| 174 | + self._mlflow.log_metric(k, float(v)) |
201 | 175 | elif isinstance(v, numbers.Number): |
202 | | - mlflow.log_metric(k, v) |
| 176 | + self._mlflow.log_metric(k, v) |
203 | 177 |
|
204 | 178 | def log_artifacts(self, artifacts=dict()): |
205 | | - """ |
206 | | - Stores artifacts created in training to mlflow. |
207 | | -
|
208 | | - :param artifacts: (dict {str: Object}) |
209 | | - """ |
210 | 179 | super().log_artifacts(artifacts) |
211 | | - mlflow.log_artifacts(self.savedir) |
| 180 | + self._mlflow.log_artifacts(self.savedir) |
212 | 181 |
|
213 | 182 | def clean_up(self): |
214 | | - """ |
215 | | - Remove temporary files from file system |
216 | | - """ |
217 | 183 | shutil.rmtree(self.savedir) |
218 | | - mlflow.end_run() |
| 184 | + self._mlflow.end_run() |
| 185 | + |
219 | 186 |
|
220 | 187 |
|
0 commit comments