5050order : (property) get/set integration order
5151output : (property) get/set output flag
5252matrix : (property) get/set matrix flag
53- query : query line
53+ query : query line
5454__call__ : transform state
5555__len__ : get number of elements (first level)
5656__getitem__ : get (first level) element by key
7676from torch import Tensor
7777
7878from model .library .element import Element
79+ from model .library .element import transform
7980
8081from model .command .util import rotate
8182
@@ -143,7 +144,7 @@ def __init__(self,
143144
144145 elements :list [Element ] = [* self .scan ('name' )]
145146 for element in elements :
146- element .lines .add (name )
147+ element .lines .add (name )
147148
148149 self ._propagate : bool = propagate
149150 if self ._propagate :
@@ -1488,6 +1489,10 @@ def __call__(self,
14881489 if self .matrix :
14891490 container_matrix : list [Tensor ] = []
14901491
1492+ if alignment and self .alignment :
1493+ state = transform (self , state , data )
1494+ return state
1495+
14911496 for element in self .sequence :
14921497 state = element (state , alignment = alignment , data = data .get (element .name ))
14931498 if self .output :
@@ -1508,6 +1513,8 @@ def __call__(self,
15081513 self .container_matrix = torch .vstack (container_matrix ) if self ._propagate else torch .stack (container_matrix )
15091514
15101515 return state
1516+
1517+
15111518
15121519
15131520 def __len__ (self ) -> int :
0 commit comments