Skip to content

Commit 189797f

Browse files
committed
Fixes new formatting complaints
1 parent ba1c473 commit 189797f

17 files changed

+110
-40
lines changed

contrib/docs/compile_docs.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
dataflow python files and information we have.
1111
6. We then will trigger a build of the docs; the docs can serve the latest commit version!
1212
"""
13+
1314
import json
1415
import os
1516
import shutil

hamilton/io/materialization.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def generate_nodes(self, fn_graph: graph.FunctionGraph) -> List[node.Node]:
155155

156156
class MaterializerFactory:
157157
"""Basic factory for creating materializers. Note that this should only ever be instantiated
158-
through `to.<name>`, which conducts polymorphic lookup to find the appropriate materializer."""
158+
through `to.<name>`, which conducts polymorphic lookup to find the appropriate materializer.
159+
"""
159160

160161
def __init__(
161162
self,
@@ -193,7 +194,11 @@ def sanitize_dependencies(self, module_set: Set[str]) -> "MaterializerFactory":
193194
"""
194195
final_vars = common.convert_output_values(self.dependencies, module_set)
195196
return MaterializerFactory(
196-
self.id, self.savers, self.result_builder, final_vars, **self.data_saver_kwargs
197+
self.id,
198+
self.savers,
199+
self.result_builder,
200+
final_vars,
201+
**self.data_saver_kwargs,
197202
)
198203

199204
def _resolve_dependencies(self, fn_graph: graph.FunctionGraph) -> List[node.Node]:
@@ -241,9 +246,9 @@ def join_function(**kwargs):
241246
doc_string=f"Builds the result for {self.id} materializer",
242247
callabl=join_function,
243248
input_types={dep.name: dep.type for dep in node_dependencies},
244-
originating_functions=None
245-
if self.result_builder is None
246-
else [self.result_builder.build_result],
249+
originating_functions=(
250+
None if self.result_builder is None else [self.result_builder.build_result]
251+
),
247252
)
248253
out.append(join_node)
249254
save_dep = join_node
@@ -268,13 +273,13 @@ def __call__(
268273
combine: lifecycle.ResultBuilder = None,
269274
**kwargs: Union[str, SingleDependency],
270275
) -> MaterializerFactory:
271-
...
276+
pass
272277

273278

274279
@typing.runtime_checkable
275280
class _ExtractorFactoryProtocol(Protocol):
276281
def __call__(self, target: str, **kwargs: Union[str, SingleDependency]) -> ExtractorFactory:
277-
...
282+
pass
278283

279284

280285
def partial_materializer(data_savers: List[Type[DataSaver]]) -> _MaterializerFactoryProtocol:
@@ -297,7 +302,9 @@ def create_materializer_factory(
297302
return create_materializer_factory
298303

299304

300-
def partial_extractor(data_loaders: List[Type[DataLoader]]) -> _ExtractorFactoryProtocol:
305+
def partial_extractor(
306+
data_loaders: List[Type[DataLoader]],
307+
) -> _ExtractorFactoryProtocol:
301308
"""Creates a partial materializer, with the specified data savers."""
302309

303310
def create_extractor_factory(

hamilton/plugins/h_spark.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,10 @@ def new_callable(__callable=node_.callable, **kwargs) -> Any:
885885
if key != transformation_target and key not in dependent_columns_from_dataframe
886886
}
887887
# Thus we put that linear dependency in
888-
new_input_types[linear_df_dependency_name] = (DataFrame, node.DependencyType.REQUIRED)
888+
new_input_types[linear_df_dependency_name] = (
889+
DataFrame,
890+
node.DependencyType.REQUIRED,
891+
)
889892
# Then we go through all "logical" dependencies -- columns we want to add to make lineage
890893
# look nice
891894
for item in dependent_columns_from_upstream:
@@ -1191,7 +1194,9 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
11911194
self.select if self.select is not None else [item.name for item in output_nodes]
11921195
)
11931196
select_node = with_columns.create_selector_node(
1194-
upstream_name=current_dataframe_node, columns=select_columns, node_name="_select"
1197+
upstream_name=current_dataframe_node,
1198+
columns=select_columns,
1199+
node_name="_select",
11951200
)
11961201
output_nodes.append(select_node)
11971202
current_dataframe_node = select_node.name

hamilton/plugins/numpy_extensions.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ class NumpyNpyWriter(DataSaver):
2626

2727
def save_data(self, data: np.ndarray) -> Dict[str, Any]:
2828
np.save(
29-
file=self.path, arr=data, allow_pickle=self.allow_pickle, fix_imports=self.fix_imports
29+
file=self.path,
30+
arr=data,
31+
allow_pickle=self.allow_pickle,
32+
fix_imports=self.fix_imports,
3033
)
3134
return utils.get_file_metadata(self.path)
3235

hamilton/plugins/pandas_extensions.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ class PandasCSVReader(DataLoader):
136136
comment: Optional[str] = None
137137
encoding: str = "utf-8"
138138
encoding_errors: Union[
139-
Literal["strict", "ignore", "replace", "backslashreplace", "surrogateescape"], str
139+
Literal["strict", "ignore", "replace", "backslashreplace", "surrogateescape"],
140+
str,
140141
] = "strict"
141142
dialect: Optional[Union[str, csv.Dialect]] = None
142143
on_bad_lines: Union[Literal["error", "warn", "skip"], Callable] = "error"
@@ -446,9 +447,9 @@ class PandasPickleReader(DataLoader):
446447
"""
447448

448449
filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] = None
449-
path: Union[
450-
str, Path, BytesIO, BufferedReader
451-
] = None # alias for `filepath_or_buffer` to keep reading/writing args symmetric.
450+
path: Union[str, Path, BytesIO, BufferedReader] = (
451+
None # alias for `filepath_or_buffer` to keep reading/writing args symmetric.
452+
)
452453
# kwargs:
453454
compression: Union[str, Dict[str, Any], None] = "infer"
454455
storage_options: Optional[Dict[str, Any]] = None
@@ -732,7 +733,10 @@ def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
732733
df = pd.read_sql(self.query_or_table, self.db_connection, **self._get_loading_kwargs())
733734
sql_metadata = utils.get_sql_metadata(self.query_or_table, df)
734735
df_metadata = utils.get_dataframe_metadata(df)
735-
metadata = {utils.SQL_METADATA: sql_metadata, utils.DATAFRAME_METADATA: df_metadata}
736+
metadata = {
737+
utils.SQL_METADATA: sql_metadata,
738+
utils.DATAFRAME_METADATA: df_metadata,
739+
}
736740
return df, metadata
737741

