Skip to content

Notes for logits and softmax_cross_entropy #45

Open
@taehoonlee

Description

@taehoonlee

All the models in TensorNets return a softmax tf.Tensor, while the models in tensorflow/models (also backbone of tensorflow/hub) do a logits tf.Tensor (the values before softmax). That is because most regular TensorFlow loss APIs take a logits as an argument. In results,

  • Incorrect: tf.losses.softmax_cross_entropy(onehot_labels=outputs, logits=model),
  • Correct: tf.nn.softmax_cross_entropy_with_logits_v2(labels=outputs, logits=model.logits),
  • Correct and recommended: tf.losses.softmax_cross_entropy(onehot_labels=outputs, logits=model.logits),

where model is any model function in TensorNets (e.g., model = nets.MobileNet50(inputs)) and model.logits is equivalent to model.get_outputs()[-2].

The following is a comparison of the three losses mentioned above and TL; DR.

import numpy as np
import tensorflow as tf
import tensornets as nets

# Define a model
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
outputs = tf.placeholder(tf.float32, [None, 1000])
model = nets.MobileNet50(inputs)
# `model.logits`: available since TensorNets 0.4.1
logits = model.logits  #  alternatively: `model.get_outputs()[-2]`

# `model`: the values after softmax
assert 'probs' in model.name
# `logits`: the values before softmax
assert 'logits' in logits.name

# Load feeds
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
label = np.zeros((1, 1000)).astype(np.float32)
label[0, 588] = 1.

# Initialize the model
sess = tf.Session()
sess.run(model.pretrained())

# Get results
# 1. invalid usage: `softmax_cross_entropy`
#    with predicted softmax values
invalid = tf.losses.softmax_cross_entropy(
    onehot_labels=outputs, logits=model)
# 2. valid usage: `softmax_cross_entropy_with_logits_v2`
#    with predicted logits values
valid1 = tf.reduce_sum(
    tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=outputs, logits=logits))
# 3. (recommended) valid usage: `softmax_cross_entropy`
#    with predicted logits values
valid2 = tf.losses.softmax_cross_entropy(
    onehot_labels=outputs, logits=logits)

preds, invalid, valid1, valid2 = sess.run(
    [model, invalid, valid1, valid2],
    {inputs: model.preprocess(img),
     outputs: label})

sess.close()

# Check the results with desired outputs
# cross entropy loss: -\sum_i p_i \log{\hat{p_i}}
desired = -np.sum(label * np.log(preds))
np.testing.assert_allclose(valid1, desired)
np.testing.assert_allclose(valid2, desired)
try:
    np.testing.assert_allclose(invalid, desired)
except AssertionError as e:
    print(e)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions