Skip to content

Commit d9bfa4f

Browse files
committed
refactoring transformation function engine
1 parent e87aab3 commit d9bfa4f

2 files changed

Lines changed: 68 additions & 106 deletions

File tree

python/.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@ repos:
66
- id: ruff
77
args: [--fix]
88
- id: ruff-format
9+
- repo: https://github.com/jshwi/docsig
10+
rev: v0.80.0
11+
hooks:
12+
- id: docsig
13+
files: ^(hopsworks_common|hopsworks|hsfs|hsml)/
14+
types: [python]

python/hsfs/core/transformation_function_engine.py

Lines changed: 62 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -141,44 +141,36 @@ def __str__(self) -> str:
141141
if not self.nodes:
142142
return "Transformation Execution DAG (empty)"
143143

144+
arrow = " \u2502\n \u25bc"
144145
_, raw_inputs, tf_outputs, pass_through_features = self._build_lookups()
145146

146-
lines = [
147-
"Transformation Execution DAG",
148-
"\u2550" * 35,
149-
"",
150-
]
147+
lines = ["Transformation Execution DAG", "\u2550" * 35, ""]
151148

152149
if raw_inputs:
153150
lines.append(f"Input Features: {', '.join(sorted(raw_inputs))}")
154-
lines.append(" \u2502")
155-
lines.append(" \u25bc")
151+
lines.append(arrow)
156152

157153
depths = self._compute_depths()
158-
159-
# Group TFs by depth level
160154
levels: dict[int, list] = {}
161155
for tf in self.nodes:
162156
levels.setdefault(depths[id(tf)], []).append(tf)
163157

164-
for level_idx in sorted(levels.keys()):
158+
max_level = max(levels)
159+
for level_idx in sorted(levels):
165160
for tf in levels[level_idx]:
166161
udf = tf.hopsworks_udf
167162
header = f" {udf.function_name} (mode: {udf.execution_mode.value})"
168163
if udf.dropped_features:
169164
header += f" [drops: {', '.join(udf.dropped_features)}]"
170165
lines.append(header)
166+
if level_idx < max_level:
167+
lines.append(arrow)
171168

172-
if level_idx < max(levels.keys()):
173-
lines.append(" \u2502")
174-
lines.append(" \u25bc")
175-
176-
# Output features = all TF outputs + pass-through input features
177-
tf_output_cols = [col for cols in tf_outputs.values() for col in cols]
178-
all_output_cols = pass_through_features + tf_output_cols
169+
all_output_cols = pass_through_features + [
170+
col for cols in tf_outputs.values() for col in cols
171+
]
179172
if all_output_cols:
180-
lines.append(" \u2502")
181-
lines.append(" \u25bc")
173+
lines.append(arrow)
182174
lines.append(f"Output Features: {', '.join(all_output_cols)}")
183175

184176
return "\n".join(lines)
@@ -266,39 +258,22 @@ def _to_graphviz(self, orient: str = "TB") -> graphviz.Digraph:
266258
edge_attr={"fontname": "Helvetica", "fontsize": "9"},
267259
)
268260

261+
io_style = {"shape": "rectangle", "style": "filled", "margin": "0.2"}
269262
if raw_inputs:
270263
dot.node(
271-
"input",
272-
"<<b>Input Features</b>>",
273-
shape="rectangle",
274-
style="filled",
275-
fillcolor="#E8F4FD",
276-
margin="0.2",
264+
"input", "<<b>Input Features</b>>", fillcolor="#E8F4FD", **io_style
277265
)
278266

279-
# TF nodes
280267
for tf in self.nodes:
281268
udf = tf.hopsworks_udf
282-
label = (
283-
f"<<b>{udf.function_name}</b>"
284-
f"<br/><i>mode: {udf.execution_mode.value}</i>"
285-
)
269+
label = f"<<b>{udf.function_name}</b><br/><i>mode: {udf.execution_mode.value}</i>"
286270
if udf.dropped_features:
287-
dropped = ", ".join(udf.dropped_features)
288-
label += f'<br/><font color="#EA5556">drops: {dropped}</font>'
289-
label += ">"
290-
dot.node(str(id(tf)), label)
291-
292-
# Output node: shown if there are TF outputs or pass-through features
293-
has_outputs = tf_outputs or pass_through_features
294-
if has_outputs:
271+
label += f'<br/><font color="#EA5556">drops: {", ".join(udf.dropped_features)}</font>'
272+
dot.node(str(id(tf)), label + ">")
273+
274+
if tf_outputs or pass_through_features:
295275
dot.node(
296-
"output",
297-
"<<b>Output Features</b>>",
298-
shape="rectangle",
299-
style="filled",
300-
fillcolor="#D4EDDA",
301-
margin="0.2",
276+
"output", "<<b>Output Features</b>>", fillcolor="#D4EDDA", **io_style
302277
)
303278

304279
# Edges: Input -> TF
@@ -307,8 +282,8 @@ def _to_graphviz(self, orient: str = "TB") -> graphviz.Digraph:
307282
for feat in tf.hopsworks_udf.transformation_features:
308283
if feat not in output_to_tf:
309284
input_edges.setdefault(id(tf), []).append(feat)
310-
for tf_id, features in input_edges.items():
311-
dot.edge("input", str(tf_id), label=", ".join(features))
285+
for tf_id, feats in input_edges.items():
286+
dot.edge("input", str(tf_id), label=", ".join(feats))
312287