738742
@classmethod
@@ -789,7 +793,10 @@ def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
789793
results = data.to_sql(self.table_name, self.db_connection, **self._get_saving_kwargs())
790794
sql_metadata = utils.get_sql_metadata(self.table_name, results)
791795
df_metadata = utils.get_dataframe_metadata(data)
792-
metadata = {utils.SQL_METADATA: sql_metadata, utils.DATAFRAME_METADATA: df_metadata}
796+
metadata = {
797+
utils.SQL_METADATA: sql_metadata,
798+
utils.DATAFRAME_METADATA: df_metadata,
799+
}
793800
return metadata
794801

795802
@classmethod

hamilton/plugins/plotly_extensions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ class PlotlyInteractiveWriter(DataSaver):
6565
path: Union[str, pathlib.Path, IO]
6666
config: Optional[Dict] = None
6767
auto_play: bool = True
68-
include_plotlyjs: Union[
69-
bool, str
70-
] = True # or "cdn", "directory", "require", "False", "other string .js"
68+
include_plotlyjs: Union[bool, str] = (
69+
True # or "cdn", "directory", "require", "False", "other string .js"
70+
)
7171
include_mathjax: Union[bool, str] = False # "cdn", "string .js"
7272
post_script: Union[str, List[str], None] = None
7373
full_html: bool = True

plugin_tests/h_dask/test_h_dask.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def test_smoke_screen_module(client):
105105
),
106106
# dataframe_and_series
107107
(
108-
{"a": pd.Series([1, 2, 3]), "b": pd.DataFrame({"b": [1, 2, 3], "c": [1, 1, 1]})},
108+
{
109+
"a": pd.Series([1, 2, 3]),
110+
"b": pd.DataFrame({"b": [1, 2, 3], "c": [1, 1, 1]}),
111+
},
109112
pd.DataFrame({"a": [1, 2, 3], "b.b": [1, 2, 3], "b.c": [1, 1, 1]}),
110113
),
111114
# multiple_series_and_scalar

