Skip to content

Commit 4c0e5de

Browse files
committed
Fixes all local file loader metadata to have a uniform shape
Beforehand we had different approaches depending on whether it was a dataframe. We'll need this for downstream parsing. The new format is one to many of: 1. sql_metadata 2. file_metadata 3. dataframe_metadata The utils functions will return prefixed dicts, and consumers just have to add them. This is all internal-facing (just to be used for diagnostics), so its OK to change it.
1 parent 6107068 commit 4c0e5de

9 files changed

+48
-44
lines changed

hamilton/io/utils.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ def get_file_metadata(path: str) -> Dict[str, Any]:
2222
- the current time
2323
"""
2424
return {
25-
"size": os.path.getsize(path),
26-
"path": path,
27-
"last_modified": os.path.getmtime(path),
28-
"timestamp": datetime.now().utcnow().timestamp(),
25+
FILE_METADATA: {
26+
"size": os.path.getsize(path),
27+
"path": path,
28+
"last_modified": os.path.getmtime(path),
29+
"timestamp": datetime.now().utcnow().timestamp(),
30+
}
2931
}
3032

3133

@@ -42,10 +44,12 @@ def get_dataframe_metadata(df: pd.DataFrame) -> Dict[str, Any]:
4244
- the data types
4345
"""
4446
return {
45-
"rows": len(df),
46-
"columns": len(df.columns),
47-
"column_names": list(df.columns),
48-
"datatypes": [str(t) for t in list(df.dtypes)], # for serialization purposes
47+
DATAFRAME_METADATA: {
48+
"rows": len(df),
49+
"columns": len(df.columns),
50+
"column_names": list(df.columns),
51+
"datatypes": [str(t) for t in list(df.dtypes)], # for serialization purposes
52+
}
4953
}
5054

5155

