@@ -106,21 +106,41 @@ def place_on_path(diagrams: Iterable[Diagram], path: Path) -> Diagram:
106
106
107
107
Cat = Union [Iterable [Diagram ], Diagram ]
108
108
def cat (
109
- diagram : Cat , v : V2_t , sep : Optional [Floating ] = None , axis = 0
109
+ diagram : Cat , v : V2_t , sep : Optional [Floating ] = None
110
110
) -> Diagram :
111
111
if isinstance (diagram , Diagram ):
112
- assert diagram .size () != ()
112
+ axes = diagram .size ()
113
+ axis = len (axes ) - 1
114
+ assert diagram .size () != ()
115
+ diagram = diagram ._normalize ()
113
116
import jax
114
117
from functools import partial
115
- def fn (a : Diagram , b : Diagram ) -> Diagram :
116
- @partial (jax .vmap , in_axes = axis , out_axes = axis )
117
- def merge (a , b ):
118
- new = a .juxtapose (b , v )
119
- if sep is not None :
120
- return new .translate_by (v * sep )
121
- return new
122
- return merge (a , b )
123
- return jax .lax .associative_scan (fn , diagram , axis = axis ).compose_axis ()
118
+ # def fn(a: Diagram, b: Diagram) -> Diagram:
119
+ # @partial(jax.vmap)
120
+ # def merge(a, b):
121
+ # b.get_envelope()(-v)
122
+ # new = a.juxtapose(b, v)
123
+ # return new
124
+ # return merge(a, b)
125
+ def call_scan (diagram ):
126
+ @jax .vmap
127
+ def offset (diagram ):
128
+ env = diagram .get_envelope ()
129
+ right = env (v )
130
+ left = env (- v )
131
+ return right , left
132
+ right , left = offset (diagram )
133
+ off = tx .X .np .roll (right , 1 ) + left
134
+ off = off .at [0 ].set (0 )
135
+ off = tx .X .np .cumsum (off , axis = 0 )
136
+ @jax .vmap
137
+ def translate (off , diagram ):
138
+ return diagram .translate_by (v * off [..., None , None ])
139
+ return translate (off , diagram )
140
+ #return jax.lax.associative_scan(fn, diagram, axis=0).compose_axis()
141
+ for a in range (axis ):
142
+ call_scan = jax .vmap (call_scan , in_axes = a , out_axes = a )
143
+ return call_scan (diagram ).compose_axis ()
124
144
125
145
else :
126
146
diagrams = iter (diagram )
@@ -148,7 +168,12 @@ def concat(diagrams: Iterable[Diagram]) -> Diagram:
148
168
"""
149
169
from chalk .core import BaseDiagram
150
170
151
- return BaseDiagram .concat (diagrams ) # type: ignore
171
+ if isinstance (diagram , Diagram ):
172
+ size = diagram .size ()
173
+ assert size != ()
174
+ return diagram .compose_axis ()
175
+ else :
176
+ return BaseDiagram .concat (diagrams ) # type: ignore
152
177
153
178
154
179
def empty () -> Diagram :
@@ -180,7 +205,7 @@ def vstrut(height: Optional[Floating]) -> Diagram:
180
205
181
206
182
207
def hcat (
183
- diagrams : Iterable [Diagram ], sep : Optional [Floating ] = None , axis = 0
208
+ diagrams : Iterable [Diagram ], sep : Optional [Floating ] = None
184
209
) -> Diagram :
185
210
"""
186
211
Stack diagrams next to each other with `besides`.
@@ -193,12 +218,11 @@ def hcat(
193
218
Diagram: New diagram
194
219
195
220
"""
196
- return cat (diagrams , tx .X .unit_x , sep , axis = axis )
221
+ return cat (diagrams , tx .X .unit_x , sep )
197
222
198
223
199
224
def vcat (
200
- diagrams : Iterable [Diagram ], sep : Optional [Floating ] = None , axis = 0
201
- ) -> Diagram :
225
+ diagrams : Iterable [Diagram ], sep : Optional [Floating ] = None ) -> Diagram :
202
226
"""
203
227
Stack diagrams above each other with `above`.
204
228
@@ -210,7 +234,7 @@ def vcat(
210
234
Diagrams
211
235
212
236
"""
213
- return cat (diagrams , tx .X .unit_y , sep , axis = axis )
237
+ return cat (diagrams , tx .X .unit_y , sep )
214
238
215
239
216
240
# Extra
0 commit comments