diff --git a/keras_contrib/layers/crf.py b/keras_contrib/layers/crf.py index 88a64ac69..3d43537d4 100644 --- a/keras_contrib/layers/crf.py +++ b/keras_contrib/layers/crf.py @@ -513,6 +513,7 @@ def recursion(self, input_energy, mask=None, go_backwards=False, constants = [chain_energy] if mask is not None: + mask = K.cast(mask, K.floatx()) mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1), K.floatx()) constants.append(mask2)