313288
# Edges: Input -> Output (pass-through features not dropped by any TF)
314289
if pass_through_features:
@@ -570,8 +545,9 @@ def apply_transformation_functions(
570545
# PYTHON PATH — pre-compute metadata
571546
# ============================================================
572547
dropped_features: set[str] = set()
573-
if is_dataframe:
574-
column_order = list(data.columns)
548+
# Collect all TF output columns in topo order for final column ordering.
549+
tf_output_cols: list[str] = []
550+
tf_output_set: set[str] = set()
575551

576552
for tf in execution_graph.nodes:
577553
udf = tf.hopsworks_udf
@@ -582,45 +558,43 @@ def apply_transformation_functions(
582558
if expected_features
583559
else udf.dropped_features
584560
)
585-
if is_dataframe:
586-
for col in udf.output_column_names:
587-
if col in column_order:
588-
column_order.remove(col)
589-
column_order.append(col)
561+
for col in udf.output_column_names:
562+
tf_output_set.add(col)
563+
tf_output_cols.append(col)
564+
565+
if is_dataframe:
566+
# Original columns (minus those overwritten by TFs) + TF outputs in topo order
567+
column_order = [
568+
c for c in data.columns if c not in tf_output_set
569+
] + tf_output_cols
590570
if request_parameters:
591571
data = TransformationFunctionEngine._update_request_parameter_data(
592572
data, request_parameters
593573
)
594574

595575
# --- Dict/list: sequential, topo order, in-place update ---
596576
if isinstance(data, (dict, list)):
597-
if isinstance(data, list):
598-
transformed_data = [row.copy() for row in data]
599-
for tf in execution_graph.nodes:
600-
for row in transformed_data:
601-
result = TransformationFunctionEngine.execute_udf(
577+
rows = (
578+
[row.copy() for row in data]
579+
if isinstance(data, list)
580+
else [data.copy()]
581+
)
582+
eng_type = engine.get_type()
583+
for tf in execution_graph.nodes:
584+
for row in rows:
585+
row.update(
586+
TransformationFunctionEngine.execute_udf(
602587
udf=tf.hopsworks_udf,
603588
data=row,
604589
online=online,
605-
engine_type=engine.get_type(),
590+
engine_type=eng_type,
606591
)
607-
row.update(result)
608-
return [
609-
{k: v for k, v in row.items() if k not in dropped_features}
610-
for row in transformed_data
611-
]
612-
transformed_data = data.copy()
613-
for tf in execution_graph.nodes:
614-
result = TransformationFunctionEngine.execute_udf(
615-
udf=tf.hopsworks_udf,
616-
data=transformed_data,
617-
online=online,
618-
engine_type=engine.get_type(),
619-
)
620-
transformed_data.update(result)
621-
return {
622-
k: v for k, v in transformed_data.items() if k not in dropped_features
623-
}
592+
)
593+
cleaned = [
594+
{k: v for k, v in row.items() if k not in dropped_features}
595+
for row in rows
596+
]
597+
return cleaned if isinstance(data, list) else cleaned[0]
624598

625599
# --- DataFrame: sequential (n_processes==1) or parallel DAG ---
626600
column_store = {} # col_name -> Series (accumulated results from completed TFs)
@@ -869,20 +843,13 @@ def execute_udf(
869843
shm_name, shm_size, is_polars
870844
)
871845
if columns:
872-
if predecessor_columns:
873-
# Build input: base columns overridden by predecessor results
874-
col_data = {}
875-
for col in columns:
876-
col_data[col] = (
877-
predecessor_columns[col]
878-
if col in predecessor_columns
879-
else data[col]
880-
)
881-
data = (
882-
pl.DataFrame(col_data) if is_polars else pd.DataFrame(col_data)
883-
)
884-
else:
885-
data = data[columns]
846+
col_data = {
847+
c: predecessor_columns[c]
848+
if predecessor_columns and c in predecessor_columns
849+
else data[c]
850+
for c in columns
851+
}
852+
data = pl.DataFrame(col_data) if is_polars else pd.DataFrame(col_data)
886853

887854
# Check dict/list first — these are the dominant types on the online
888855
# serving hot path and avoid the multiple isinstance() checks inside
@@ -938,23 +905,12 @@ def apply_udf_on_dict(
938905
== UDFExecutionMode.PANDAS
939906
)
940907

941-
# Pre-compute prefix and feature list once to avoid repeated property
942-
# access and string concatenation inside the loop.
943908
prefix = udf.feature_name_prefix
944-
945-
if is_pandas_mode:
946-
features = []
947-
for feat in udf.unprefixed_transformation_features:
948-
feature_name = prefix + feat if prefix else feat
949-
val = data[feature_name] if feature_name in data else data[feat]
950-
features.append(pd.Series([val], name=feat))
951-
else:
952-
features = []
953-
for feat in udf.unprefixed_transformation_features:
954-
feature_name = prefix + feat if prefix else feat
955-
features.append(
956-
data[feature_name] if feature_name in data else data[feat]
957-
)
909+
features = []
910+
for feat in udf.unprefixed_transformation_features:
911+
feature_name = prefix + feat if prefix else feat
912+
val = data[feature_name] if feature_name in data else data[feat]
913+
features.append(pd.Series([val], name=feat) if is_pandas_mode else val)
958914

959915
transformed_result = udf.get_udf(online=online, engine_type=engine_type)(
960916
*features

0 commit comments

Comments
 (0)