Skip to content
This repository was archived by the owner on Nov 29, 2022. It is now read-only.

Commit 7d6faa2

Browse files
authored
Allow higher rank inputs to SparseCategoricalFocalLoss (#6)
* Make SparseCategoricalFocalLoss accept higher-rank tensors (#5) * Updated numpy sparse categorical focal loss reference implementation * new test cases for higher rank and unknown rank inputs
1 parent a92bb39 commit 7d6faa2

File tree

2 files changed

+168
-22
lines changed

2 files changed

+168
-22
lines changed

src/focal_loss/_categorical_focal_loss.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
# | | | (_) | | (__ | (_| | | | | | | (_) | \__ \ \__ \
77
# |_| \___/ \___| \__,_| |_| |_| \___/ |___/ |___/
88

9+
import itertools
10+
911
import tensorflow as tf
1012

13+
_EPSILON = tf.keras.backend.epsilon()
14+
1115

1216
def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
13-
from_logits: bool = False) -> tf.Tensor:
17+
from_logits: bool = False, axis: int = -1
18+
) -> tf.Tensor:
1419
r"""Focal loss function for multiclass classification with integer labels.
1520
1621
This loss function generalizes multiclass softmax cross-entropy by
@@ -46,10 +51,10 @@ def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
4651
4752
Parameters
4853
----------
49-
y_true : tensor-like, shape (N,)
54+
y_true : tensor-like
5055
Integer class labels.
5156
52-
y_pred : tensor-like, shape (N, K)
57+
y_pred : tensor-like
5358
Either probabilities or logits, depending on the `from_logits`
5459
parameter.
5560
@@ -63,6 +68,9 @@ def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
6368
from_logits : bool, optional
6469
Whether `y_pred` contains logits or probabilities.
6570
71+
axis : int, optional
72+
Channel axis in the `y_pred` tensor.
73+
6674
Returns
6775
-------
6876
:class:`tf.Tensor`
@@ -103,33 +111,64 @@ def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
103111
A wrapper around this function that makes it a
104112
:class:`tf.keras.losses.Loss`.
105113
"""
114+
# Process focusing parameter
106115
gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32)
107-
scalar_gamma = gamma.shape == []
116+
gamma_rank = gamma.shape.rank
117+
scalar_gamma = gamma_rank == 0
108118

119+
# Process prediction tensor
109120
y_pred = tf.convert_to_tensor(y_pred)
110-
y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int32)
111-
base_loss = tf.keras.backend.sparse_categorical_crossentropy(
112-
target=y_true, output=y_pred, from_logits=from_logits)
121+
y_pred_rank = y_pred.shape.rank
122+
if y_pred_rank is not None:
123+
axis %= y_pred_rank
124+
if axis != y_pred_rank - 1:
125+
# Put channel axis last for sparse_softmax_cross_entropy_with_logits
126+
perm = list(itertools.chain(range(axis),
127+
range(axis + 1, y_pred_rank), [axis]))
128+
y_pred = tf.transpose(y_pred, perm=perm)
129+
elif axis != -1:
130+
raise ValueError(
131+
f'Cannot compute sparse categorical focal loss with axis={axis} on '
132+
'a prediction tensor with statically unknown rank.')
133+
y_pred_shape = tf.shape(y_pred)
134+
135+
# Process ground truth tensor
136+
y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int64)
137+
y_true_rank = y_true.shape.rank
138+
139+
if y_true_rank is None:
140+
raise NotImplementedError('Sparse categorical focal loss not supported '
141+
'for target/label tensors of unknown rank')
142+
143+
reshape_needed = (y_true_rank is not None and y_pred_rank is not None and
144+
y_pred_rank != y_true_rank + 1)
145+
if reshape_needed:
146+
y_true = tf.reshape(y_true, [-1])
147+
y_pred = tf.reshape(y_pred, [-1, y_pred_shape[-1]])
113148

