Description
System information
- Have I written custom code (as opposed to using a stock example script
provided in TensorFlow Model Analysis): No - OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04.2 LTS Linux 5.4.0-65-generic
- TensorFlow Model Analysis installed from (source or binary):
pip install tensorflow-model-analysis
- TensorFlow Model Analysis version (use command below): 0.41.1
- Python version: Python 3.8.5
- Jupyter Notebook version: NA
Describe the problem
I am currently trying to get tfma.analyze_raw_data
to work with MultiClassConfusionMatrixPlot
which has multiple prediction values per record. Is this not supported? I will be happy to provide any further details or run any further tests.
Details
Currently tfma.analyze_raw_data
does not seem to work with metrics for multi classification tasks (e.g. tfma.metrics.MultiClassConfusionMatrixPlot
). However, I do not see this limitation documented anywhere.
The prediction column for a multi classification column will be a series of whose values are a list or array (e.g.,. pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]})
)
The tfma.analyze_raw_data
funciton uses tfx_bsl.arrow.DataFrameToRecordBatch
to convert a Pandas DataFrame to Arrow RecordBatch. The problem, however, is that it encodes columns with the dtype of object
as a pyarrow.Binary
. Since a column that has lists or arrays as values has a dtype of object
, these columns are being encoded as a pyarrow.Binary
instead of the relevant pyarrow list-like type.
Source code / logs
import tensorflow_model_analysis as tfma
from google.protobuf import text_format
import pandas as pd
eval_config = text_format.Parse("""
## Model information
model_specs {
label_key: "label",
prediction_key: "predictions"
}
## Post training metric information. These will be merged with any built-in
## metrics from training.
metrics_specs {
metrics { class_name: "MultiClassConfusionMatrixPlot" }
}
## Slicing information
slicing_specs {} # overall slice
""", tfma.EvalConfig())
df = pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]})
tfma.analyze_raw_data(df, eval_config)
Error
---------------------------------------------------------------------------
ArrowTypeError Traceback (most recent call last)
/tmp/ipykernel_206830/3947320198.py in <cell line: 23>()
21
22 df = pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]})
---> 23 tfma.analyze_raw_data(df, eval_config)
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/tensorflow_model_analysis/api/model_eval_lib.py in analyze_raw_data(data, eval_config, output_path, add_metric_callbacks)
1511
1512 arrow_data = table_util.CanonicalizeRecordBatch(
-> 1513 table_util.DataFrameToRecordBatch(data))
1514 beam_data = beam.Create([arrow_data])
1515
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/tfx_bsl/arrow/table_util.py in DataFrameToRecordBatch(dataframe)
122 continue
123 arrow_fields.append(pa.field(col_name, arrow_type))
--> 124 return pa.RecordBatch.from_pandas(dataframe, schema=pa.schema(arrow_fields))
125
126
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/table.pxi in pyarrow.lib.RecordBatch.from_pandas()
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in dataframe_to_arrays(df, schema, preserve_index, nthreads, columns, safe)
592
593 if nthreads == 1:
--> 594 arrays = [convert_column(c, f)
595 for c, f in zip(columns_to_convert, convert_fields)]
596 else:
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in <listcomp>(.0)
592
593 if nthreads == 1:
--> 594 arrays = [convert_column(c, f)
595 for c, f in zip(columns_to_convert, convert_fields)]
596 else:
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in convert_column(col, field)
579 e.args += ("Conversion failed for column {!s} with type {!s}"
580 .format(col.name, col.dtype),)
--> 581 raise e
582 if not field_nullable and result.null_count > 0:
583 raise ValueError("Field {} was non-nullable but pandas column "
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in convert_column(col, field)
573
574 try:
--> 575 result = pa.array(col, type=type_, from_pandas=True, safe=safe)
576 except (pa.ArrowInvalid,
577 pa.ArrowNotImplementedError,
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/array.pxi in pyarrow.lib.array()
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/array.pxi in pyarrow.lib._ndarray_to_array()
/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/error.pxi in pyarrow.lib.check_status()
ArrowTypeError: ("Expected bytes, got a 'list' object", 'Conversion failed for column predictions with type object')
Temporary fix
If I change/patch tfx_bsl.arrow.DataFrameToRecordBatch
as follows, it seems to work, but I doubt this is a solution.
def DataFrameToRecordBatch(dataframe):
arrays = []
for col_name, col_type in zip(dataframe.columns, dataframe.dtypes):
arrow_type = None
if col_type.kind != 'O':
arrow_type = NumpyKindToArrowType(col_type.kind)
arrays.append(pa.array(dataframe[col_name].values.tolist(), type=arrow_type))
return pa.RecordBatch.from_arrays(arrays, names=dataframe.columns)