Skip to content

TFMA analyze_raw_data function support with MultiClassConfusionMatrixPlot #162

Open
@tybrs

Description

@tybrs

System information

  • Have I written custom code (as opposed to using a stock example script
    provided in TensorFlow Model Analysis)
    : No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04.2 LTS Linux 5.4.0-65-generic
  • TensorFlow Model Analysis installed from (source or binary): pip install tensorflow-model-analysis
  • TensorFlow Model Analysis version (use command below): 0.41.1
  • Python version: Python 3.8.5
  • Jupyter Notebook version: NA

Describe the problem

I am currently trying to get tfma.analyze_raw_data to work with MultiClassConfusionMatrixPlot which has multiple prediction values per record. Is this not supported? I will be happy to provide any further details or run any further tests.

Details

Currently tfma.analyze_raw_data does not seem to work with metrics for multi classification tasks (e.g. tfma.metrics.MultiClassConfusionMatrixPlot). However, I do not see this limitation documented anywhere.

The prediction column for a multi classification column will be a series of whose values are a list or array (e.g.,. pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]}))

The tfma.analyze_raw_data funciton uses tfx_bsl.arrow.DataFrameToRecordBatch to convert a Pandas DataFrame to Arrow RecordBatch. The problem, however, is that it encodes columns with the dtype of object as a pyarrow.Binary. Since a column that has lists or arrays as values has a dtype of object, these columns are being encoded as a pyarrow.Binary instead of the relevant pyarrow list-like type.

Source code / logs

import tensorflow_model_analysis as tfma
from google.protobuf import text_format
import pandas as pd

eval_config = text_format.Parse("""
  ## Model information
  model_specs {
    label_key: "label",
    prediction_key: "predictions"
  }

  ## Post training metric information. These will be merged with any built-in
  ## metrics from training.
  metrics_specs {
    metrics { class_name: "MultiClassConfusionMatrixPlot" }
  }
  
  ## Slicing information
  slicing_specs {}  # overall slice
""", tfma.EvalConfig())

df = pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]})
tfma.analyze_raw_data(df, eval_config)

Error

---------------------------------------------------------------------------
ArrowTypeError                            Traceback (most recent call last)
/tmp/ipykernel_206830/3947320198.py in <cell line: 23>()
     21 
     22 df = pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]})
---> 23 tfma.analyze_raw_data(df, eval_config)

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/tensorflow_model_analysis/api/model_eval_lib.py in analyze_raw_data(data, eval_config, output_path, add_metric_callbacks)
   1511 
   1512   arrow_data = table_util.CanonicalizeRecordBatch(
-> 1513       table_util.DataFrameToRecordBatch(data))
   1514   beam_data = beam.Create([arrow_data])
   1515 

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/tfx_bsl/arrow/table_util.py in DataFrameToRecordBatch(dataframe)
    122       continue
    123     arrow_fields.append(pa.field(col_name, arrow_type))
--> 124   return pa.RecordBatch.from_pandas(dataframe, schema=pa.schema(arrow_fields))
    125 
    126 

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/table.pxi in pyarrow.lib.RecordBatch.from_pandas()

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in dataframe_to_arrays(df, schema, preserve_index, nthreads, columns, safe)
    592 
    593     if nthreads == 1:
--> 594         arrays = [convert_column(c, f)
    595                   for c, f in zip(columns_to_convert, convert_fields)]
    596     else:

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in <listcomp>(.0)
    592 
    593     if nthreads == 1:
--> 594         arrays = [convert_column(c, f)
    595                   for c, f in zip(columns_to_convert, convert_fields)]
    596     else:

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in convert_column(col, field)
    579             e.args += ("Conversion failed for column {!s} with type {!s}"
    580                        .format(col.name, col.dtype),)
--> 581             raise e
    582         if not field_nullable and result.null_count > 0:
    583             raise ValueError("Field {} was non-nullable but pandas column "

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/pandas_compat.py in convert_column(col, field)
    573 
    574         try:
--> 575             result = pa.array(col, type=type_, from_pandas=True, safe=safe)
    576         except (pa.ArrowInvalid,
    577                 pa.ArrowNotImplementedError,

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/array.pxi in pyarrow.lib.array()

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/array.pxi in pyarrow.lib._ndarray_to_array()

/localdisk/twilbers/src/repos/xai-tools/model_card_gen/.venv/lib/python3.8/site-packages/pyarrow/error.pxi in pyarrow.lib.check_status()

ArrowTypeError: ("Expected bytes, got a 'list' object", 'Conversion failed for column predictions with type object')

Temporary fix

If I change/patch tfx_bsl.arrow.DataFrameToRecordBatch as follows, it seems to work, but I doubt this is a solution.

def DataFrameToRecordBatch(dataframe):
    arrays = []
    for col_name, col_type in zip(dataframe.columns, dataframe.dtypes):
        arrow_type = None
        if col_type.kind != 'O':
            arrow_type = NumpyKindToArrowType(col_type.kind)
        arrays.append(pa.array(dataframe[col_name].values.tolist(), type=arrow_type))
    return pa.RecordBatch.from_arrays(arrays,  names=dataframe.columns)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions