diff --git a/merlin/dag/executors.py b/merlin/dag/executors.py index adcd15244..08e6877dd 100644 --- a/merlin/dag/executors.py +++ b/merlin/dag/executors.py @@ -273,13 +273,17 @@ def _validate_dtypes(self, node, output_data): def _combine_node_outputs(self, node, transformed, output): node_output_cols = _get_unique(node.output_schema.column_names) - # dask needs output to be in the same order defined as meta, reorder partitions here - # this also selects columns (handling the case of removing columns from the output using - # "-" overload) + if isinstance(transformed, dict): + selected_cols = {col: transformed[col] for col in node_output_cols} + else: + # dask needs output columns to be defined in the same order as meta, + # so reorder columns here + selected_cols = transformed[node_output_cols] + if output is None: - output = transformed[node_output_cols] + output = selected_cols else: - output = concat_columns([output, transformed[node_output_cols]]) + output = concat_columns([output, selected_cols]) return output @@ -390,7 +394,10 @@ def _transform_impl(self, dataset: Dataset, graph: Graph, capture_dtypes=False): return Dataset( self._executor.transform( - ddf, graph.output_node, graph.output_dtypes, capture_dtypes=capture_dtypes + ddf, + graph.output_node, + graph.output_dtypes, + capture_dtypes=capture_dtypes, ), cpu=dataset.cpu, base_dataset=dataset.base_dataset, @@ -546,27 +553,25 @@ def _mask_cpu_only(supported): def _data_format(transformable): - data = TensorTable(transformable) if isinstance(transformable, dict) else transformable - - if cudf and isinstance(data, cudf.DataFrame): + if cudf and isinstance(transformable, cudf.DataFrame): return DataFormats.CUDF_DATAFRAME - elif pandas and isinstance(data, pandas.DataFrame): + elif pandas and isinstance(transformable, pandas.DataFrame): return DataFormats.PANDAS_DATAFRAME - elif isinstance(data, dict) and data.values(): - first = list(data.values())[0] - if cupy and first and isinstance(first, cupy.ndarray): + elif isinstance(transformable, dict) and transformable.values(): + first = list(transformable.values())[0] + if cupy and isinstance(first, cupy.ndarray): return DataFormats.CUPY_DICT_ARRAY - if numpy and first and isinstance(first, numpy.ndarray): + if numpy and isinstance(first, numpy.ndarray): return DataFormats.NUMPY_DICT_ARRAY - elif data.column_type is CupyColumn: + elif transformable.column_type is CupyColumn: return DataFormats.CUPY_TENSOR_TABLE - elif data.column_type is NumpyColumn: + elif transformable.column_type is NumpyColumn: return DataFormats.NUMPY_TENSOR_TABLE else: - if isinstance(data, TensorTable): - raise TypeError(f"Unknown type: {data.column_type}") + if isinstance(transformable, TensorTable): + raise TypeError(f"Unknown type: {transformable.column_type}") else: - raise TypeError(f"Unknown type: {type(data)}") + raise TypeError(f"Unknown type: {type(transformable)}") def _convert_format(tensors, target_format):