15
15
from sklearn .svm import SVC
16
16
from sklearn .tree import DecisionTreeClassifier
17
17
18
+ from hamilton .io .utils import FILE_METADATA
18
19
from hamilton .plugins .sklearn_plot_extensions import SklearnPlotSaver
19
20
20
21
if hasattr (metrics , "PredictionErrorDisplay" ):
@@ -191,7 +192,7 @@ def test_cm_plot_saver(
191
192
metadata = writer .save_data (confusion_matrix_display )
192
193
193
194
assert plot_path .exists ()
194
- assert metadata ["path" ] == plot_path
195
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
195
196
196
197
197
198
def test_det_curve_display (
@@ -203,7 +204,7 @@ def test_det_curve_display(
203
204
metadata = writer .save_data (det_curve_display )
204
205
205
206
assert plot_path .exists ()
206
- assert metadata ["path" ] == plot_path
207
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
207
208
208
209
209
210
def test_precision_recall_display (
@@ -215,7 +216,7 @@ def test_precision_recall_display(
215
216
metadata = writer .save_data (precision_recall_display )
216
217
217
218
assert plot_path .exists ()
218
- assert metadata ["path" ] == plot_path
219
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
219
220
220
221
221
222
@pytest .mark .skipif (sys .version_info < (3 , 8 ), reason = "requires python3.8 or higher" )
@@ -228,7 +229,7 @@ def test_prediction_error_display(
228
229
metadata = writer .save_data (prediction_error_display )
229
230
230
231
assert plot_path .exists ()
231
- assert metadata ["path" ] == plot_path
232
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
232
233
233
234
234
235
def test_roc_curve_display (
@@ -240,7 +241,7 @@ def test_roc_curve_display(
240
241
metadata = writer .save_data (roc_curve_display )
241
242
242
243
assert plot_path .exists ()
243
- assert metadata ["path" ] == plot_path
244
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
244
245
245
246
246
247
def test_calibration_display (
@@ -252,7 +253,7 @@ def test_calibration_display(
252
253
metadata = writer .save_data (calibration_display )
253
254
254
255
assert plot_path .exists ()
255
- assert metadata ["path" ] == plot_path
256
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
256
257
257
258
258
259
@pytest .mark .skipif (sys .version_info < (3 , 8 ), reason = "requires python3.8 or higher" )
@@ -265,7 +266,7 @@ def test_decision_boundary_display(
265
266
metadata = writer .save_data (decision_boundary_display )
266
267
267
268
assert plot_path .exists ()
268
- assert metadata ["path" ] == plot_path
269
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
269
270
270
271
271
272
@pytest .mark .skipif (sys .version_info < (3 , 8 ), reason = "requires python3.8 or higher" )
@@ -278,7 +279,7 @@ def test_partial_dependence_display(
278
279
metadata = writer .save_data (partial_dependence_display )
279
280
280
281
assert plot_path .exists ()
281
- assert metadata ["path" ] == plot_path
282
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
282
283
283
284
284
285
@pytest .mark .skipif (sys .version_info < (3 , 8 ), reason = "requires python3.8 or higher" )
@@ -291,7 +292,7 @@ def test_learning_curve_display(
291
292
metadata = writer .save_data (learning_curve_display )
292
293
293
294
assert plot_path .exists ()
294
- assert metadata ["path" ] == plot_path
295
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
295
296
296
297
297
298
@pytest .mark .skipif (sys .version_info < (3 , 8 ), reason = "requires python3.8 or higher" )
@@ -304,4 +305,4 @@ def test_validation_curve_display(
304
305
metadata = writer .save_data (validation_curve_display )
305
306
306
307
assert plot_path .exists ()
307
- assert metadata ["path" ] == plot_path
308
+ assert metadata [FILE_METADATA ][ "path" ] == plot_path
0 commit comments