Skip to content

Commit f41ffd8

Browse files
authored
Follow up on #1139 (#1141)
* rename Signed-off-by: xadupre <[email protected]> * Improve rendering of one example Signed-off-by: xadupre <[email protected]> * fix title Signed-off-by: xadupre <[email protected]> --------- Signed-off-by: xadupre <[email protected]>
1 parent 5893ed6 commit f41ffd8

File tree

3 files changed

+145
-86
lines changed

3 files changed

+145
-86
lines changed

docs/examples/output_onnx_single_probability.py

-86
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Append onnx nodes to the converted model
3+
========================================
4+
5+
This example show how to append some onnx nodes to the converted
6+
model to produce the desired output. In this case, it removes the second
7+
column of the output probabilies.
8+
9+
To be completly accurate, most of the code was generated using a LLM
10+
and modified to accomodate with the latest changes.
11+
"""
12+
13+
from sklearn.datasets import load_iris
14+
from sklearn.linear_model import LogisticRegression
15+
from sklearn.model_selection import train_test_split
16+
from skl2onnx import convert_sklearn
17+
from skl2onnx.common.data_types import FloatTensorType
18+
import onnx
19+
20+
iris = load_iris()
21+
X, y = iris.data, iris.target
22+
X_train, X_test, y_train, y_test = train_test_split(X, y)
23+
clr = LogisticRegression(max_iter=500)
24+
clr.fit(X_train, y_train)
25+
26+
27+
################################################
28+
# model_to_convert refers to the scikit-learn classifier to convert.
29+
model_to_convert = clr # model to convert
30+
X_test = X_test[:1] # data used to test or train, one row is enough
31+
32+
################################################
33+
# Set the output filename for the modified ONNX model
34+
output_filename = "output_file.onnx" # Replace with your desired output filename
35+
36+
################################################
37+
# Step 1: Convert the model to ONNX format,
38+
# disabling the output of labels.
39+
# Define the input type for the ONNX model.
40+
# The input type is a float tensor with shape
41+
# [None, X_test.shape[1]], where None indicates that the
42+
# number of input samples can be flexible,
43+
# and X_test.shape[1] is the number of features for each input sample.
44+
# A "tensor" is essentially a multi-dimensional array,
45+
# commonly used in machine learning to represent data.
46+
# A "float tensor" specifically contains floating-point
47+
# numbers, which are numbers with decimals.
48+
initial_type = [("float_input", FloatTensorType([None, X_test.shape[1]]))]
49+
50+
################################################
51+
# Convert the model to ONNX format.
52+
# - target_opset=18 specifies the version of ONNX operators to use.
53+
# - options={...} sets parameters for the conversion:
54+
# - "zipmap": False ensures that the output is a raw array
55+
# - of probabilities instead of a dictionary.
56+
# - "output_class_labels": False ensures that the output
57+
# contains only probabilities, not class labels.
58+
# ONNX (Open Neural Network Exchange) is an open format for
59+
# representing machine learning models.
60+
# It allows interoperability between different machine learning frameworks,
61+
# enabling the use of models across various platforms.
62+
onx = convert_sklearn(
63+
model_to_convert,
64+
initial_types=initial_type,
65+
target_opset={"": 18, "ai.onnx.ml": 3},
66+
options={
67+
id(model_to_convert): {"zipmap": False, "output_class_labels": False}
68+
}, # Ensures the output is only probabilities, not labels
69+
)
70+
71+
################################################
72+
# Step 2: Load the ONNX model for further modifications if needed
73+
# Load the ONNX model from the serialized string representation.
74+
# An ONNX file is essentially a serialized representation of a machine learning
75+
# model that can be shared and used across different systems.
76+
onnx_model = onnx.load_model_from_string(onx.SerializeToString())
77+
78+
################################################
79+
# Assuming the first output in this model should be the probability tensor
80+
# Extract the name of the output tensor representing the probabilities.
81+
# If there are multiple outputs, select the second one, otherwise, select the first.
82+
prob_output_name = (
83+
onnx_model.graph.output[1].name
84+
if len(onnx_model.graph.output) > 1
85+
else onnx_model.graph.output[0].name
86+
)
87+
88+
################################################
89+
# Add a Gather node to extract only the probability
90+
# of the positive class (index 1)
91+
# Create a tensor to specify the index to gather
92+
# (index 1), which represents the positive class.
93+
indices = onnx.helper.make_tensor(
94+
"indices", onnx.TensorProto.INT64, (1,), [1]
95+
) # Index 1 to gather positive class
96+
97+
################################################
98+
# Create a "Gather" node in the ONNX graph to extract the probability of the positive class.
99+
# - inputs: [prob_output_name, "indices"] specify the inputs
100+
# to this node (probability tensor and index tensor).
101+
# - outputs: ["positive_class_prob"] specify the name of the output of this node.
102+
# - axis=1 indicates gathering along the columns (features) of the probability tensor.
103+
# A "Gather" node is used to extract specific elements from a tensor.
104+
# Here, it extracts the probability for the positive class.
105+
gather_node = onnx.helper.make_node(
106+
"Gather",
107+
inputs=[prob_output_name, "indices"],
108+
outputs=["positive_class_prob"],
109+
axis=1, # Gather along columns (axis 1)
110+
)
111+
112+
################################################
113+
# Add the Gather node to the ONNX graph
114+
onnx_model.graph.node.append(gather_node)
115+
116+
################################################
117+
# Add the tensor initializer for indices (needed for the Gather node)
118+
# Initializers in ONNX are used to define constant tensors that are used in the computation.
119+
onnx_model.graph.initializer.append(indices)
120+
121+
################################################
122+
# Remove existing outputs and add only the new output for the positive class probability
123+
# Clear the existing output definitions to replace them with the new output.
124+
del onnx_model.graph.output[:]
125+
126+
################################################
127+
# Define new output for the positive class probability
128+
# Create a new output tensor specification with the name "positive_class_prob".
129+
positive_class_output = onnx.helper.make_tensor_value_info(
130+
"positive_class_prob", onnx.TensorProto.FLOAT, [None, 1]
131+
)
132+
onnx_model.graph.output.append(positive_class_output)
133+
134+
################################################
135+
# Step 3: Save the modified ONNX model
136+
# Save the modified ONNX model to the specified output filename.
137+
# The resulting ONNX file can then be loaded and used in different environments
138+
# that support ONNX, such as inference servers or other machine learning frameworks.
139+
onnx.save(onnx_model, output_filename)
140+
141+
142+
################################################
143+
# The model can be printed as follows.
144+
print(onnx.printer.to_text(onnx_model))

docs/tutorial_4_advanced.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ with issues and resolved issues.
1313
auto_tutorial/plot_ngrams
1414
auto_tutorial/plot_usparse_xgboost
1515
auto_tutorial/plot_woe_transformer
16+
auto_tutorial/plot_output_onnx_single_probability

0 commit comments

Comments
 (0)