Skip to content

Commit f1fd8ce

Browse files
committed
tests stringification of arrays
1 parent 2fe3e46 commit f1fd8ce

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/test_pytato.py

+26
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,32 @@ def test_idx_lambda_to_hlo():
418418
(b, a))
419419

420420

421+
def test_stringify():
422+
x = pt.make_placeholder("x", (10, 4), np.int64)
423+
y = pt.make_placeholder("y", (10, 4), np.int64)
424+
425+
assert(str(3*x+4*y)
426+
== "3*x + 4*y")
427+
assert(str(pt.roll(x.reshape(2, 20).reshape(-1), 3))
428+
== "roll(reshape(reshape(x, (2, 20)), 40), 3)")
429+
assert(str(pt.roll(x.reshape(2, 20).reshape(-1), 3))
430+
== "roll(reshape(reshape(x, (2, 20)), 40), 3)")
431+
assert(str(y * pt.not_equal(x, 3))
432+
== "y*(x != 3)")
433+
assert(str(3 * y @ pt.sum(x, axis=0))
434+
== "3*y @ sum(x, axis=0)")
435+
assert(str(x[y[:, 2:3], x[2, :]])
436+
== "x[y[::, 2:3:], x[2]]")
437+
assert(str(pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]))
438+
== ("stack([transpose(x[y[::, 2:3:], x[2]]),"
439+
" transpose(y[x[::, 2:3:], y[2]])])"))
440+
assert(str(pt.concatenate([x[y[:, 2:3], x[2, :]],
441+
y[x[:, 2:3], y[2, :]]]))
442+
== "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])")
443+
assert(str(pt.einsum("ij,i->i", 2*x, pt.sum(y, axis=1)))
444+
== 'einsum("ij, i -> i", 2*x, sum(y, axis=1))')
445+
446+
421447
if __name__ == "__main__":
422448
if len(sys.argv) > 1:
423449
exec(sys.argv[1])

0 commit comments

Comments
 (0)