|
53 | 53 |
|
54 | 54 |
|
55 | 55 | def numpy_sparse_categorical_focal_loss(y_true, y_pred, gamma, |
56 | | - from_logits=False): |
| 56 | + from_logits=False, axis=-1): |
57 | 57 | """Simple sparse categorical focal loss implementation using NumPy.""" |
58 | | - # Convert to arrays |
59 | 58 | y_true = np.asarray(y_true) |
60 | 59 | y_pred = np.asarray(y_pred) |
61 | 60 |
|
| 61 | + if axis != -1: |
| 62 | + pred_dim = np.ndim(y_pred) |
| 63 | + axes = list(range(axis)) + list(range(axis + 1, pred_dim)) + [axis] |
| 64 | + y_pred = np.transpose(y_pred, axes) |
| 65 | + |
| 66 | + y_pred_shape_original = y_pred.shape |
| 67 | + n_classes = y_pred_shape_original[-1] |
| 68 | + y_true = np.reshape(y_true, newshape=[-1]) |
| 69 | + y_pred = np.reshape(y_pred, newshape=[-1, n_classes]) |
| 70 | + |
62 | 71 | # One-hot encoding of integer labels |
63 | | - y_true_one_hot = np.eye(y_pred.shape[-1])[y_true] |
| 72 | + y_true_one_hot = np.eye(n_classes)[y_true] |
64 | 73 |
|
65 | 74 | if from_logits: |
66 | 75 | y_pred = softmax(y_pred, axis=-1) |
| 76 | + else: |
| 77 | + y_pred = np.clip(y_pred, 1e-7, 1-1e-7) |
67 | 78 |
|
68 | 79 | loss = -y_true_one_hot * (1 - y_pred) ** gamma * np.log(y_pred) |
69 | | - return loss.sum(axis=-1) |
| 80 | + loss = np.sum(loss, axis=-1) |
| 81 | + loss = np.reshape(loss, y_pred_shape_original[:-1]) |
| 82 | + |
| 83 | + return loss |
70 | 84 |
|
71 | 85 |
|
72 | 86 | def get_dummy_sparse_multiclass_classifier(n_features, n_classes, gamma, |
@@ -250,3 +264,95 @@ def test_save_and_restore(self, gamma, from_logits): |
250 | 264 |
|
251 | 265 | # Delete the created SavedModel directory |
252 | 266 | shutil.rmtree(sm_filepath, ignore_errors=True) |
| 267 | + |
| 268 | + def test_with_higher_rank_inputs(self): |
| 269 | + """Addresses https://github.com/artemmavrin/focal-loss/issues/5""" |
| 270 | + |
| 271 | + def build_model(): |
| 272 | + return tf.keras.Sequential([ |
| 273 | + tf.keras.layers.Input((100, 10)), |
| 274 | + tf.keras.layers.GRU(13, return_sequences=True), |
| 275 | + tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(13)), |
| 276 | + ]) |
| 277 | + |
| 278 | + x = np.zeros((20, 100, 10)) |
| 279 | + y = np.ones((20, 100, 1)) |
| 280 | + |
| 281 | + model = build_model() |
| 282 | + loss = SparseCategoricalFocalLoss(gamma=2) |
| 283 | + model.compile(loss=loss, optimizer='adam') |
| 284 | + model.fit(x, y) |
| 285 | + |
| 286 | + @named_parameters_with_testcase_names(axis=[0, 1, 2], |
| 287 | + from_logits=[False, True]) |
| 288 | + def test_reduce_to_keras_with_higher_rank_and_axis(self, axis, from_logits): |
| 289 | + labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]], |
| 290 | + dtype=tf.dtypes.int64) |
| 291 | + logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32), |
| 292 | + shape=[3, 3, 3]) |
| 293 | + probs = tf.nn.softmax(logits, axis=axis) |
| 294 | + |
| 295 | + y_pred = logits if from_logits else probs |
| 296 | + keras_loss = tf.keras.losses.sparse_categorical_crossentropy( |
| 297 | + labels, y_pred, from_logits=from_logits, axis=axis) |
| 298 | + focal_loss = sparse_categorical_focal_loss( |
| 299 | + labels, y_pred, gamma=0, from_logits=from_logits, axis=axis) |
| 300 | + self.assertAllClose(focal_loss, keras_loss) |
| 301 | + |
| 302 | + @named_parameters_with_testcase_names(gamma=[0, 1, 2], axis=[0, 1, 2], |
| 303 | + from_logits=[False, True]) |
| 304 | + def test_higher_rank_sanity_checks(self, gamma, axis, from_logits): |
| 305 | + labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]], |
| 306 | + dtype=tf.dtypes.int64) |
| 307 | + logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32), |
| 308 | + shape=[3, 3, 3]) |
| 309 | + probs = tf.nn.softmax(logits, axis=axis) |
| 310 | + |
| 311 | + y_pred = logits if from_logits else probs |
| 312 | + numpy_loss = numpy_sparse_categorical_focal_loss( |
| 313 | + labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis) |
| 314 | + focal_loss = sparse_categorical_focal_loss( |
| 315 | + labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis) |
| 316 | + self.assertAllClose(focal_loss, numpy_loss) |
| 317 | + |
| 318 | + @named_parameters_with_testcase_names(gamma=[0, 1, 2], |
| 319 | + from_logits=[False, True]) |
| 320 | + def test_with_dynamic_ranks(self, gamma, from_logits): |
| 321 | + # y_true must have defined rank |
| 322 | + y_true = tf.keras.backend.placeholder(None, dtype=tf.int64) |
| 323 | + y_pred = tf.keras.backend.placeholder((None, 2), dtype=tf.float32) |
| 324 | + with self.assertRaises(NotImplementedError): |
| 325 | + sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma, |
| 326 | + from_logits=from_logits) |
| 327 | + |
| 328 | + # If axis is specified, y_pred must have a defined rank |
| 329 | + y_true = tf.keras.backend.placeholder((None,), dtype=tf.int64) |
| 330 | + y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32) |
| 331 | + with self.assertRaises(ValueError): |
| 332 | + sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma, |
| 333 | + from_logits=from_logits, axis=0) |
| 334 | + |
| 335 | + # It's fine if y_pred has undefined rank is axis=-1 |
| 336 | + graph = tf.Graph() |
| 337 | + with graph.as_default(): |
| 338 | + y_true = tf.keras.backend.placeholder((None,), dtype=tf.int64) |
| 339 | + y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32) |
| 340 | + focal_loss = sparse_categorical_focal_loss(y_true, y_pred, |
| 341 | + gamma=gamma, |
| 342 | + from_logits=from_logits) |
| 343 | + |
| 344 | + labels = [0, 0, 1] |
| 345 | + logits = [[10., 0.], [5., -5.], [0., 10.]] |
| 346 | + probs = softmax(logits, axis=-1) |
| 347 | + |
| 348 | + pred = logits if from_logits else probs |
| 349 | + loss_numpy = numpy_sparse_categorical_focal_loss( |
| 350 | + labels, pred, gamma=gamma, from_logits=from_logits) |
| 351 | + |
| 352 | + with tf.compat.v1.Session(graph=graph) as sess: |
| 353 | + loss = sess.run(focal_loss, |
| 354 | + feed_dict={y_true: labels, y_pred: pred}) |
| 355 | + |
| 356 | + self.assertAllClose(loss, loss_numpy) |
| 357 | + |
| 358 | + |
0 commit comments