Skip to content

Commit a3ae080

Browse files
authored
remove softmax api from fluid (#48388)
* move softmax to paddle2.0 * fix some bugs * resolve conflict * remove some code * modify code style * fix bugs * fix code * fix move code * fix some bugs * fix code * fix some code * modify the header file * fix bugs * fix some examples * fix mish example * fix code
1 parent ea5ca55 commit a3ae080

33 files changed

+69
-205
lines changed

python/paddle/fluid/layers/detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class number, M is number of bounding boxes.
626626
target_box=loc,
627627
code_type='decode_center_size',
628628
)
629-
scores = nn.softmax(input=scores)
629+
scores = paddle.nn.functional.softmax(scores)
630630
scores = paddle.transpose(scores, perm=[0, 2, 1])
631631
scores.stop_gradient = True
632632
nmsed_outs = helper.create_variable_for_type_inference(

python/paddle/fluid/layers/nn.py

+12-152
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
'linear_chain_crf',
6969
'crf_decoding',
7070
'conv2d',
71-
'softmax',
7271
'pool2d',
7372
'batch_norm',
7473
'dropout',
@@ -145,7 +144,7 @@ def _get_reduce_dim(dim, input):
145144
else:
146145
raise TypeError(
147146
"The type of dim must be int, list, tuple or range, but received {}".format(
148-
type(axis)
147+
type(dim)
149148
)
150149
)
151150
if dim is None:
@@ -679,7 +678,7 @@ def _pull_gpups_sparse(
679678
size(int|list of int): The embedding size parameter of each input, which indicates the size of
680679
each embedding vector respectively.
681680
dtype(str): The dtype refers to the data type of output tensor. Only supports
682-
float32 now.
681+
float32 now.
683682
684683
Returns:
685684
Variable|list of Variable: The tensor variable storing the embeddings of the \
@@ -742,7 +741,7 @@ def _pull_box_sparse(
742741
size(int): The embedding size parameter, which indicates the size of
743742
each embedding vector respectively.
744743
dtype(str): The dtype refers to the data type of output tensor. Only supports
745-
float32 now.
744+
float32 now.
746745
747746
Returns:
748747
Variable|list of Variable: The tensor variable storing the embeddings of the \
@@ -1123,147 +1122,6 @@ def get_attrs(prog, dropout_prob, is_test, seed):
11231122
return out
11241123

11251124

1126-
@deprecated(since="2.0.0", update_to="paddle.nn.functional.softmax")
1127-
def softmax(input, use_cudnn=True, name=None, axis=-1):
1128-
r"""
1129-
This operator implements the softmax layer. The calculation process is as follows:
1130-
1131-
1. The dimension :attr:`axis` of the ``input`` will be permuted to the last.
1132-
1133-
2. Then the input tensor will be logically flattened to a 2-D matrix. The matrix's
1134-
second dimension(row length) is the same as the dimension :attr:`axis` of the input
1135-
tensor, and the first dimension(column length) is the product of all other
1136-
dimensions of the input tensor. For each row of the matrix, the softmax operator
1137-
squashes the K-dimensional(K is the width of the matrix, which is also the size
1138-
of the input tensor's dimension :attr:`axis`) vector of arbitrary real values to a
1139-
K-dimensional vector of real values in the range [0, 1] that add up to 1.
1140-
1141-
3. After the softmax operation is completed, the inverse operations of steps 1 and 2
1142-
are performed to restore the two-dimensional matrix to the same dimension as the ``input``.
1143-
1144-
It computes the exponential of the given dimension and the sum of exponential
1145-
values of all the other dimensions in the K-dimensional vector input.
1146-
Then the ratio of the exponential of the given dimension and the sum of
1147-
exponential values of all the other dimensions is the output of the softmax
1148-
operator.
1149-
1150-
For each row :math:`i` and each column :math:`j` in the matrix, we have:
1151-
1152-
.. math::
1153-
1154-
Out[i, j] = \\frac{\\exp(X[i, j])}{\\sum_j(exp(X[i, j])}
1155-
1156-
Example:
1157-
1158-
.. code-block:: text
1159-
1160-
Case 1:
1161-
Input:
1162-
X.shape = [2, 3, 4]
1163-
X.data = [[[2.0, 3.0, 4.0, 5.0],
1164-
[3.0, 4.0, 5.0, 6.0],
1165-
[7.0, 8.0, 8.0, 9.0]],
1166-
[[1.0, 2.0, 3.0, 4.0],
1167-
[5.0, 6.0, 7.0, 8.0],
1168-
[6.0, 7.0, 8.0, 9.0]]]
1169-
1170-
Attrs:
1171-
axis = -1
1172-
1173-
Output:
1174-
Out.shape = [2, 3, 4]
1175-
Out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
1176-
[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
1177-
[0.07232949, 0.19661193, 0.19661193, 0.53444665]],
1178-
[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
1179-
[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
1180-
[0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]
1181-
1182-
Case 2:
1183-
Input:
1184-
X.shape = [2, 3, 4]
1185-
X.data = [[[2.0, 3.0, 4.0, 5.0],
1186-
[3.0, 4.0, 5.0, 6.0],
1187-
[7.0, 8.0, 8.0, 9.0]],
1188-
[[1.0, 2.0, 3.0, 4.0],
1189-
[5.0, 6.0, 7.0, 8.0],
1190-
[6.0, 7.0, 8.0, 9.0]]]
1191-
Attrs:
1192-
axis = 1
1193-
1194-
Output:
1195-
Out.shape = [2, 3, 4]
1196-
Out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783],
1197-
[0.01786798, 0.01786798, 0.04661262, 0.04661262],
1198-
[0.97555875, 0.97555875, 0.93623955, 0.93623955]],
1199-
[[0.00490169, 0.00490169, 0.00490169, 0.00490169],
1200-
[0.26762315, 0.26762315, 0.26762315, 0.26762315],
1201-
[0.72747516, 0.72747516, 0.72747516, 0.72747516]]]
1202-
1203-
Args:
1204-
input (Tensor): The input tensor. A multi-dimension ``Tensor`` with type float32 or float64.
1205-
use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn \
1206-
library is installed. To improve performance, set use_cudnn to True by default.
1207-
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Default: None.
1208-
will be named automatically. Default: None.
1209-
axis (int, optional): The index of dimension to perform softmax calculations, it should
1210-
be in range :math:`[-1, rank - 1]`, while :math:`rank` is the rank of
1211-
input tensor. Default: -1. -1 means the last dimension.
1212-
1213-
Returns:
1214-
Tensor: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input`` .
1215-
1216-
Examples:
1217-
1218-
.. code-block:: python
1219-
1220-
import paddle
1221-
import paddle.nn.functional as F
1222-
1223-
x = paddle.to_tensor([[[2.0, 3.0, 4.0, 5.0],
1224-
[3.0, 4.0, 5.0, 6.0],
1225-
[7.0, 8.0, 8.0, 9.0]],
1226-
[[1.0, 2.0, 3.0, 4.0],
1227-
[5.0, 6.0, 7.0, 8.0],
1228-
[6.0, 7.0, 8.0, 9.0]]], dtype='float32')
1229-
y = F.softmax(x, axis=1)
1230-
print(y)
1231-
# [[[0.00657326, 0.00657326, 0.01714783, 0.01714783],
1232-
# [0.01786798, 0.01786798, 0.04661262, 0.04661262],
1233-
# [0.97555870, 0.97555870, 0.93623954, 0.93623954]],
1234-
# [[0.00490169, 0.00490169, 0.00490169, 0.00490169],
1235-
# [0.26762316, 0.26762316, 0.26762316, 0.26762316],
1236-
# [0.72747517, 0.72747517, 0.72747517, 0.72747517]]]
1237-
1238-
"""
1239-
1240-
if in_dygraph_mode():
1241-
return _C_ops.softmax(input, axis)
1242-
1243-
if _non_static_mode():
1244-
return _legacy_C_ops.softmax(
1245-
input, 'axis', axis, 'use_cudnn', use_cudnn
1246-
)
1247-
1248-
inputs = {"X": [input]}
1249-
attrs = {"axis": axis, "use_cudnn": use_cudnn}
1250-
1251-
helper = LayerHelper('softmax', **locals())
1252-
check_variable_and_dtype(
1253-
input, 'input/x', ['float16', 'float32', 'float64'], 'softmax'
1254-
)
1255-
1256-
dtype = helper.input_dtype()
1257-
softmax_out = helper.create_variable_for_type_inference(dtype)
1258-
helper.append_op(
1259-
type="softmax",
1260-
inputs={"X": input},
1261-
outputs={"Out": softmax_out},
1262-
attrs=attrs,
1263-
)
1264-
return softmax_out
1265-
1266-
12671125
def conv2d(
12681126
input,
12691127
num_filters,
@@ -1788,7 +1646,7 @@ def is_list_or_tuple(ele):
17881646
if pool_padding == "VALID":
17891647
padding_algorithm = "VALID"
17901648
pool_padding = [0, 0]
1791-
if ceil_mode != False:
1649+
if ceil_mode is not False:
17921650
raise ValueError(
17931651
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode) must be False. "
17941652
"Received ceil_mode: True."
@@ -6643,7 +6501,7 @@ def deformable_roi_pooling(
66436501
)
66446502

66456503
input_channels = input.shape[1]
6646-
if position_sensitive == False:
6504+
if position_sensitive is False:
66476505
output_channels = input_channels
66486506
else:
66496507
output_channels = input_channels / pooled_height / pooled_width
@@ -6841,11 +6699,11 @@ def mish(x, threshold=20, name=None):
68416699
68426700
.. math::
68436701
6844-
out = \\begin{cases}
6845-
x \\ast \\tanh(x), \\text{if } x > \\text{threshold} \\\\
6846-
x \\ast \\tanh(e^{x}), \\text{if } x < -\\text{threshold} \\\\
6847-
x \\ast \\tanh(\\ln(1 + e^{x})), \\text{otherwise}
6848-
\\end{cases}
6702+
out = \\begin{cases}
6703+
x \\ast \\tanh(x), \\text{if } x > \\text{threshold} \\\\
6704+
x \\ast \\tanh(e^{x}), \\text{if } x < -\\text{threshold} \\\\
6705+
x \\ast \\tanh(\\ln(1 + e^{x})), \\text{otherwise}
6706+
\\end{cases}
68496707
68506708
Args:
68516709
x (Variable): Input feature, multi-dimensional Tensor. The data type
@@ -6867,9 +6725,11 @@ def mish(x, threshold=20, name=None):
68676725
68686726
.. code-block:: python
68696727
6728+
import paddle
68706729
import paddle.fluid as fluid
68716730
import numpy as np
68726731
6732+
paddle.enable_static()
68736733
DATATYPE='float32'
68746734
68756735
x_data = np.array([i for i in range(1,5)]).reshape([1,1,4]).astype(DATATYPE)

python/paddle/fluid/layers/rnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ def _beam_search_step(self, time, logits, next_cell_states, beam_state):
13041304
self.noend_mask_tensor, "float64"
13051305
)
13061306

1307-
step_log_probs = paddle.log(nn.softmax(logits))
1307+
step_log_probs = paddle.log(paddle.nn.functional.softmax(logits))
13081308
step_log_probs = self._mask_probs(step_log_probs, beam_state.finished)
13091309
log_probs = nn.elementwise_add(
13101310
x=step_log_probs, y=beam_state.log_probs, axis=0
@@ -2330,7 +2330,7 @@ def sample(self, time, outputs, states):
23302330
if self.softmax_temperature is not None
23312331
else outputs
23322332
)
2333-
probs = nn.softmax(logits)
2333+
probs = paddle.nn.functional.softmax(logits)
23342334
# TODO: remove this stop_gradient. The stop_gradient of sample_ids can
23352335
# not pass to probs, since sampling_id op does not have corresponding
23362336
# grad op and thus can not pass.

python/paddle/fluid/tests/unittests/collective/fleet/parallel_dygraph_se_resnext.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def run_one_loop(self, model, opt, data):
354354
label.stop_gradient = True
355355

356356
out = model(img)
357-
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
357+
softmax_out = paddle.nn.functional.softmax(out, use_cudnn=False)
358358
loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
359359
avg_loss = paddle.mean(x=loss)
360360
return avg_loss

python/paddle/fluid/tests/unittests/collective/fleet/parallel_dygraph_transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def forward(self, queries, keys, values, attn_bias):
342342
)
343343
if attn_bias is not None:
344344
product += attn_bias
345-
weights = fluid.layers.softmax(product)
345+
weights = paddle.nn.functional.softmax(product)
346346
if self._dropout_rate:
347347
weights_droped = fluid.layers.dropout(
348348
weights,
@@ -849,7 +849,7 @@ def forward(self, dec_inputs=None, enc_output=None):
849849

850850
if dec_inputs is None:
851851
# Return probs for independent decoder program.
852-
predict_out = fluid.layers.softmax(predict)
852+
predict_out = paddle.nn.functional.softmax(predict)
853853
return predict_out
854854
return predict
855855

python/paddle/fluid/tests/unittests/dist_transformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,7 @@ def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
11771177
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
11781178
if attn_bias:
11791179
product += attn_bias
1180-
weights = layers.softmax(product)
1180+
weights = paddle.nn.functional.softmax(product)
11811181
if dropout_rate:
11821182
weights = layers.dropout(
11831183
weights,
@@ -1715,7 +1715,7 @@ def wrap_decoder(
17151715
bias_attr=const_bias_attr,
17161716
)
17171717
if dec_inputs is None:
1718-
predict = layers.softmax(predict)
1718+
predict = paddle.nn.functional.softmax(predict)
17191719
return predict
17201720

17211721

@@ -1834,7 +1834,7 @@ def beam_search():
18341834
logits = paddle.reshape(logits, (-1, trg_vocab_size))
18351835

18361836
topk_scores, topk_indices = layers.topk(
1837-
input=layers.softmax(logits), k=beam_size
1837+
input=paddle.nn.functional.softmax(logits), k=beam_size
18381838
)
18391839
accu_scores = layers.elementwise_add(
18401840
x=paddle.log(topk_scores),

python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ def beam_search(self, inputs):
435435
cell_outputs = self._split_batch_beams(step_input)
436436
cell_outputs = self.fc(cell_outputs)
437437

438-
step_log_probs = paddle.log(fluid.layers.softmax(cell_outputs))
438+
step_log_probs = paddle.log(
439+
paddle.nn.functional.softmax(cell_outputs)
440+
)
439441
noend_array = [-self.kinf] * self.tar_vocab_size
440442
noend_array[self.beam_end_token] = 0
441443
noend_mask_tensor = to_variable(
@@ -703,7 +705,7 @@ def attention(self, query, enc_output, mask=None):
703705
attn = paddle.transpose(attn, [1, 0, 2])
704706
attn = paddle.add(attn, mask * 1000000000)
705707
attn = paddle.transpose(attn, [1, 0, 2])
706-
weight = fluid.layers.softmax(attn)
708+
weight = paddle.nn.functional.softmax(attn)
707709
weight_memory = fluid.layers.matmul(weight, memory)
708710

709711
return weight_memory

python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def forward(self, input, cache=None):
6767
cache["k"], cache["v"] = k, v
6868

6969
weight = fluid.layers.matmul(x=q, y=k, transpose_y=True)
70-
weight = fluid.layers.softmax(weight)
70+
weight = paddle.nn.functional.softmax(weight)
7171
out = fluid.layers.matmul(weight, v)
7272

7373
return out
@@ -113,7 +113,7 @@ def forward(self, input, max_len=4):
113113
# Test to call function defined outside of class.
114114
def update_cache(cache):
115115
for k, val in cache.items():
116-
cache[k] = fluid.layers.softmax(val)
116+
cache[k] = paddle.nn.functional.softmax(val)
117117

118118
return cache
119119

python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def forward(self, x, label=None):
308308

309309
# Test to call function behind caller.
310310
def softmax(x):
311-
return fluid.layers.softmax(x)
311+
return paddle.nn.functional.softmax(x)
312312

313313

314314
class TestNetWithExternalFunc(TestDygraphIfElseNet):

python/paddle/fluid/tests/unittests/dygraph_to_static/test_mobile_net.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def train_mobilenet(args, to_static):
535535
out = net(img)
536536

537537
t_end = time.time()
538-
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
538+
softmax_out = paddle.nn.functional.softmax(out)
539539
loss = fluid.layers.cross_entropy(
540540
input=softmax_out, label=label
541541
)

python/paddle/fluid/tests/unittests/dygraph_to_static/test_reinforcement_learning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def forward(self, x):
4848
x = fluid.layers.relu(x)
4949
action_scores = self.affine2(x)
5050

51-
log_prob = fluid.layers.softmax(action_scores, axis=1)
51+
log_prob = paddle.nn.functional.softmax(action_scores, axis=1)
5252

5353
return log_prob
5454

python/paddle/fluid/tests/unittests/dygraph_to_static/test_se_resnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def forward(self, inputs, label):
343343
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_output])
344344
out = self.out(y)
345345

346-
softmax_out = fluid.layers.softmax(out)
346+
softmax_out = paddle.nn.functional.softmax(out)
347347
loss = fluid.layers.cross_entropy(input=softmax_out, label=label)
348348
avg_loss = paddle.mean(x=loss)
349349

0 commit comments

Comments
 (0)