Skip to content

Commit 4e5c987

Browse files
committed
fix a bug in tensor.flatten
1 parent ec10537 commit 4e5c987

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

torchax/test/test_functions.py

+6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def test_bernoulli_inplace(self):
6060
a = torch.randn((2,3))
6161
a.bernoulli_(0.4)
6262

63+
def test_flatten(self):
64+
with self.env:
65+
a = torch.randn((2,3,4))
66+
a = a.flatten(0, 1)
67+
self.assertEqual(tuple(a.shape), (6, 4))
68+
6369
def test_rnn(self):
6470
model = SeqModel()
6571
x = torch.randn((2, 100, 20))

torchax/torchax/tensor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def flatten(self, start_dim=0, end_dim=-1):
100100
if end_dim == -1:
101101
end_dim = self.ndim
102102
new_shape = (
103-
self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim:])
103+
self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
104104
new_elem = jnp.reshape(self._elem, new_shape)
105105
return Tensor(new_elem, self._env)
106106
# return torch.reshape(self, new_shape)
@@ -371,6 +371,8 @@ def _handle_tensor_constructor(self, func, args, kwargs):
371371
return func(*args, **kwargs)
372372
with jax.default_device(jax_device):
373373
op = self._ops.get(func)
374+
if op is None and isinstance(func, torch._ops.OpOverload):
375+
op = self._ops.get(func.overloadpacket)
374376
res = op.func(*args, **kwargs)
375377
if isinstance(res, jax.Array):
376378
res = Tensor(res, self)

0 commit comments

Comments
 (0)