Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from keras.src.ops.numpy import flip as flip
from keras.src.ops.numpy import floor as floor
from keras.src.ops.numpy import floor_divide as floor_divide
from keras.src.ops.numpy import fmod as fmod
from keras.src.ops.numpy import full as full
from keras.src.ops.numpy import full_like as full_like
from keras.src.ops.numpy import gcd as gcd
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from keras.src.ops.numpy import flip as flip
from keras.src.ops.numpy import floor as floor
from keras.src.ops.numpy import floor_divide as floor_divide
from keras.src.ops.numpy import fmod as fmod
from keras.src.ops.numpy import full as full
from keras.src.ops.numpy import full_like as full_like
from keras.src.ops.numpy import gcd as gcd
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,12 @@ def mod(x1, x2):
return jnp.mod(x1, x2)


def fmod(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.fmod(x1, x2)


def moveaxis(x, source, destination):
return jnp.moveaxis(x, source=source, destination=destination)

Expand Down
11 changes: 11 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,17 @@ def mod(x1, x2):
return np.mod(x1, x2)


def fmod(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype == "bool":
dtype = "int32"
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.fmod(x1, x2)


def moveaxis(x, source, destination):
return np.moveaxis(x, source=source, destination=destination)

Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ NNOpsDtypeTest::test_glu_
NNOpsDtypeTest::test_polar_
NNOpsDynamicShapeTest::test_glu
NumpyDtypeTest::test_array
NumpyDtypeTest::test_fmod
NumpyDtypeTest::test_maximum_python_types
NumpyDtypeTest::test_minimum_python_types
NumpyDtypeTest::test_nanargmax
Expand All @@ -510,6 +511,9 @@ NumpyOneInputOpsCorrectnessTest::test_vectorize
NumpyOneInputOpsCorrectnessTest::test_view
NumpyOneInputOpsDynamicShapeTest::test_view
NumpyOneInputOpsStaticShapeTest::test_view
NumpyTwoInputOpsCorrectnessTest::test_fmod
NumpyTwoInputOpsDynamicShapeTest::test_fmod
NumpyTwoInputOpsStaticShapeTest::test_fmod
OptimizerTest::test_constraints_are_applied
OptimizerTest::test_ema
OptimizerTest::test_gradient_accumulation
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2827,6 +2827,10 @@ def mod(x1, x2):
return OpenVINOKerasTensor(ov_opset.floor_mod(x1, x2).output(0))


def fmod(x1, x2):
raise NotImplementedError("fmod is not supported by openvino backend.")


def moveaxis(x, source, destination):
x = get_ov_output(x)
if isinstance(source, int):
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,19 @@ def mod(x1, x2):
return tf.math.mod(x1, x2)


def fmod(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype == "bool":
dtype = "int32"
x1 = tf.cast(x1, dtype)
x2 = tf.cast(x2, dtype)
quotient = x1 / x2
truncated = tf.sign(quotient) * tf.math.floor(tf.math.abs(quotient))
return x1 - truncated * x2
Comment on lines +2179 to +2182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation can be simplified by using tf.trunc. More importantly, it returns a float tensor for integer inputs, which is inconsistent with other backends and numpy.fmod. The result should be cast back to the original integer dtype if the inputs were integers.

Suggested change
quotient = x1 / x2
truncated = tf.sign(quotient) * tf.math.floor(tf.math.abs(quotient))
return x1 - truncated * x2
quotient = x1 / x2
result = x1 - tf.trunc(quotient) * x2
if "int" in dtype:
return tf.cast(result, dtype)
return result

Comment on lines +2179 to +2182
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be a faster implementation:

return tf.sign(x1) * tf.math.floormod(tf.abs(x1), tf.abs(x2))



def moveaxis(x, source, destination):
x = convert_to_tensor(x)

Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,16 @@ def mod(x1, x2):
return torch.remainder(x1, x2)


def fmod(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
if dtype == "bool":
x1 = cast(x1, "int32")
x2 = cast(x2, "int32")
return torch.fmod(x1, x2)


def moveaxis(x, source, destination):
x = convert_to_tensor(x)
return torch.moveaxis(x, source=source, destination=destination)
Expand Down
38 changes: 38 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5255,6 +5255,44 @@ def mod(x1, x2):
return backend.numpy.mod(x1, x2)


class Fmod(Operation):
def call(self, x1, x2):
return backend.numpy.fmod(x1, x2)

def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
output_dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
if output_dtype == "bool":
output_dtype = "int32"
return KerasTensor(output_shape, dtype=output_dtype)


@keras_export(["keras.ops.fmod", "keras.ops.numpy.fmod"])
def fmod(x1, x2):
"""Returns the element-wise remainder of division with truncation.

Computes the remainder complementary to the `floor_divide` function,
equivalent to the C library function ``fmod``. The result has the same
sign as the dividend ``x1``. This is different from `keras.ops.mod`
which has the same sign as the divisor ``x2``.

Args:
x1: First tensor, the dividend.
x2: Second tensor, the divisor.

Returns:
Output tensor, element-wise remainder with truncation.
"""
Comment on lines +5288 to +5290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Per the Keras API design guidelines, all docstrings should include code examples. Adding an example here would help users understand the function's behavior, especially its difference from keras.ops.mod.

    Returns:
        Output tensor, element-wise remainder with truncation.

    Examples:
    >>> x1 = keras.ops.array([-3., -2., -1., 1., 2., 3.])
    >>> x2 = keras.ops.array([2., 2., 2., 2., 2., 2.])
    >>> keras.ops.fmod(x1, x2)
    array([-1., -0., -1.,  1.,  0.,  1.], dtype=float32)

    >>> x1 = keras.ops.array([1, 2, 3, 4, 5])
    >>> x2 = keras.ops.array([-2, -2, -2, -2, -2])
    >>> keras.ops.fmod(x1, x2)
    array([1, 0, 1, 0, 1], dtype=int32)
    """
References
  1. All docstrings should include code examples. (link)

if any_symbolic_tensors((x1, x2)):
return Fmod().symbolic_call(x1, x2)
return backend.numpy.fmod(x1, x2)


class Moveaxis(Operation):
def __init__(self, source, destination, *, name=None):
super().__init__(name=name)
Expand Down
50 changes: 50 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def test_mod(self):
y = KerasTensor((2, None))
self.assertEqual(knp.mod(x, y).shape, (2, 3))

def test_fmod(self):
x = KerasTensor((None, 3))
y = KerasTensor((2, None))
self.assertEqual(knp.fmod(x, y).shape, (2, 3))

def test_nextafter(self):
x = KerasTensor((None, 3))
y = KerasTensor((1, 3))
Expand Down Expand Up @@ -1002,6 +1007,19 @@ def test_mod(self):
y = KerasTensor((2, 3, 4))
knp.mod(x, y)

def test_fmod(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
self.assertEqual(knp.fmod(x, y).shape, (2, 3))

x = KerasTensor((2, 3))
self.assertEqual(knp.fmod(x, 2).shape, (2, 3))

with self.assertRaises(ValueError):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3, 4))
knp.fmod(x, y)

def test_nextafter(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
Expand Down Expand Up @@ -3811,6 +3829,17 @@ def test_mod(self):
self.assertAllClose(knp.Mod()(x, 1), np.mod(x, 1))
self.assertAllClose(knp.Mod()(1, x), np.mod(1, x))

def test_fmod(self):
x = np.array([[-3, 7], [5, -2]])
y = np.array([[2, -3], [3, 4]])
self.assertAllClose(knp.fmod(x, y), np.fmod(x, y))
self.assertAllClose(knp.fmod(x, 2), np.fmod(x, 2))
self.assertAllClose(knp.fmod(1, x), np.fmod(1, x))

self.assertAllClose(knp.Fmod()(x, y), np.fmod(x, y))
self.assertAllClose(knp.Fmod()(x, 2), np.fmod(x, 2))
self.assertAllClose(knp.Fmod()(1, x), np.fmod(1, x))

def test_nextafter(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
Expand Down Expand Up @@ -9520,6 +9549,27 @@ def test_mod(self, dtypes):
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
def test_fmod(self, dtypes):
import jax.numpy as jnp

dtype1, dtype2 = dtypes
x1 = knp.ones((), dtype=dtype1)
x2 = knp.ones((), dtype=dtype2)
x1_jax = jnp.ones((), dtype=dtype1)
x2_jax = jnp.ones((), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.fmod(x1_jax, x2_jax).dtype)

self.assertEqual(
standardize_dtype(knp.fmod(x1, x2).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Fmod().symbolic_call(x1, x2).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_moveaxis(self, dtype):
import jax.numpy as jnp
Expand Down
Loading