Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,17 @@ def _validate_dtypes(self, node, output_data):
def _combine_node_outputs(self, node, transformed, output):
node_output_cols = _get_unique(node.output_schema.column_names)

# dask needs output to be in the same order defined as meta, reorder partitions here
# this also selects columns (handling the case of removing columns from the output using
# "-" overload)
if isinstance(transformed, dict):
selected_cols = {col: transformed[col] for col in node_output_cols}
else:
# dask needs output columns to be defined in the same order as meta,
# so reorder columns here
selected_cols = transformed[node_output_cols]

if output is None:
output = transformed[node_output_cols]
output = selected_cols
else:
output = concat_columns([output, transformed[node_output_cols]])
output = concat_columns([output, selected_cols])

return output

Expand Down Expand Up @@ -390,7 +394,10 @@ def _transform_impl(self, dataset: Dataset, graph: Graph, capture_dtypes=False):

return Dataset(
self._executor.transform(
ddf, graph.output_node, graph.output_dtypes, capture_dtypes=capture_dtypes
ddf,
graph.output_node,
graph.output_dtypes,
capture_dtypes=capture_dtypes,
),
cpu=dataset.cpu,
base_dataset=dataset.base_dataset,
Expand Down Expand Up @@ -546,27 +553,25 @@ def _mask_cpu_only(supported):


def _data_format(transformable):
data = TensorTable(transformable) if isinstance(transformable, dict) else transformable

if cudf and isinstance(data, cudf.DataFrame):
if cudf and isinstance(transformable, cudf.DataFrame):
return DataFormats.CUDF_DATAFRAME
elif pandas and isinstance(data, pandas.DataFrame):
elif pandas and isinstance(transformable, pandas.DataFrame):
return DataFormats.PANDAS_DATAFRAME
elif isinstance(data, dict) and data.values():
first = list(data.values())[0]
if cupy and first and isinstance(first, cupy.ndarray):
elif isinstance(transformable, dict) and transformable.values():
first = list(transformable.values())[0]
if cupy and isinstance(first, cupy.ndarray):
return DataFormats.CUPY_DICT_ARRAY
if numpy and first and isinstance(first, numpy.ndarray):
if numpy and isinstance(first, numpy.ndarray):
return DataFormats.NUMPY_DICT_ARRAY
elif data.column_type is CupyColumn:
elif transformable.column_type is CupyColumn:
return DataFormats.CUPY_TENSOR_TABLE
elif data.column_type is NumpyColumn:
elif transformable.column_type is NumpyColumn:
return DataFormats.NUMPY_TENSOR_TABLE
else:
if isinstance(data, TensorTable):
raise TypeError(f"Unknown type: {data.column_type}")
if isinstance(transformable, TensorTable):
raise TypeError(f"Unknown type: {transformable.column_type}")
else:
raise TypeError(f"Unknown type: {type(data)}")
raise TypeError(f"Unknown type: {type(transformable)}")


def _convert_format(tensors, target_format):
Expand Down