plugin_tests/h_ray/test_h_ray.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def test_ray_graph_adapter(init):
2525
"spend": pd.Series([10, 10, 20, 40, 40, 50]),
2626
}
2727
dr = driver.Driver(
28-
initial_columns, example_module, adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult())
28+
initial_columns,
29+
example_module,
30+
adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult()),
2931
)
3032
output_columns = [
3133
"spend",
@@ -47,7 +49,9 @@ def test_ray_graph_adapter(init):
4749
def test_smoke_screen_module(init):
4850
config = {"region": "US"}
4951
dr = driver.Driver(
50-
config, smoke_screen_module, adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult())
52+
config,
53+
smoke_screen_module,
54+
adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult()),
5155
)
5256
output_columns = [
5357
"raw_acquisition_cost",

plugin_tests/h_spark/test_h_spark.py

+48-15
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def test_koalas_spark_graph_adapter(spark_session):
4747
initial_columns,
4848
example_module,
4949
adapter=h_spark.SparkKoalasGraphAdapter(
50-
spark_session, result_builder=base.PandasDataFrameResult(), spine_column="spend"
50+
spark_session,
51+
result_builder=base.PandasDataFrameResult(),
52+
spine_column="spend",
5153
),
5254
)
5355
output_columns = [
@@ -79,7 +81,9 @@ def test_smoke_screen_module(spark_session):
7981
config,
8082
smoke_screen_module,
8183
adapter=h_spark.SparkKoalasGraphAdapter(
82-
spark_session, result_builder=base.PandasDataFrameResult(), spine_column="weeks"
84+
spark_session,
85+
result_builder=base.PandasDataFrameResult(),
86+
spine_column="weeks",
8387
),
8488
)
8589
output_columns = [
@@ -110,7 +114,12 @@ def test_smoke_screen_module(spark_session):
110114
(lambda df: ({"a": df}, (df, {}))),
111115
(lambda df: ({"a": df, "b": 1}, (df, {"b": 1}))),
112116
],
113-
ids=["no_kwargs", "one_plain_kwarg", "one_df_kwarg", "one_df_kwarg_and_one_plain_kwarg"],
117+
ids=[
118+
"no_kwargs",
119+
"one_plain_kwarg",
120+
"one_df_kwarg",
121+
"one_df_kwarg_and_one_plain_kwarg",
122+
],
114123
)
115124
def test__inspect_kwargs(input_and_expected_fn, spark_session):
116125
"""A unit test for inspect_kwargs."""
@@ -230,7 +239,11 @@ def base_func(a: int, b: int) -> int:
230239
base_spark_df = spark_session.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
231240
node_ = node.Node.from_fn(base_func)
232241
new_df = h_spark._lambda_udf(base_spark_df, node_, {})
233-
assert new_df.collect() == [Row(a=1, b=4, test=5), Row(a=2, b=5, test=7), Row(a=3, b=6, test=9)]
242+
assert new_df.collect() == [
243+
Row(a=1, b=4, test=5),
244+
Row(a=2, b=5, test=7),
245+
Row(a=3, b=6, test=9),
246+
]
234247

235248

236249
def test__lambda_udf_pandas_func(spark_session):
@@ -243,7 +256,11 @@ def base_func(a: pd.Series, b: pd.Series) -> htypes.column[pd.Series, int]:
243256
node_ = node.Node.from_fn(base_func)
244257

245258
new_df = h_spark._lambda_udf(base_spark_df, node_, {})
246-
assert new_df.collect() == [Row(a=1, b=4, test=5), Row(a=2, b=5, test=7), Row(a=3, b=6, test=9)]
259+
assert new_df.collect() == [
260+
Row(a=1, b=4, test=5),
261+
Row(a=2, b=5, test=7),
262+
Row(a=3, b=6, test=9),
263+
]
247264

248265

249266
def test__lambda_udf_pandas_func_error(spark_session):
@@ -348,11 +365,13 @@ def test_get_spark_type_numpy_types(return_type, expected_spark_type):
348365

349366
# 4. Unsupported types
350367
@pytest.mark.parametrize(
351-
"unsupported_return_type", [dict, set, tuple] # Add other unsupported types as needed
368+
"unsupported_return_type",
369+
[dict, set, tuple], # Add other unsupported types as needed
352370
)
353371
def test_get_spark_type_unsupported(unsupported_return_type):
354372
with pytest.raises(
355-
ValueError, match=f"Currently unsupported return type {unsupported_return_type}."
373+
ValueError,
374+
match=f"Currently unsupported return type {unsupported_return_type}.",
356375
):
357376
h_spark.get_spark_type(unsupported_return_type)
358377

@@ -470,19 +489,19 @@ def test_base_spark_executor_end_to_end_multiple_with_columns(spark_session):
470489

471490

