Skip to content

Commit a88486c

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Fix warnings in array_interoperability_test.
PiperOrigin-RevId: 747586780
1 parent a64e7dc commit a88486c

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tests/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ jax_multiplatform_test(
102102
enable_configs = [
103103
"gpu_h100x2",
104104
],
105-
env = {
106-
"PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes.
107-
},
108105
tags = ["multiaccelerator"],
109106
deps = py_deps("tensorflow_core"),
110107
)

tests/array_interoperability_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def setUp(self):
9595
message="Calling from_dlpack with a DLPack tensor",
9696
category=DeprecationWarning,
9797
)
98+
@jtu.ignore_warning(
99+
message="jax.dlpack.to_dlpack was deprecated.*",
100+
category=DeprecationWarning,
101+
)
98102
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
99103
rng = jtu.rand_default(self.rng())
100104
np = rng(shape, dtype)
@@ -188,6 +192,10 @@ def testTensorFlowToJax(self, shape, dtype):
188192
dtype=dlpack_dtypes,
189193
)
190194
@unittest.skipIf(not tf, "Test requires TensorFlow")
195+
@jtu.ignore_warning(
196+
message="jax.dlpack.to_dlpack was deprecated.*",
197+
category=DeprecationWarning,
198+
)
191199
def testJaxToTensorFlow(self, shape, dtype):
192200
if (not config.enable_x64.value and
193201
dtype in [jnp.int64, jnp.uint64, jnp.float64]):

0 commit comments

Comments
 (0)