@@ -418,6 +418,32 @@ def test_idx_lambda_to_hlo():
418
418
(b , a ))
419
419
420
420
421
+ def test_stringify ():
422
+ x = pt .make_placeholder ("x" , (10 , 4 ), np .int64 )
423
+ y = pt .make_placeholder ("y" , (10 , 4 ), np .int64 )
424
+
425
+ assert (str (3 * x + 4 * y )
426
+ == "3*x + 4*y" )
427
+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
428
+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
429
+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
430
+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
431
+ assert (str (y * pt .not_equal (x , 3 ))
432
+ == "y*(x != 3)" )
433
+ assert (str (3 * y @ pt .sum (x , axis = 0 ))
434
+ == "3*y @ sum(x, axis=0)" )
435
+ assert (str (x [y [:, 2 :3 ], x [2 , :]])
436
+ == "x[y[::, 2:3:], x[2]]" )
437
+ assert (str (pt .stack ([x [y [:, 2 :3 ], x [2 , :]].T , y [x [:, 2 :3 ], y [2 , :]].T ]))
438
+ == ("stack([transpose(x[y[::, 2:3:], x[2]]),"
439
+ " transpose(y[x[::, 2:3:], y[2]])])" ))
440
+ assert (str (pt .concatenate ([x [y [:, 2 :3 ], x [2 , :]],
441
+ y [x [:, 2 :3 ], y [2 , :]]]))
442
+ == "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])" )
443
+ assert (str (pt .einsum ("ij,i->i" , 2 * x , pt .sum (y , axis = 1 )))
444
+ == 'einsum("ij, i -> i", 2*x, sum(y, axis=1))' )
445
+
446
+
421
447
if __name__ == "__main__" :
422
448
if len (sys .argv ) > 1 :
423
449
exec (sys .argv [1 ])
0 commit comments