472491
def _only_pyspark_dataframe_parameter(foo: DataFrame) -> DataFrame:
473-
...
492+
pass
474493

475494

476495
def _no_pyspark_dataframe_parameter(foo: int) -> int:
477-
...
496+
pass
478497

479498

480499
def _one_pyspark_dataframe_parameter(foo: DataFrame, bar: int) -> DataFrame:
481-
...
500+
pass
482501

483502

484503
def _two_pyspark_dataframe_parameters(foo: DataFrame, bar: int, baz: DataFrame) -> DataFrame:
485-
...
504+
pass
486505

487506

488507
@pytest.mark.parametrize(
@@ -603,7 +622,11 @@ def df_as_pandas(df: DataFrame) -> pd.DataFrame:
603622

604623
nodes = dec.generate_nodes(df_as_pandas, {})
605624
nodes_by_names = {n.name: n for n in nodes}
606-
assert set(nodes_by_names.keys()) == {"df_as_pandas.c", "df_as_pandas", "df_as_pandas._select"}
625+
assert set(nodes_by_names.keys()) == {
626+
"df_as_pandas.c",
627+
"df_as_pandas",
628+
"df_as_pandas._select",
629+
}
607630

608631

609632
def test_with_columns_generate_nodes_specify_namespace():
@@ -640,7 +663,10 @@ def test__format_standard_udf():
640663

641664
def test_sparkify_node():
642665
def foo(
643-
a_from_upstream: pd.Series, b_from_upstream: pd.Series, c_from_df: pd.Series, d_fixed: int
666+
a_from_upstream: pd.Series,
667+
b_from_upstream: pd.Series,
668+
c_from_df: pd.Series,
669+
d_fixed: int,
644670
) -> htypes.column[pd.Series, int]:
645671
return a_from_upstream + b_from_upstream + c_from_df + d_fixed
646672

@@ -679,7 +705,10 @@ def test_pyspark_mixed_pandas_udfs_end_to_end():
679705
# inputs={"spark_session": spark_session},
680706
# )
681707
results = dr.execute(
682-
["processed_df_as_pandas_dataframe_with_injected_dataframe", "processed_df_as_pandas"],
708+
[
709+
"processed_df_as_pandas_dataframe_with_injected_dataframe",
710+
"processed_df_as_pandas",
711+
],
683712
inputs={"spark_session": spark_session},
684713
)
685714
processed_df_as_pandas = results["processed_df_as_pandas"]
@@ -774,7 +803,11 @@ def test_create_selector_node(spark_session):
774803
selector_node = h_spark.with_columns.create_selector_node("foo", ["a", "b"], "select")
775804
assert selector_node.name == "select"
776805
pandas_df = pd.DataFrame(
777-
{"a": [10, 10, 20, 40, 40, 50], "b": [1, 10, 50, 100, 200, 400], "c": [1, 2, 3, 4, 5, 6]}
806+
{
807+
"a": [10, 10, 20, 40, 40, 50],
808+
"b": [1, 10, 50, 100, 200, 400],
809+
"c": [1, 2, 3, 4, 5, 6],
810+
}
778811
)
779812
df = spark_session.createDataFrame(pandas_df)
780813
transformed = selector_node(foo=df).toPandas()

tests/function_modifiers/test_combined.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
it is useful to have a few tests that demonstrate that common use-cases are supported.
44
55
Note we also have some more end-to-end cases in test_layered.py"""
6+
67
from typing import Dict
78

89
import pandas as pd

tests/resources/bad_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Module for more dummy functions to test graph things with.
33
"""
4+
45
# we import this to check we don't pull in this function when parsing this module.
56
from tests.resources import only_import_me # noqa: F401
67

tests/resources/cyclic_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Module for cyclic functions to test graph things with.
33
"""
4+
45
# we import this to check we don't pull in this function when parsing this module.
56
from tests.resources import only_import_me # noqa: F401
67

tests/resources/dummy_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Module for dummy functions to test graph things with.
33
"""
4+
45
# we import this to check we don't pull in this function when parsing this module.
56
from tests.resources import only_import_me
67

tests/resources/functions_with_generics.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Module for functions with genercis to test graph things with.
33
"""
4+
45
from typing import Dict, List, Mapping, Tuple
56

67

tests/resources/smoke_screen_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
neutral_net_acquisition_cost
2222
optimistic_net_acquisition_cost
2323
"""
24+
2425
from typing import Dict
2526

2627
import numpy as np

0 commit comments

Comments
 (0)