@@ -47,31 +47,32 @@ OmniXAI includes a rich family of explanation methods integrated in a unified in
4747supports multiple data types (tabular data, images, texts, time-series), multiple types of ML models
4848(traditional ML in Scikit-learn and deep learning models in PyTorch/TensorFlow), and a range of diverse explaination
4949methods including "model-specific" and "model-agnostic" methods (such as feature-attribution explanation,
50- counterfactual explanation, gradient-based explanation, etc). For practitioners, OmniXAI provides an easy-to-use
50+ counterfactual explanation, gradient-based explanation, feature visualization, etc). For practitioners, OmniXAI provides an easy-to-use
5151unified interface to generate the explanations for their applications by only writing a few lines of
5252codes, and also a GUI dashboard for visualization for obtaining more insights about decisions.
5353
5454The following table shows the supported explanation methods and features in our library.
55- We will continue improving this library to make it more comprehensive in the future, e.g., supporting more
56- explanation methods for vision, NLP and time-series tasks.
57-
58- | Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
59- :-----------------------:| :---: | :---: |:---:| :---: | :---: | :---: | :---:
60- | Feature analysis | NA | Global | ✅ | | | | |
61- | Feature selection | NA | Global | ✅ | | | | |
62- | Prediction metrics | Black box | Global | | ✅ | ✅ | ✅ | ✅ |
63- | Partial dependence plots | Black box | Global | | ✅ | | | |
64- | Accumulated local effects | Black box | Global | | ✅ | | | |
65- | Sensitivity analysis | Black box | Global | | ✅ | | | |
66- | LIME | Black box | Local | | ✅ | ✅ | ✅ | |
67- | SHAP | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
68- | Integrated gradient | Torch or TF | Local | | ✅ | ✅ | ✅ | |
69- | Counterfactual | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
70- | Contrastive explanation | Torch or TF | Local | | | ✅ | | |
71- | Grad-CAM, Grad-CAM++ | Torch or TF | Local | | | ✅ | | |
72- | Learning to explain | Black box | Local | | ✅ | ✅ | ✅ | |
73- | Linear models | Linear models | Global and Local | | ✅ | | | |
74- | Tree models | Tree models | Global and Local | | ✅ | | | |
55+ We will continue improving this library to make it more comprehensive in the future.
56+
57+ | Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
58+ :-------------------------:|:-------------:|:----------------:|:---:|:-------:|:-----:| :---: | :---:
59+ | Feature analysis | NA | Global | ✅ | | | | |
60+ | Feature selection | NA | Global | ✅ | | | | |
61+ | Prediction metrics | Black box | Global | | ✅ | ✅ | ✅ | ✅ |
62+ | Partial dependence plots | Black box | Global | | ✅ | | | |
63+ | Accumulated local effects | Black box | Global | | ✅ | | | |
64+ | Sensitivity analysis | Black box | Global | | ✅ | | | |
65+ | Feature visualization | Torch or TF | Global | | | ✅ | | |
66+ | LIME | Black box | Local | | ✅ | ✅ | ✅ | |
67+ | SHAP | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
68+ | Integrated gradient | Torch or TF | Local | | ✅ | ✅ | ✅ | |
69+ | Counterfactual | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
70+ | Contrastive explanation | Torch or TF | Local | | | ✅ | | |
71+ | Grad-CAM, Grad-CAM++ | Torch or TF | Local | | | ✅ | | |
72+ | Learning to explain | Black box | Local | | ✅ | ✅ | ✅ | |
73+ | Linear models | Linear models | Global and Local | | ✅ | | | |
74+ | Tree models | Tree models | Global and Local | | ✅ | | | |
75+ | Feature maps | Torch or TF | Local | | | ✅ | | |
7576
7677* SHAP* accepts black box models for tabular data, PyTorch/Tensorflow models for image data, transformer models
7778for text data. * Counterfactual* accepts black box models for tabular, text and time-series data, and PyTorch/Tensorflow models for
@@ -109,22 +110,29 @@ Some examples:
1091104 . [ Text classification] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/nlp_imdb.ipynb )
1101115 . [ Time-series anomaly detection] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/timeseries.ipynb )
1111126 . [ Vision-language tasks] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/gradcam_vlm.ipynb )
113+ 7 . [ Ranking tasks] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/tabular/ranking.ipynb )
114+ 8 . [ Feature visualization] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_visualization_torch.ipynb )
115+ 9 . [ Check feature maps] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_map_torch.ipynb )
112116
113117To get started, we recommend the linked tutorials in [ tutorials] ( https://opensource.salesforce.com/OmniXAI/latest/tutorials.html ) .
114118In general, we recommend using ` TabularExplainer ` , ` VisionExplainer ` ,
115119` NLPExplainer ` and ` TimeseriesExplainer ` for tabular, vision, NLP and time-series tasks, respectively, and using
116120` DataAnalyzer ` and ` PredictionAnalyzer ` for feature analysis and prediction result analysis.
117- To generate explanations, one only needs to specify
121+ These classes act as the factories of the individual explainers supported in OmniXAI, providing a simpler
122+ interface to generate multiple explanations. To generate explanations, you only need to specify
118123
119124- ** The ML model to explain** : e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.
120125- ** The pre-processing function** : i.e., converting raw input features into the model inputs.
121126- ** The post-processing function (optional)** : e.g., converting the model outputs into class probabilities.
122127- ** The explainers to apply** : e.g., SHAP, MACE, Grad-CAM.
123128
129+ Besides using these classes, you can also create a single explainer defined in the ` omnixai.explainers ` package, e.g.,
130+ ` ShapTabular ` , ` GradCAM ` , ` IntegratedGradient ` or ` FeatureVisualizer ` .
131+
124132Let's take the income prediction task as an example.
125133The [ dataset] ( https://archive.ics.uci.edu/ml/datasets/adult ) used in this example is for income prediction.
126134We recommend using data class ` Tabular ` to represent a tabular dataset. To create a ` Tabular ` instance given a pandas
127- dataframe, one needs to specify the dataframe, the categorical feature names (if exists) and the target/label
135+ dataframe, you need to specify the dataframe, the categorical feature names (if exists) and the target/label
128136column name (if exists).
129137
130138``` python
@@ -152,8 +160,8 @@ for a `Tabular` instance. `TabularTransform` is a special transform designed for
152160By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features.
153161The method `` transform `` of ` TabularTransform ` transforms a ` Tabular ` instance to a numpy array.
154162If the ` Tabular ` instance has a target/label column, the last column of the numpy array
155- will be the target/label. One can also apply any customized preprocessing functions instead of using ` TabularTransform ` .
156- After data preprocessing, we train a XGBoost classifier for this task.
163+ will be the target/label. You can apply any customized preprocessing functions instead of using ` TabularTransform ` .
164+ After data preprocessing, let's train a XGBoost classifier for this task.
157165
158166``` python
159167from omnixai.preprocessing.tabular import TabularTransform
@@ -172,7 +180,7 @@ train_data = transformer.invert(train)
172180test_data = transformer.invert(test)
173181```
174182
175- To initialize ` TabularExplainer ` , we need to set the following parameters :
183+ To initialize ` TabularExplainer ` , the following parameters need to be set :
176184
177185- `` explainers `` : The names of the explainers to apply, e.g., [ "lime", "shap", "mace", "pdp"] .
178186- `` data `` : The data used to initialize explainers. `` data `` is the training dataset for training the
@@ -185,8 +193,8 @@ To initialize `TabularExplainer`, we need to set the following parameters:
185193- `` mode `` : The task type, e.g., "classification" or "regression".
186194
187195The preprocessing function takes a ` Tabular ` instance as its input and outputs the processed features that
188- the ML model consumes. In this example, we simply call `` transformer.transform `` . If one uses some customized transforms
189- on pandas dataframes, the preprocess function has format: ` lambda z: some_transform(z.to_pd()) ` . If the output of `` model ``
196+ the ML model consumes. In this example, we simply call `` transformer.transform `` . If you use some customized transforms
197+ on pandas dataframes, the preprocess function has this format: ` lambda z: some_transform(z.to_pd()) ` . If the output of `` model ``
190198is not a numpy array, `` postprocess `` needs to be set to convert it into a numpy array.
191199
192200``` python
@@ -222,7 +230,7 @@ global_explanations = explainers.explain_global(
222230```
223231
224232Similarly, we create a ` PredictionAnalyzer ` for computing performance metrics for this classification task.
225- To initialize ` PredictionAnalyzer ` , we set the following parameters:
233+ To initialize ` PredictionAnalyzer ` , the following parameters need to be set :
226234
227235- ` mode ` : The task type, e.g., "classification" or "regression".
228236- ` test_data ` : The test dataset, which should be a ` Tabular ` instance.
@@ -265,6 +273,48 @@ dashboard.show() # Launch the dashboard
265273After opening the Dash app in the browser, we will see a dashboard showing the explanations:
266274![ alt text] ( https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo.gif )
267275
276+ For vision tasks, the same interface is used to create explainers and generate explanations.
277+ Let's take an image classification model as an example.
278+
279+ ``` python
280+ from omnixai.explainers.vision import VisionExplainer
281+ from omnixai.visualization.dashboard import Dashboard
282+
283+ explainer = VisionExplainer(
284+ explainers = [" gradcam" , " lime" , " ig" , " ce" , " feature_visualization" ],
285+ mode = " classification" ,
286+ model = model, # An image classification model, e.g., ResNet50
287+ preprocess = preprocess, # The preprocessing function
288+ postprocess = postprocess, # The postprocessing function
289+ params = {
290+ # Set the target layer for GradCAM
291+ " gradcam" : {" target_layer" : model.layer4[- 1 ]},
292+ # Set the objective for feature visualization
293+ " feature_visualization" :
294+ {" objectives" : [{" layer" : model.layer4[- 3 ], " type" : " channel" , " index" : list (range (6 ))}]}
295+ },
296+ )
297+ # Generate explanations of GradCAM, LIME, IG and CE
298+ local_explanations = explainer.explain(test_img)
299+ # Generate explanations of feature visualization
300+ global_explanations = explainer.explain_global()
301+ # Launch the dashboard
302+ dashboard = Dashboard(
303+ instances = test_img,
304+ local_explanations = local_explanations,
305+ global_explanations = global_explanations
306+ )
307+ dashboard.show()
308+ ```
309+
310+ The following figure shows the dashboard of these explanations:
311+ ![ alt text] ( https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_vision.gif )
312+
313+ For NLP tasks and time-series forecasting/anomaly detection, OmniXAI also provides the same interface
314+ to generate and visualize explanations. This figure shows a dashboard example of text classification
315+ and time-series anomaly detection:
316+ ![ alt text] ( https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_nlp_ts.gif )
317+
268318## How to Contribute
269319
270320We welcome the contribution from the open-source community to improve the library!
0 commit comments