114149
if from_logits:
150+
logits = y_pred
115151
probs = tf.nn.softmax(y_pred, axis=-1)
116152
else:
117153
probs = y_pred
118-
batch_size = tf.shape(y_true)[0]
154+
logits = tf.math.log(tf.clip_by_value(y_pred, _EPSILON, 1 - _EPSILON))
119155

120-
# For some reason y_true becomes shaped like (batch, 1) during training, so
121-
# the next line is a hack to ensure it's always rank 1 (needed for stacking)
122-
y_true = tf.cond(tf.rank(y_true) == 1, lambda: y_true, lambda: y_true[:, 0])
156+
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
157+
labels=y_true,
158+
logits=logits,
159+
)
123160

124-
indices = tf.stack([tf.range(batch_size), y_true], axis=1)
125-
probs = tf.gather_nd(probs, indices)
161+
y_true_rank = y_true.shape.rank
162+
probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank)
163+
if not scalar_gamma:
164+
gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank)
165+
focal_modulation = (1 - probs) ** gamma
166+
loss = focal_modulation * xent_loss
126167

127-
if scalar_gamma:
128-
focal_modulation = (1 - probs) ** gamma
129-
else:
130-
focal_modulation = (1 - probs) ** tf.gather(gamma, y_true)
168+
if reshape_needed:
169+
loss = tf.reshape(loss, y_pred_shape[:-1])
131170

132-
return focal_modulation * base_loss
171+
return loss
133172

134173

