Open
Description
All the models in TensorNets return a softmax tf.Tensor
, while the models in tensorflow/models
(also backbone of tensorflow/hub
) do a logits tf.Tensor
(the values before softmax). That is because most regular TensorFlow loss APIs take a logits as an argument. In results,
- Incorrect:
tf.losses.softmax_cross_entropy(onehot_labels=outputs, logits=model)
, - Correct:
tf.nn.softmax_cross_entropy_with_logits_v2(labels=outputs, logits=model.logits)
, - Correct and recommended:
tf.losses.softmax_cross_entropy(onehot_labels=outputs, logits=model.logits)
,
where model
is any model function in TensorNets (e.g., model = nets.MobileNet50(inputs)
) and model.logits
is equivalent to model.get_outputs()[-2]
.
The following is a comparison of the three losses mentioned above and TL; DR.
import numpy as np
import tensorflow as tf
import tensornets as nets
# Define a model
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
outputs = tf.placeholder(tf.float32, [None, 1000])
model = nets.MobileNet50(inputs)
# `model.logits`: available since TensorNets 0.4.1
logits = model.logits # alternatively: `model.get_outputs()[-2]`
# `model`: the values after softmax
assert 'probs' in model.name
# `logits`: the values before softmax
assert 'logits' in logits.name
# Load feeds
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
label = np.zeros((1, 1000)).astype(np.float32)
label[0, 588] = 1.
# Initialize the model
sess = tf.Session()
sess.run(model.pretrained())
# Get results
# 1. invalid usage: `softmax_cross_entropy`
# with predicted softmax values
invalid = tf.losses.softmax_cross_entropy(
onehot_labels=outputs, logits=model)
# 2. valid usage: `softmax_cross_entropy_with_logits_v2`
# with predicted logits values
valid1 = tf.reduce_sum(
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=outputs, logits=logits))
# 3. (recommended) valid usage: `softmax_cross_entropy`
# with predicted logits values
valid2 = tf.losses.softmax_cross_entropy(
onehot_labels=outputs, logits=logits)
preds, invalid, valid1, valid2 = sess.run(
[model, invalid, valid1, valid2],
{inputs: model.preprocess(img),
outputs: label})
sess.close()
# Check the results with desired outputs
# cross entropy loss: -\sum_i p_i \log{\hat{p_i}}
desired = -np.sum(label * np.log(preds))
np.testing.assert_allclose(valid1, desired)
np.testing.assert_allclose(valid2, desired)
try:
np.testing.assert_allclose(invalid, desired)
except AssertionError as e:
print(e)