@@ -424,6 +424,32 @@ def test_idx_lambda_to_hlo():
424
424
== BroadcastOp (a ))
425
425
426
426
427
+ def test_stringify ():
428
+ x = pt .make_placeholder ("x" , (10 , 4 ), np .int64 )
429
+ y = pt .make_placeholder ("y" , (10 , 4 ), np .int64 )
430
+
431
+ assert (str (3 * x + 4 * y )
432
+ == "3*x + 4*y" )
433
+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
434
+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
435
+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
436
+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
437
+ assert (str (y * pt .not_equal (x , 3 ))
438
+ == "y*(x != 3)" )
439
+ assert (str (3 * y @ pt .sum (x , axis = 0 ))
440
+ == "3*y @ sum(x, axis=0)" )
441
+ assert (str (x [y [:, 2 :3 ], x [2 , :]])
442
+ == "x[y[::, 2:3:], x[2]]" )
443
+ assert (str (pt .stack ([x [y [:, 2 :3 ], x [2 , :]].T , y [x [:, 2 :3 ], y [2 , :]].T ]))
444
+ == ("stack([transpose(x[y[::, 2:3:], x[2]]),"
445
+ " transpose(y[x[::, 2:3:], y[2]])])" ))
446
+ assert (str (pt .concatenate ([x [y [:, 2 :3 ], x [2 , :]],
447
+ y [x [:, 2 :3 ], y [2 , :]]]))
448
+ == "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])" )
449
+ assert (str (pt .einsum ("ij,i->i" , 2 * x , pt .sum (y , axis = 1 )))
450
+ == 'einsum("ij, i -> i", 2*x, sum(y, axis=1))' )
451
+
452
+
427
453
if __name__ == "__main__" :
428
454
if len (sys .argv ) > 1 :
429
455
exec (sys .argv [1 ])
0 commit comments