135174
@tf.keras.utils.register_keras_serializable()
@@ -198,6 +237,7 @@ class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
198237
The function that performs the focal loss computation, taking a label
199238
tensor and a prediction tensor and outputting a loss.
200239
"""
240+
201241
def __init__(self, gamma, from_logits: bool = False, **kwargs):
202242
super().__init__(**kwargs)
203243
self.gamma = gamma

src/focal_loss/tests/test_sparse_categorical_focal_loss.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,34 @@
5353

5454

5555
def numpy_sparse_categorical_focal_loss(y_true, y_pred, gamma,
56-
from_logits=False):
56+
from_logits=False, axis=-1):
5757
"""Simple sparse categorical focal loss implementation using NumPy."""
58-
# Convert to arrays
5958
y_true = np.asarray(y_true)
6059
y_pred = np.asarray(y_pred)
6160

61+
if axis != -1:
62+
pred_dim = np.ndim(y_pred)
63+
axes = list(range(axis)) + list(range(axis + 1, pred_dim)) + [axis]
64+
y_pred = np.transpose(y_pred, axes)
65+
66+
y_pred_shape_original = y_pred.shape
67+
n_classes = y_pred_shape_original[-1]
68+
y_true = np.reshape(y_true, newshape=[-1])
69+
y_pred = np.reshape(y_pred, newshape=[-1, n_classes])
70+
6271
# One-hot encoding of integer labels
63-
y_true_one_hot = np.eye(y_pred.shape[-1])[y_true]
72+
y_true_one_hot = np.eye(n_classes)[y_true]
6473

6574
if from_logits:
6675
y_pred = softmax(y_pred, axis=-1)
76+
else:
77+
y_pred = np.clip(y_pred, 1e-7, 1-1e-7)
6778

6879
loss = -y_true_one_hot * (1 - y_pred) ** gamma * np.log(y_pred)
69-
return loss.sum(axis=-1)
80+
loss = np.sum(loss, axis=-1)
81+
loss = np.reshape(loss, y_pred_shape_original[:-1])
82+
83+
return loss
7084

7185

7286
def get_dummy_sparse_multiclass_classifier(n_features, n_classes, gamma,
@@ -250,3 +264,95 @@ def test_save_and_restore(self, gamma, from_logits):
250264

251265
# Delete the created SavedModel directory
252266
shutil.rmtree(sm_filepath, ignore_errors=True)
267+
268+
def test_with_higher_rank_inputs(self):
269+
"""Addresses https://github.com/artemmavrin/focal-loss/issues/5"""
270+
271+
def build_model():
272+
return tf.keras.Sequential([
273+
tf.keras.layers.Input((100, 10)),
274+
tf.keras.layers.GRU(13, return_sequences=True),
275+
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(13)),
276+
])
277+
278+
x = np.zeros((20, 100, 10))
279+
y = np.ones((20, 100, 1))
280+
281+
model = build_model()
282+
loss = SparseCategoricalFocalLoss(gamma=2)
283+
model.compile(loss=loss, optimizer='adam')
284+
model.fit(x, y)
285+
286+
@named_parameters_with_testcase_names(axis=[0, 1, 2],
287+
from_logits=[False, True])
288+
def test_reduce_to_keras_with_higher_rank_and_axis(self, axis, from_logits):
289+
labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]],
290+
dtype=tf.dtypes.int64)
291+
logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32),
292+
shape=[3, 3, 3])
293+
probs = tf.nn.softmax(logits, axis=axis)
294+
295+
y_pred = logits if from_logits else probs
296+
keras_loss = tf.keras.losses.sparse_categorical_crossentropy(
297+
labels, y_pred, from_logits=from_logits, axis=axis)
298+
focal_loss = sparse_categorical_focal_loss(
299+
labels, y_pred, gamma=0, from_logits=from_logits, axis=axis)
300+
self.assertAllClose(focal_loss, keras_loss)
301+
302+
@named_parameters_with_testcase_names(gamma=[0, 1, 2], axis=[0, 1, 2],
303+
from_logits=[False, True])
304+
def test_higher_rank_sanity_checks(self, gamma, axis, from_logits):
305+
labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]],
306+
dtype=tf.dtypes.int64)
307+
logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32),
308+
shape=[3, 3, 3])
309+
probs = tf.nn.softmax(logits, axis=axis)
310+
311+
y_pred = logits if from_logits else probs
312+
numpy_loss = numpy_sparse_categorical_focal_loss(
313+
labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis)
314+
focal_loss = sparse_categorical_focal_loss(
315+
labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis)
316+
self.assertAllClose(focal_loss, numpy_loss)
317+
318+
@named_parameters_with_testcase_names(gamma=[0, 1, 2],
319+
from_logits=[False, True])
320+
def test_with_dynamic_ranks(self, gamma, from_logits):
321+
# y_true must have defined rank
322+
y_true = tf.keras.backend.placeholder(None, dtype=tf.int64)
323+
y_pred = tf.keras.backend.placeholder((None, 2), dtype=tf.float32)
324+
with self.assertRaises(NotImplementedError):
325+
sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma,
326+
from_logits=from_logits)
327+
328+
# If axis is specified, y_pred must have a defined rank
329+
y_true = tf.keras.backend.placeholder((None,), dtype=tf.int64)
330+
y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32)
331+
with self.assertRaises(ValueError):
332+
sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma,
333+
from_logits=from_logits, axis=0)
334+
335+
# It's fine if y_pred has undefined rank is axis=-1
336+
graph = tf.Graph()
337+
with graph.as_default():
338+
y_true = tf.keras.backend.placeholder((None,), dtype=tf.int64)
339+
y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32)
340+
focal_loss = sparse_categorical_focal_loss(y_true, y_pred,
341+
gamma=gamma,
342+
from_logits=from_logits)
343+
344+
labels = [0, 0, 1]
345+
logits = [[10., 0.], [5., -5.], [0., 10.]]
346+
probs = softmax(logits, axis=-1)
347+
348+
pred = logits if from_logits else probs
349+
loss_numpy = numpy_sparse_categorical_focal_loss(
350+
labels, pred, gamma=gamma, from_logits=from_logits)
351+
352+
with tf.compat.v1.Session(graph=graph) as sess:
353+
loss = sess.run(focal_loss,
354+
feed_dict={y_true: labels, y_pred: pred})
355+
356+
self.assertAllClose(loss, loss_numpy)
357+
358+

0 commit comments

Comments
 (0)