This repository was archived by the owner on Nov 29, 2022. It is now read-only.

Description
Hi, I'm trying to put categorical_focal_loss in my image segmentation task. The dataset is defined with tf.data.Dataset object and the model is defined with keras Model. The model is compiled like
loss_gamma = [0.5, 1., ...]
model.compile(
optimizer=tf.keras.optimizers.Adam(lr=lr),
loss=SparseCategoricalFocalLoss(gamma=loss_gamma),
...)
model.fit(...)
While training the segmentation task, assert exemption raise because the y_true tensor is Unknown.
https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_categorical_focal_loss.py#L136-L141
How do I define the true tensor? In my task, the true tensor is shaped with (BATCH, HEIGHT, WIDTH). My virtual environment is on ubuntu18.04, tensorflow 2.2.0