Skip to content

Commit a2110f5

Browse files
committed
fix trace
1 parent 695de84 commit a2110f5

18 files changed

+387
-216
lines changed

chalk/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
chalk.core.ApplyTransform,
4646
chalk.core.ComposeAxis,
4747
chalk.envelope.EnvDistance,
48+
chalk.trace.TraceDistances,
4849
chalk.style.StyleHolder,
4950
chalk.shapes.Trail,
5051
chalk.shapes.Path,

chalk/backend/cairo.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,22 @@ def render_cairo_prims(
146146
if even_odd:
147147
ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
148148
shape_renderer = ToCairoShape()
149-
print("PRIM:", len(prims))
149+
print("prim", len(prims))
150+
d = {}
150151
for prim in prims:
151-
# print(p.transform.shape)
152-
# for prim in p:
152+
print(prim.order)
153+
for ind, i in tx.X.np.ndenumerate(prim.order):
154+
d[i] = (prim, ind)
155+
156+
print(list(d.keys()))
157+
for j in sorted(d.keys()):
158+
prim, ind = d[j]
159+
print(ind)
160+
print(prim.order.shape)
161+
print(prim.size())
162+
prim = prim.split(ind)
163+
print(prim.size())
153164
print(prim.transform.shape)
154-
155165
for i in range(prim.transform.shape[0]):
156166
# apply transformation
157167
matrix = tx_to_cairo(prim.transform[i : i + 1])

chalk/combinators.py

+41-17
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,41 @@ def place_on_path(diagrams: Iterable[Diagram], path: Path) -> Diagram:
106106

107107
Cat = Union[Iterable[Diagram], Diagram]
108108
def cat(
109-
diagram: Cat, v: V2_t, sep: Optional[Floating] = None, axis=0
109+
diagram: Cat, v: V2_t, sep: Optional[Floating] = None
110110
) -> Diagram:
111111
if isinstance(diagram, Diagram):
112-
assert diagram.size() != ()
112+
axes = diagram.size()
113+
axis = len(axes) - 1
114+
assert diagram.size() != ()
115+
diagram = diagram._normalize()
113116
import jax
114117
from functools import partial
115-
def fn(a: Diagram, b: Diagram) -> Diagram:
116-
@partial(jax.vmap, in_axes=axis, out_axes=axis)
117-
def merge(a, b):
118-
new = a.juxtapose(b, v)
119-
if sep is not None:
120-
return new.translate_by(v * sep)
121-
return new
122-
return merge(a, b)
123-
return jax.lax.associative_scan(fn, diagram, axis=axis).compose_axis()
118+
# def fn(a: Diagram, b: Diagram) -> Diagram:
119+
# @partial(jax.vmap)
120+
# def merge(a, b):
121+
# b.get_envelope()(-v)
122+
# new = a.juxtapose(b, v)
123+
# return new
124+
# return merge(a, b)
125+
def call_scan(diagram):
126+
@jax.vmap
127+
def offset(diagram):
128+
env = diagram.get_envelope()
129+
right = env(v)
130+
left = env(-v)
131+
return right, left
132+
right, left = offset(diagram)
133+
off = tx.X.np.roll(right, 1) + left
134+
off = off.at[0].set(0)
135+
off = tx.X.np.cumsum(off, axis=0)
136+
@jax.vmap
137+
def translate(off, diagram):
138+
return diagram.translate_by(v * off[..., None, None])
139+
return translate(off, diagram)
140+
#return jax.lax.associative_scan(fn, diagram, axis=0).compose_axis()
141+
for a in range(axis):
142+
call_scan = jax.vmap(call_scan, in_axes=a, out_axes=a)
143+
return call_scan(diagram).compose_axis()
124144

125145
else:
126146
diagrams = iter(diagram)
@@ -148,7 +168,12 @@ def concat(diagrams: Iterable[Diagram]) -> Diagram:
148168
"""
149169
from chalk.core import BaseDiagram
150170

151-
return BaseDiagram.concat(diagrams) # type: ignore
171+
if isinstance(diagram, Diagram):
172+
size = diagram.size()
173+
assert size != ()
174+
return diagram.compose_axis()
175+
else:
176+
return BaseDiagram.concat(diagrams) # type: ignore
152177

153178

154179
def empty() -> Diagram:
@@ -180,7 +205,7 @@ def vstrut(height: Optional[Floating]) -> Diagram:
180205

181206

182207
def hcat(
183-
diagrams: Iterable[Diagram], sep: Optional[Floating] = None, axis=0
208+
diagrams: Iterable[Diagram], sep: Optional[Floating] = None
184209
) -> Diagram:
185210
"""
186211
Stack diagrams next to each other with `besides`.
@@ -193,12 +218,11 @@ def hcat(
193218
Diagram: New diagram
194219
195220
"""
196-
return cat(diagrams, tx.X.unit_x, sep, axis=axis)
221+
return cat(diagrams, tx.X.unit_x, sep)
197222

198223

199224
def vcat(
200-
diagrams: Iterable[Diagram], sep: Optional[Floating] = None, axis=0
201-
) -> Diagram:
225+
diagrams: Iterable[Diagram], sep: Optional[Floating] = None) -> Diagram:
202226
"""
203227
Stack diagrams above each other with `above`.
204228
@@ -210,7 +234,7 @@ def vcat(
210234
Diagrams
211235
212236
"""
213-
return cat(diagrams, tx.X.unit_y, sep, axis=axis)
237+
return cat(diagrams, tx.X.unit_y, sep)
214238

215239

216240
# Extra

0 commit comments

Comments
 (0)