@@ -385,6 +385,32 @@ def test_idx_lambda_to_hlo():
385
385
(b , a ))
386
386
387
387
388
+ def test_stringify ():
389
+ x = pt .make_placeholder ("x" , (10 , 4 ), np .int64 )
390
+ y = pt .make_placeholder ("y" , (10 , 4 ), np .int64 )
391
+
392
+ assert (str (3 * x + 4 * y )
393
+ == "3*x + 4*y" )
394
+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
395
+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
396
+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
397
+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
398
+ assert (str (y * pt .not_equal (x , 3 ))
399
+ == "y*(x != 3)" )
400
+ assert (str (3 * y @ pt .sum (x , axis = 0 ))
401
+ == "3*y @ sum(x, axis=0)" )
402
+ assert (str (x [y [:, 2 :3 ], x [2 , :]])
403
+ == "x[y[::, 2:3:], x[2]]" )
404
+ assert (str (pt .stack ([x [y [:, 2 :3 ], x [2 , :]].T , y [x [:, 2 :3 ], y [2 , :]].T ]))
405
+ == ("stack([transpose(x[y[::, 2:3:], x[2]]),"
406
+ " transpose(y[x[::, 2:3:], y[2]])])" ))
407
+ assert (str (pt .concatenate ([x [y [:, 2 :3 ], x [2 , :]],
408
+ y [x [:, 2 :3 ], y [2 , :]]]))
409
+ == "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])" )
410
+ assert (str (pt .einsum ("ij,i->i" , 2 * x , pt .sum (y , axis = 1 )))
411
+ == 'einsum("ij, i -> i", 2*x, sum(y, axis=1))' )
412
+
413
+
388
414
if __name__ == "__main__" :
389
415
if len (sys .argv ) > 1 :
390
416
exec (sys .argv [1 ])
0 commit comments