@@ -54,8 +54,8 @@ def __call__(self, *args, **kwargs):
54
54
55
55
# Set free_indices if not set already
56
56
if not hasattr (obj , 'free_indices' ):
57
- obj .free_indices = unique (chain ( * [ c .free_indices
58
- for c in obj .children ] ))
57
+ obj .free_indices = unique (chain . from_iterable ( c .free_indices
58
+ for c in obj .children ))
59
59
# Set dtype if not set already.
60
60
if not hasattr (obj , 'dtype' ):
61
61
obj .dtype = obj .inherit_dtype_from_children (obj .children )
@@ -306,9 +306,6 @@ def value(self):
306
306
def shape (self ):
307
307
return self .array .shape
308
308
309
- def __getitem__ (self , i ):
310
- return self .array [i ]
311
-
312
309
313
310
class Variable (Terminal ):
314
311
"""Symbolic variable tensor"""
@@ -337,7 +334,7 @@ def __new__(cls, a, b):
337
334
return a
338
335
339
336
if isinstance (a , Constant ) and isinstance (b , Constant ):
340
- return Literal (a .value + b .value , dtype = Node .inherit_dtype_from_children ([ a , b ] ))
337
+ return Literal (a .value + b .value , dtype = Node .inherit_dtype_from_children (( a , b ) ))
341
338
342
339
self = super (Sum , cls ).__new__ (cls )
343
340
self .children = a , b
@@ -361,7 +358,7 @@ def __new__(cls, a, b):
361
358
return a
362
359
363
360
if isinstance (a , Constant ) and isinstance (b , Constant ):
364
- return Literal (a .value * b .value , dtype = Node .inherit_dtype_from_children ([ a , b ] ))
361
+ return Literal (a .value * b .value , dtype = Node .inherit_dtype_from_children (( a , b ) ))
365
362
366
363
self = super (Product , cls ).__new__ (cls )
367
364
self .children = a , b
@@ -385,7 +382,7 @@ def __new__(cls, a, b):
385
382
return a
386
383
387
384
if isinstance (a , Constant ) and isinstance (b , Constant ):
388
- return Literal (a .value / b .value , dtype = Node .inherit_dtype_from_children ([ a , b ] ))
385
+ return Literal (a .value / b .value , dtype = Node .inherit_dtype_from_children (( a , b ) ))
389
386
390
387
self = super (Division , cls ).__new__ (cls )
391
388
self .children = a , b
@@ -676,6 +673,19 @@ def __new__(cls, aggregate, multiindex):
676
673
if isinstance (aggregate , Zero ):
677
674
return Zero (dtype = aggregate .dtype )
678
675
676
+ # Simplify Literal and ListTensor
677
+ if isinstance (aggregate , (Constant , ListTensor )):
678
+ if all (isinstance (i , int ) for i in multiindex ):
679
+ # All indices fixed
680
+ sub = aggregate .array [multiindex ]
681
+ return Literal (sub , dtype = aggregate .dtype ) if isinstance (aggregate , Constant ) else sub
682
+ elif any (isinstance (i , int ) for i in multiindex ) and all (isinstance (i , (int , Index )) for i in multiindex ):
683
+ # Some indices fixed
684
+ slices = tuple (i if isinstance (i , int ) else slice (None ) for i in multiindex )
685
+ sub = aggregate .array [slices ]
686
+ sub = Literal (sub , dtype = aggregate .dtype ) if isinstance (aggregate , Constant ) else ListTensor (sub )
687
+ return Indexed (sub , tuple (i for i in multiindex if not isinstance (i , int )))
688
+
679
689
# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
680
690
if isinstance (aggregate , ComponentTensor ):
681
691
B , = aggregate .children
@@ -689,19 +699,6 @@ def __new__(cls, aggregate, multiindex):
689
699
ll = tuple (rep .get (k , k ) for k in kk )
690
700
return Indexed (C , ll )
691
701
692
- # Simplify Literal and ListTensor
693
- if isinstance (aggregate , (Constant , ListTensor )):
694
- if all (isinstance (i , int ) for i in multiindex ):
695
- # All indices fixed
696
- sub = aggregate [multiindex ]
697
- return Literal (sub , dtype = aggregate .dtype ) if isinstance (aggregate , Constant ) else sub
698
- elif any (isinstance (i , int ) for i in multiindex ) and all (isinstance (i , (int , Index )) for i in multiindex ):
699
- # Some indices fixed
700
- slices = tuple (i if isinstance (i , int ) else slice (None ) for i in multiindex )
701
- sub = aggregate [slices ]
702
- sub = Literal (sub , dtype = aggregate .dtype ) if isinstance (aggregate , Constant ) else ListTensor (sub )
703
- return Indexed (sub , tuple (i for i in multiindex if not isinstance (i , int )))
704
-
705
702
self = super (Indexed , cls ).__new__ (cls )
706
703
self .children = (aggregate ,)
707
704
self .multiindex = multiindex
@@ -945,9 +942,6 @@ def shape(self):
945
942
def __reduce__ (self ):
946
943
return type (self ), (self .array ,)
947
944
948
- def __getitem__ (self , i ):
949
- return self .array [i ]
950
-
951
945
def reconstruct (self , * args ):
952
946
return ListTensor (asarray (args ).reshape (self .array .shape ))
953
947
@@ -958,7 +952,7 @@ def is_equal(self, other):
958
952
"""Common subexpression eliminating equality predicate."""
959
953
if type (self ) is not type (other ):
960
954
return False
961
- if (self .array == other .array ). all ( ):
955
+ if numpy . array_equal (self .array , other .array ):
962
956
self .array = other .array
963
957
return True
964
958
return False
0 commit comments