@@ -67,7 +71,7 @@ def get_file_and_dataframe_metadata(path: str, df: pd.DataFrame) -> Dict[str, An
6771
- the column names
6872
- the data types
6973
"""
70-
return {FILE_METADATA: get_file_metadata(path), DATAFRAME_METADATA: get_dataframe_metadata(df)}
74+
return {**get_file_metadata(path), **get_dataframe_metadata(df)}
7175

7276

7377
def get_sql_metadata(query_or_table: str, results: Union[int, pd.DataFrame]) -> Dict[str, Any]:
@@ -91,8 +95,10 @@ def get_sql_metadata(query_or_table: str, results: Union[int, pd.DataFrame]) ->
9195
else:
9296
rows = None
9397
return {
94-
"rows": rows,
95-
"query": query,
96-
"table_name": table_name,
97-
"timestamp": datetime.now().utcnow().timestamp(),
98+
SQL_METADATA: {
99+
"rows": rows,
100+
"query": query,
101+
"table_name": table_name,
102+
"timestamp": datetime.now().utcnow().timestamp(),
103+
}
98104
}

hamilton/plugins/pandas_extensions.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -733,11 +733,7 @@ def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
733733
df = pd.read_sql(self.query_or_table, self.db_connection, **self._get_loading_kwargs())
734734
sql_metadata = utils.get_sql_metadata(self.query_or_table, df)
735735
df_metadata = utils.get_dataframe_metadata(df)
736-
metadata = {
737-
utils.SQL_METADATA: sql_metadata,
738-
utils.DATAFRAME_METADATA: df_metadata,
739-
}
740-
return df, metadata
736+
return df, {**sql_metadata, **df_metadata}
741737

742738
@classmethod
743739
def name(cls) -> str:
@@ -793,11 +789,7 @@ def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
793789
results = data.to_sql(self.table_name, self.db_connection, **self._get_saving_kwargs())
794790
sql_metadata = utils.get_sql_metadata(self.table_name, results)
795791
df_metadata = utils.get_dataframe_metadata(data)
796-
metadata = {
797-
utils.SQL_METADATA: sql_metadata,
798-
utils.DATAFRAME_METADATA: df_metadata,
799-
}
800-
return metadata
792+
return {**sql_metadata, **df_metadata}
801793

802794
@classmethod
803795
def name(cls) -> str:

tests/io/test_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import pandas as pd
22

3-
from hamilton.io.utils import get_sql_metadata
3+
from hamilton.io.utils import SQL_METADATA, get_sql_metadata
44

55

66
def test_get_sql_metadata():
77
results = 5
88
table = "foo"
99
query = "SELECT foo FROM bar"
1010
df = pd.DataFrame({"foo": ["bar"]})
11-
metadata1 = get_sql_metadata(table, df)
12-
metadata2 = get_sql_metadata(query, results)
13-
metadata3 = get_sql_metadata(query, "foo")
11+
metadata1 = get_sql_metadata(table, df)[SQL_METADATA]
12+
metadata2 = get_sql_metadata(query, results)[SQL_METADATA]
13+
metadata3 = get_sql_metadata(query, "foo")[SQL_METADATA]
1414
assert metadata1["table_name"] == table
1515
assert metadata1["rows"] == 1
1616
assert metadata2["query"] == query

tests/plugins/test_lightgbm_extensions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pytest
77

8+
from hamilton.io.utils import FILE_METADATA
89
from hamilton.plugins.lightgbm_extensions import LightGBMFileReader, LightGBMFileWriter
910

1011

@@ -40,7 +41,7 @@ def test_lightgbm_file_writer(
4041
metadata = writer.save_data(fitted_lightgbm)
4142

4243
assert model_path.exists()
43-
assert metadata["path"] == model_path
44+
assert metadata[FILE_METADATA]["path"] == model_path
4445

4546

4647
@pytest.mark.parametrize(

tests/plugins/test_matplotlib_extensions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import matplotlib.pyplot as plt
55
import pytest
66

7+
from hamilton.io.utils import FILE_METADATA
78
from hamilton.plugins.matplotlib_extensions import MatplotlibWriter
89

910

@@ -27,4 +28,4 @@ def test_plotly_static_writer(figure: matplotlib.figure.Figure, tmp_path: pathli
2728
metadata = writer.save_data(figure)
2829

2930
assert file_path.exists()
30-
assert metadata["path"] == file_path
31+
assert metadata[FILE_METADATA]["path"] == file_path

tests/plugins/test_numpy_extensions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55

6+
from hamilton.io.utils import FILE_METADATA
67
from hamilton.plugins.numpy_extensions import NumpyNpyReader, NumpyNpyWriter
78

89

@@ -18,7 +19,7 @@ def test_numpy_file_writer(array: np.ndarray, tmp_path: pathlib.Path) -> None:
1819
metadata = writer.save_data(array)
1920

2021
assert file_path.exists()
21-
assert metadata["path"] == file_path
22+
assert metadata[FILE_METADATA]["path"] == file_path
2223

2324

2425
def test_numpy_file_reader(array: np.ndarray, tmp_path: pathlib.Path) -> None:

tests/plugins/test_plotly_extensions.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import plotly.graph_objects as go
44
import pytest
55

6+
from hamilton.io.utils import FILE_METADATA
67
from hamilton.plugins.plotly_extensions import PlotlyInteractiveWriter, PlotlyStaticWriter
78

89

@@ -18,7 +19,7 @@ def test_plotly_static_writer(figure: go.Figure, tmp_path: pathlib.Path) -> None
1819
metadata = writer.save_data(figure)
1920

2021
assert file_path.exists()
21-
assert metadata["path"] == file_path
22+
assert metadata[FILE_METADATA]["path"] == file_path
2223

2324

2425
def test_plotly_interactive_writer(figure: go.Figure, tmp_path: pathlib.Path) -> None:
@@ -28,4 +29,4 @@ def test_plotly_interactive_writer(figure: go.Figure, tmp_path: pathlib.Path) ->
2829
metadata = writer.save_data(figure)
2930

3031
assert file_path.exists()
31-
assert metadata["path"] == file_path
32+
assert metadata[FILE_METADATA]["path"] == file_path

tests/plugins/test_sklearn_plot_extensions.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.svm import SVC
1616
from sklearn.tree import DecisionTreeClassifier
1717

18+
from hamilton.io.utils import FILE_METADATA
1819
from hamilton.plugins.sklearn_plot_extensions import SklearnPlotSaver
1920

2021
if hasattr(metrics, "PredictionErrorDisplay"):
@@ -191,7 +192,7 @@ def test_cm_plot_saver(
191192
metadata = writer.save_data(confusion_matrix_display)
192193

193194
assert plot_path.exists()
194-
assert metadata["path"] == plot_path
195+
assert metadata[FILE_METADATA]["path"] == plot_path
195196

196197

197198
def test_det_curve_display(
@@ -203,7 +204,7 @@ def test_det_curve_display(
203204
metadata = writer.save_data(det_curve_display)
204205

205206
assert plot_path.exists()
206-
assert metadata["path"] == plot_path
207+
assert metadata[FILE_METADATA]["path"] == plot_path
207208

208209

209210
def test_precision_recall_display(
@@ -215,7 +216,7 @@ def test_precision_recall_display(
215216
metadata = writer.save_data(precision_recall_display)
216217

217218
assert plot_path.exists()
218-
assert metadata["path"] == plot_path
219+
assert metadata[FILE_METADATA]["path"] == plot_path
219220

220221

221222
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
@@ -228,7 +229,7 @@ def test_prediction_error_display(
228229
metadata = writer.save_data(prediction_error_display)
229230

230231
assert plot_path.exists()
231-
assert metadata["path"] == plot_path
232+
assert metadata[FILE_METADATA]["path"] == plot_path
232233

233234

234235
def test_roc_curve_display(
@@ -240,7 +241,7 @@ def test_roc_curve_display(
240241
metadata = writer.save_data(roc_curve_display)
241242

242243
assert plot_path.exists()
243-
assert metadata["path"] == plot_path
244+
assert metadata[FILE_METADATA]["path"] == plot_path
244245

245246

246247
def test_calibration_display(
@@ -252,7 +253,7 @@ def test_calibration_display(
252253
metadata = writer.save_data(calibration_display)
253254

254255
assert plot_path.exists()
255-
assert metadata["path"] == plot_path
256+
assert metadata[FILE_METADATA]["path"] == plot_path
256257

257258

258259
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
@@ -265,7 +266,7 @@ def test_decision_boundary_display(
265266
metadata = writer.save_data(decision_boundary_display)
266267

267268
assert plot_path.exists()
268-
assert metadata["path"] == plot_path
269+
assert metadata[FILE_METADATA]["path"] == plot_path
269270

270271

271272
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
@@ -278,7 +279,7 @@ def test_partial_dependence_display(
278279
metadata = writer.save_data(partial_dependence_display)
279280

280281
assert plot_path.exists()
281-
assert metadata["path"] == plot_path
282+
assert metadata[FILE_METADATA]["path"] == plot_path
282283

283284

284285
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
@@ -291,7 +292,7 @@ def test_learning_curve_display(
291292
metadata = writer.save_data(learning_curve_display)
292293

293294
assert plot_path.exists()
294-
assert metadata["path"] == plot_path
295+
assert metadata[FILE_METADATA]["path"] == plot_path
295296

296297

297298
@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
@@ -304,4 +305,4 @@ def test_validation_curve_display(
304305
metadata = writer.save_data(validation_curve_display)
305306

306307
assert plot_path.exists()
307-
assert metadata["path"] == plot_path
308+
assert metadata[FILE_METADATA]["path"] == plot_path

tests/plugins/test_xgboost_extensions.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import xgboost
55
from sklearn.utils.validation import check_is_fitted
66

7+
from hamilton.io.utils import FILE_METADATA
78
from hamilton.plugins.xgboost_extensions import XGBoostJsonReader, XGBoostJsonWriter
89

910

@@ -30,7 +31,7 @@ def test_xgboost_model_json_writer(
3031
metadata = writer.save_data(fitted_xgboost_model)
3132

3233
assert model_path.exists()
33-
assert metadata["path"] == model_path
34+
assert metadata[FILE_METADATA]["path"] == model_path
3435

3536

3637
def test_xgboost_model_json_reader(
@@ -55,7 +56,7 @@ def test_xgboost_booster_json_writer(
5556
metadata = writer.save_data(fitted_xgboost_booster)
5657

5758
assert booster_path.exists()
58-
assert metadata["path"] == booster_path
59+
assert metadata[FILE_METADATA]["path"] == booster_path
5960

6061

6162
def test_xgboost_booster_json_reader(

0 commit comments

Comments
 (0)