Skip to content

Commit ef0b650

Browse files
committed
feat: implement floor and floor_ methods for TensorFlow, JAX, and NumPy frontends
- Add floor() and floor_() methods to TensorFlow EagerTensor frontend - Add floor() and floor_() methods to JAX Array frontend - Add floor() and floor_() methods to NumPy ndarray frontend - All floor_() methods are in-place operations using ivy.inplace_update - Add comprehensive test cases for all implementations - Resolves issue #21930 The floor_() method performs in-place floor operation on arrays/tensors, providing compatibility with native framework APIs. This completes the floor_ functionality across all major Ivy frontends. Co-authored-by: Kallal Mukherjee <ritamukherje62@gmail.com>
1 parent 56642ab commit ef0b650

File tree

6 files changed

+246
-0
lines changed

6 files changed

+246
-0
lines changed

ivy/functional/frontends/jax/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,13 @@ def var(
413413
def swapaxes(self, axis1, axis2):
414414
return jax_frontend.numpy.swapaxes(self, axis1=axis1, axis2=axis2)
415415

416+
def floor(self):
417+
return jax_frontend.numpy.floor(self)
418+
419+
def floor_(self):
420+
self._ivy_array = ivy.inplace_update(self._ivy_array, ivy.floor(self._ivy_array))
421+
return self
422+
416423
def tolist(self):
417424
return ivy.to_list(self.ivy_array)
418425

ivy/functional/frontends/numpy/ndarray/ndarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,13 @@ def __lshift__(self, value, /):
649649
def __ilshift__(self, value, /):
650650
return ivy.bitwise_left_shift(self.ivy_array, value, out=self)
651651

652+
def floor(self):
653+
return np_frontend.floor(self)
654+
655+
def floor_(self):
656+
self._ivy_array = ivy.inplace_update(self._ivy_array, ivy.floor(self._ivy_array))
657+
return self
658+
652659
def round(self, decimals=0, out=None):
653660
return np_frontend.round(self, decimals=decimals, out=out)
654661

ivy/functional/frontends/tensorflow/tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ def __iter__(self):
225225
for i in range(self.shape[0]):
226226
yield self[i]
227227

228+
def floor(self):
229+
return tensorflow_frontend.floor(self)
230+
231+
def floor_(self):
232+
self.ivy_array = ivy.inplace_update(self.ivy_array, ivy.floor(self.ivy_array))
233+
return self
234+
228235
def tolist(self):
229236
return ivy.to_list(self.ivy_array)
230237

ivy_tests/test_ivy/test_frontends/test_jax/test_array.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2868,3 +2868,78 @@ def test_jax_array_tolist(
28682868
on_device=on_device,
28692869
test_values=False, # tolist returns Python list, not array
28702870
)
2871+
2872+
2873+
# floor
2874+
@handle_frontend_method(
2875+
class_tree=CLASS_TREE,
2876+
init_tree="jax.numpy.array",
2877+
method_name="floor",
2878+
dtype_and_x=helpers.dtype_and_values(
2879+
available_dtypes=helpers.get_dtypes("float"),
2880+
min_value=-1e05,
2881+
max_value=1e05,
2882+
),
2883+
)
2884+
def test_jax_array_floor(
2885+
dtype_and_x,
2886+
frontend_method_data,
2887+
init_flags,
2888+
method_flags,
2889+
frontend,
2890+
on_device,
2891+
backend_fw,
2892+
):
2893+
input_dtypes, x = dtype_and_x
2894+
helpers.test_frontend_method(
2895+
init_input_dtypes=input_dtypes,
2896+
backend_to_test=backend_fw,
2897+
init_all_as_kwargs_np={
2898+
"object": x[0],
2899+
},
2900+
method_input_dtypes=input_dtypes,
2901+
method_all_as_kwargs_np={},
2902+
frontend_method_data=frontend_method_data,
2903+
init_flags=init_flags,
2904+
method_flags=method_flags,
2905+
frontend=frontend,
2906+
on_device=on_device,
2907+
)
2908+
2909+
2910+
# floor_
2911+
@handle_frontend_method(
2912+
class_tree=CLASS_TREE,
2913+
init_tree="jax.numpy.array",
2914+
method_name="floor_",
2915+
dtype_and_x=helpers.dtype_and_values(
2916+
available_dtypes=helpers.get_dtypes("float"),
2917+
min_value=-1e05,
2918+
max_value=1e05,
2919+
),
2920+
test_inplace=st.just(True),
2921+
)
2922+
def test_jax_array_floor_(
2923+
dtype_and_x,
2924+
frontend_method_data,
2925+
init_flags,
2926+
method_flags,
2927+
frontend,
2928+
on_device,
2929+
backend_fw,
2930+
):
2931+
input_dtypes, x = dtype_and_x
2932+
helpers.test_frontend_method(
2933+
init_input_dtypes=input_dtypes,
2934+
backend_to_test=backend_fw,
2935+
init_all_as_kwargs_np={
2936+
"object": x[0],
2937+
},
2938+
method_input_dtypes=input_dtypes,
2939+
method_all_as_kwargs_np={},
2940+
frontend_method_data=frontend_method_data,
2941+
init_flags=init_flags,
2942+
method_flags=method_flags,
2943+
frontend=frontend,
2944+
on_device=on_device,
2945+
)

ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3878,3 +3878,78 @@ def test_numpy_view(
38783878
frontend_method_data=frontend_method_data,
38793879
on_device=on_device,
38803880
)
3881+
3882+
3883+
# floor
3884+
@handle_frontend_method(
3885+
class_tree=CLASS_TREE,
3886+
init_tree="numpy.array",
3887+
method_name="floor",
3888+
dtype_and_x=helpers.dtype_and_values(
3889+
available_dtypes=helpers.get_dtypes("float"),
3890+
min_value=-1e05,
3891+
max_value=1e05,
3892+
),
3893+
)
3894+
def test_numpy_ndarray_floor(
3895+
dtype_and_x,
3896+
frontend_method_data,
3897+
init_flags,
3898+
method_flags,
3899+
frontend,
3900+
on_device,
3901+
backend_fw,
3902+
):
3903+
input_dtypes, x = dtype_and_x
3904+
helpers.test_frontend_method(
3905+
init_input_dtypes=input_dtypes,
3906+
backend_to_test=backend_fw,
3907+
init_all_as_kwargs_np={
3908+
"object": x[0],
3909+
},
3910+
method_input_dtypes=input_dtypes,
3911+
method_all_as_kwargs_np={},
3912+
init_flags=init_flags,
3913+
method_flags=method_flags,
3914+
frontend=frontend,
3915+
frontend_method_data=frontend_method_data,
3916+
on_device=on_device,
3917+
)
3918+
3919+
3920+
# floor_
3921+
@handle_frontend_method(
3922+
class_tree=CLASS_TREE,
3923+
init_tree="numpy.array",
3924+
method_name="floor_",
3925+
dtype_and_x=helpers.dtype_and_values(
3926+
available_dtypes=helpers.get_dtypes("float"),
3927+
min_value=-1e05,
3928+
max_value=1e05,
3929+
),
3930+
test_inplace=st.just(True),
3931+
)
3932+
def test_numpy_ndarray_floor_(
3933+
dtype_and_x,
3934+
frontend_method_data,
3935+
init_flags,
3936+
method_flags,
3937+
frontend,
3938+
on_device,
3939+
backend_fw,
3940+
):
3941+
input_dtypes, x = dtype_and_x
3942+
helpers.test_frontend_method(
3943+
init_input_dtypes=input_dtypes,
3944+
backend_to_test=backend_fw,
3945+
init_all_as_kwargs_np={
3946+
"object": x[0],
3947+
},
3948+
method_input_dtypes=input_dtypes,
3949+
method_all_as_kwargs_np={},
3950+
init_flags=init_flags,
3951+
method_flags=method_flags,
3952+
frontend=frontend,
3953+
frontend_method_data=frontend_method_data,
3954+
on_device=on_device,
3955+
)

ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,3 +1649,78 @@ def test_tensorflow_tensor_tolist(
16491649
on_device=on_device,
16501650
test_values=False, # tolist returns Python list, not array
16511651
)
1652+
1653+
1654+
# floor
1655+
@handle_frontend_method(
1656+
class_tree=CLASS_TREE,
1657+
init_tree="tensorflow.constant",
1658+
method_name="floor",
1659+
dtype_and_x=helpers.dtype_and_values(
1660+
available_dtypes=helpers.get_dtypes("float"),
1661+
min_value=-1e05,
1662+
max_value=1e05,
1663+
),
1664+
)
1665+
def test_tensorflow_tensor_floor(
1666+
dtype_and_x,
1667+
frontend_method_data,
1668+
init_flags,
1669+
method_flags,
1670+
frontend,
1671+
on_device,
1672+
backend_fw,
1673+
):
1674+
input_dtypes, x = dtype_and_x
1675+
helpers.test_frontend_method(
1676+
init_input_dtypes=input_dtypes,
1677+
backend_to_test=backend_fw,
1678+
init_all_as_kwargs_np={
1679+
"value": x[0],
1680+
},
1681+
method_input_dtypes=input_dtypes,
1682+
method_all_as_kwargs_np={},
1683+
frontend_method_data=frontend_method_data,
1684+
init_flags=init_flags,
1685+
method_flags=method_flags,
1686+
frontend=frontend,
1687+
on_device=on_device,
1688+
)
1689+
1690+
1691+
# floor_
1692+
@handle_frontend_method(
1693+
class_tree=CLASS_TREE,
1694+
init_tree="tensorflow.constant",
1695+
method_name="floor_",
1696+
dtype_and_x=helpers.dtype_and_values(
1697+
available_dtypes=helpers.get_dtypes("float"),
1698+
min_value=-1e05,
1699+
max_value=1e05,
1700+
),
1701+
test_inplace=st.just(True),
1702+
)
1703+
def test_tensorflow_tensor_floor_(
1704+
dtype_and_x,
1705+
frontend_method_data,
1706+
init_flags,
1707+
method_flags,
1708+
frontend,
1709+
on_device,
1710+
backend_fw,
1711+
):
1712+
input_dtypes, x = dtype_and_x
1713+
helpers.test_frontend_method(
1714+
init_input_dtypes=input_dtypes,
1715+
backend_to_test=backend_fw,
1716+
init_all_as_kwargs_np={
1717+
"value": x[0],
1718+
},
1719+
method_input_dtypes=input_dtypes,
1720+
method_all_as_kwargs_np={},
1721+
frontend_method_data=frontend_method_data,
1722+
init_flags=init_flags,
1723+
method_flags=method_flags,
1724+
frontend=frontend,
1725+
on_device=on_device,
1726+
)

0 commit comments

Comments
 (0)