Skip to content

Commit c486350

Browse files
committed
tests stringification of arrays
1 parent cc6bfc1 commit c486350

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/test_pytato.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,32 @@ def test_idx_lambda_to_hlo():
424424
== BroadcastOp(a))
425425

426426

427+
def test_stringify():
428+
x = pt.make_placeholder("x", (10, 4), np.int64)
429+
y = pt.make_placeholder("y", (10, 4), np.int64)
430+
431+
assert(str(3*x+4*y)
432+
== "3*x + 4*y")
433+
assert(str(pt.roll(x.reshape(2, 20).reshape(-1), 3))
434+
== "roll(reshape(reshape(x, (2, 20)), 40), 3)")
435+
assert(str(pt.roll(x.reshape(2, 20).reshape(-1), 3))
436+
== "roll(reshape(reshape(x, (2, 20)), 40), 3)")
437+
assert(str(y * pt.not_equal(x, 3))
438+
== "y*(x != 3)")
439+
assert(str(3 * y @ pt.sum(x, axis=0))
440+
== "3*y @ sum(x, axis=0)")
441+
assert(str(x[y[:, 2:3], x[2, :]])
442+
== "x[y[::, 2:3:], x[2]]")
443+
assert(str(pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]))
444+
== ("stack([transpose(x[y[::, 2:3:], x[2]]),"
445+
" transpose(y[x[::, 2:3:], y[2]])])"))
446+
assert(str(pt.concatenate([x[y[:, 2:3], x[2, :]],
447+
y[x[:, 2:3], y[2, :]]]))
448+
== "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])")
449+
assert(str(pt.einsum("ij,i->i", 2*x, pt.sum(y, axis=1)))
450+
== 'einsum("ij, i -> i", 2*x, sum(y, axis=1))')
451+
452+
427453
if __name__ == "__main__":
428454
if len(sys.argv) > 1:
429455
exec(sys.argv[1])

0 commit comments

Comments
 (0)