Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ivy/functional/frontends/jax/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def var(
def swapaxes(self, axis1, axis2):
return jax_frontend.numpy.swapaxes(self, axis1=axis1, axis2=axis2)

def tolist(self):
return ivy.to_list(self.ivy_array)


# Jax supports DeviceArray from 0.4.13 and below
# Hence aliasing it here
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/frontends/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def __iter__(self):
for i in range(self.shape[0]):
yield self[i]

def tolist(self):
return ivy.to_list(self.ivy_array)


class TensorShape:
# TODO: there are still some methods that may need implementing
Expand Down
42 changes: 42 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,3 +2826,45 @@ def test_jax_swapaxes(
method_flags=method_flags,
on_device=on_device,
)


# tolist
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="jax.numpy.array",
method_name="tolist",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=0,
max_num_dims=5,
min_dim_size=1,
max_dim_size=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_dim_size=10,
max_dim_size=10,
min_value=-1e05,
max_value=1e05,

We need to add this so the values don't overflow and fail the test when running with a high number of examples (--num-examples 100)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed TEST_STATUS_NOTES.md file - Eliminated unnecessary documentation clutter

min_value=-1e05,
max_value=1e05,
),
)
def test_jax_array_tolist(
dtype_and_x,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
input_dtypes, x = dtype_and_x
helpers.test_frontend_method(
init_input_dtypes=input_dtypes,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"object": x[0],
},
method_input_dtypes=input_dtypes,
method_all_as_kwargs_np={},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
test_values=False, # tolist returns Python list, not array
)
42 changes: 42 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,3 +1607,45 @@ def test_tensorflow_shape(
x.ivy_array.shape, ivy.Shape(shape), as_array=False
)
ivy.previous_backend()


# tolist
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="tensorflow.constant",
method_name="tolist",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=0,
max_num_dims=5,
min_dim_size=1,
max_dim_size=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_dim_size=10,
max_dim_size=10,
min_value=-1e05,
max_value=1e05,

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed docstrings from frontend methods - Now follows Ivy's convention for clean, minimal frontend implementations

min_value=-1e05,
max_value=1e05,
),
)
def test_tensorflow_tensor_tolist(
dtype_and_x,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
input_dtypes, x = dtype_and_x
helpers.test_frontend_method(
init_input_dtypes=input_dtypes,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"value": x[0],
},
method_input_dtypes=input_dtypes,
method_all_as_kwargs_np={},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
test_values=False, # tolist returns Python list, not array
)