Skip to content

Commit 2007994

Browse files
committed
tests stringification of arrays
1 parent 147bd30 commit 2007994

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/test_pytato.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,32 @@ def test_idx_lambda_to_hlo():
385385
(b, a))
386386

387387

388+
def test_stringify():
389+
x = pt.make_placeholder("x", (10, 4), np.int64)
390+
y = pt.make_placeholder("y", (10, 4), np.int64)
391+
392+
assert(str(3*x+4*y)
393+
== "3*x + 4*y")
394+
assert(str(pt.roll(x.reshape(2, 20).reshape(-1), 3))
395+
== "roll(reshape(reshape(x, (2, 20)), 40), 3)")
396+
assert(str(pt.roll(x.reshape(2, 20).reshape(-1), 3))
397+
== "roll(reshape(reshape(x, (2, 20)), 40), 3)")
398+
assert(str(y * pt.not_equal(x, 3))
399+
== "y*(x != 3)")
400+
assert(str(3 * y @ pt.sum(x, axis=0))
401+
== "3*y @ sum(x, axis=0)")
402+
assert(str(x[y[:, 2:3], x[2, :]])
403+
== "x[y[::, 2:3:], x[2]]")
404+
assert(str(pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]))
405+
== ("stack([transpose(x[y[::, 2:3:], x[2]]),"
406+
" transpose(y[x[::, 2:3:], y[2]])])"))
407+
assert(str(pt.concatenate([x[y[:, 2:3], x[2, :]],
408+
y[x[:, 2:3], y[2, :]]]))
409+
== "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])")
410+
assert(str(pt.einsum("ij,i->i", 2*x, pt.sum(y, axis=1)))
411+
== 'einsum("ij, i -> i", 2*x, sum(y, axis=1))')
412+
413+
388414
if __name__ == "__main__":
389415
if len(sys.argv) > 1:
390416
exec(sys.argv[1])

0 commit comments

Comments
 (0)