Skip to content

sparse_categorical_crossentropy with ignore_class=-1 makes loss to nan #734

Open
@innat

Description

@innat

This behaviour happens in Keras 2 but works in Keras 3.


I tried to train a multi-output model. But it target looks like something as follows

y1_dummy = [1,  2,   0, -1,  0,  -1,  -1, -1,  3,  -1]
y2_dummy = [-1, -1, -1,  2,  -1,  0,   3,  1, -1,   2]

Between this two target array, -1 is paced to y2_dummy[0] but some value in y1_dummy[0] and continues. In training time, I set ignore_class = -1, please see below.

def custom_loss(y_true, y_pred):
    loss = sparse_categorical_crossentropy(
        y_true, y_pred, ignore_class=-1
    )
    return loss

The code works in Keras 3 but in Keras 2, the loss becomes nan. Below is the full code.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

num_samples = 10
num_classes = 4
input_shape = (224, 224, 3) 
x_dummy = np.random.rand(num_samples, *input_shape).astype('float32')
y1_dummy = [1,  2,   0, -1,  0,  -1,  -1, -1,  3,  -1]
y2_dummy = [-1, -1, -1,  2,  -1,  0,   3,  1, -1,   2]
_sample = tf.data.Dataset.from_tensor_slices(x_dummy)
_labels = tf.data.Dataset.from_tensor_slices(
    (
        y1_dummy, 
        y2_dummy
    )
)
_data = tf.data.Dataset.zip((_sample, _labels))
_data = _data.batch(batch_size=3, drop_remainder=True)

def custom_loss(y_true, y_pred):
    loss = sparse_categorical_crossentropy(
        y_true, y_pred, ignore_class=-1
    )
    return loss

input_layer = keras.Input(shape=input_shape)
flatten_layer = layers.Flatten()(input_layer)
output_layer1 = layers.Dense(
    num_classes, activation='softmax', name='out1'
)(flatten_layer)
output_layer2 = layers.Dense(
    num_classes, activation='softmax', name='out2'
)(flatten_layer)
A = keras.Model(
    inputs=input_layer, 
    outputs=[output_layer1, output_layer2]
)
A.compile(
    optimizer=Adam(), 
    loss={
    'out1': custom_loss, 
    'out2': custom_loss
    }
)
A.fit(
    _data, 
    epochs=2,
)

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions