Skip to content

Commit d6b5945

Browse files
committed
update
1 parent efdc88a commit d6b5945

10 files changed

+110
-69
lines changed

chalk/array_types.py

-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def vmap(fn: Callable[[Array], Array]) -> Callable[[Array], Array]:
4848
"Fake jax vmap for numpy as a for loop."
4949
if JAX_MODE:
5050
return vmap(fn)
51-
5251
def vmap2(x: Array) -> Array:
5352
if isinstance(x, tuple):
5453
size = x[-1].size() # type: ignore

chalk/backend/cairo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def animate(
124124
path_frame = "/tmp/frame-{:d}.png"
125125
import imageio
126126

127-
with imageio.get_writer(path, fps=20, loop=0) as writer:
127+
with imageio.get_writer(path, fps=10, loop=0) as writer:
128128
for i in range(shape[0]):
129129
path = path_frame.format(i)
130130
patches_to_file(patches, path, h, w, (i,))

chalk/backend/patch.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def order_patches(
5151
) -> List[Tuple[Tuple[int, ...], Patch, Dict[str, Any]]]:
5252
import numpy as onp
5353

54-
patches = tx.tree_map(onp.asarray, patches)
54+
if tx.JAX_MODE:
55+
patches = tx.tree_map(onp.asarray, patches)
5556

5657
d = {}
5758
for patch in patches:
@@ -99,6 +100,7 @@ def from_path(
99100
for loc_trail in path.loc_trails:
100101
p = loc_trail.location
101102
segments = loc_trail.located_segments()
103+
print("SEG", transform.shape, transform.strides, loc_trail.trail.segments.transform.strides)
102104
vert = segment_to_curve(segments.transform, segments.angles)
103105
if path.scale_invariant is not None:
104106
scale = height / 20
@@ -147,10 +149,11 @@ def patch_from_prim(
147149
style = prim.style if prim.style is not None else style
148150
assert isinstance(prim.prim_shape, Path)
149151
assert prim.order is not None
152+
in_style = style.to_mpl() #tx.multi_vmap(style.to_mpl.__func__, len(size))(style), # type: ignore
150153
patch = Patch.from_path(
151154
prim.prim_shape,
152155
prim.transform,
153-
tx.multi_vmap(style.to_mpl.__func__, len(size))(style), # type: ignore
156+
in_style,
154157
prim.order,
155158
height,
156159
)

chalk/backend/svg.py

+61-41
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,27 @@ def to_svg(patch: Patch, ind: Tuple[int, ...]) -> str:
1212
v, c = patch.vert[ind], patch.command[ind]
1313
if v.shape[0] == 0:
1414
return "<g></g>"
15-
line = ""
15+
parts = []
1616
i = 0
1717
while i < c.shape[0]:
1818
if c[i] == chalk.backend.patch.Command.MOVETO.value:
19-
line += f"M {v[i, 0]} {v[i, 1]}"
19+
parts.append( f"M {v[i, 0]:.2f} {v[i, 1]:.2f}")
2020
i += 1
2121
elif c[i] == chalk.backend.patch.Command.LINETO.value:
22-
line += f"L {v[i, 0]} {v[i, 1]}"
22+
parts.append(f"L {v[i, 0]} {v[i, 1]}")
2323
i += 1
2424
elif c[i] == chalk.backend.patch.Command.CURVE3.value:
25-
line += f"Q {v[i, 0]} {v[i, 1]} {v[i+1, 0]} {v[i+1, 1]}"
25+
parts.append(f"Q {v[i, 0]} {v[i, 1]} {v[i+1, 0]} {v[i+1, 1]}")
2626
i += 2
2727
elif c[i] == chalk.backend.patch.Command.CLOSEPOLY.value:
28-
line += "Z"
28+
parts.append("Z")
2929
i += 1
3030
elif c[i] == chalk.backend.patch.Command.SKIP.value:
3131
i += 1
3232
elif c[i] == chalk.backend.patch.Command.CURVE4.value:
33-
line += f"C {v[i, 0]} {v[i, 1]} {v[i+1, 0]} {v[i+1, 1]} {v[i+2, 0]} {v[i+2, 1]}"
33+
parts.append(f"C {v[i, 0]:.2f} {v[i, 1]:.2f} {v[i+1, 0]:.2f} {v[i+1, 1]:.2f} {v[i+2, 0]:.2f} {v[i+2, 1]:.2f}")
3434
i += 3
35-
return line
35+
return " ".join(parts)
3636

3737

3838
def write_style(d: Dict[str, Any]) -> Dict[str, str]:
@@ -51,48 +51,68 @@ def write_style(d: Dict[str, Any]) -> Dict[str, str]:
5151
return out
5252

5353

54-
def render_svg_patches(patches: List[Patch], animate:bool =False, time_steps=0) -> str:
55-
out = ""
56-
patches = [chalk.backend.patch.order_patches(patches, (step,))
57-
for step in range(time_steps)]
58-
for v in zip(*patches):
59-
out += f"""
60-
<path>
61-
"""
62-
lines = []
63-
css = {}
64-
for ind, patch, style_new in v:
65-
lines.append(to_svg(patch, ind))
66-
s = write_style(style_new)
67-
for k, v in s.items():
68-
css.setdefault(k, []).append(v)
69-
70-
values = ";".join(lines)
71-
out += f"""
72-
<animate attributeName="d" values="{values}" dur="2s" repeatCount="indefinite"/>
54+
def render_svg_patches(patches: List[Patch], animate:bool=False, time_steps:int=0) -> str:
55+
if animate:
56+
out = ""
57+
patches = [chalk.backend.patch.order_patches(patches, (step,))
58+
for step in range(time_steps)]
59+
for v in zip(*patches):
60+
out += "\n\n <path>\n"
61+
lines = []
62+
css = {}
63+
64+
65+
for ind, patch, style_new in v:
66+
67+
lines.append(to_svg(patch, ind))
68+
s = write_style(style_new)
69+
for k, v in s.items():
70+
css.setdefault(k, []).append(v)
71+
s = set(lines)
72+
if len(s) == 1:
73+
out += f"""
74+
<set attributeName="d" to="{list(s)[0]}"/>
75+
"""
76+
else:
77+
values = ";".join(lines)
78+
out += f"""
79+
<animate attributeName="d" values="{values}" dur="2s" repeatCount="indefinite"/>
80+
"""
81+
for k, v in css.items():
82+
s = set(v)
83+
if len(s) == 1:
84+
out += f"""<set attributeName="{k}" to="{list(s)[0]}"/>"""
85+
86+
else:
87+
out += f"""
88+
<animate attributeName="{k}" values="{';'.join(v)}" dur="2s" repeatCount="indefinite"/>
7389
"""
74-
for k, v in css.items():
90+
out += "</path>\n\n"
91+
return out
92+
else:
93+
out = ""
94+
for ind, patch, style_new in chalk.backend.patch.order_patches(patches):
95+
inner = to_svg(patch, ind)
96+
style_t = ";".join([f"{k}:{v}" for k, v in write_style(style_new).items()] )
7597
out += f"""
76-
<animate attributeName="{k}" values="{';'.join(v)}" dur="2s" repeatCount="indefinite"/>
77-
"""
78-
out += """
79-
</path>
80-
"""
81-
return out
82-
98+
<g style="{style_t}">
99+
<path d="{inner}" />
100+
</g>"""
101+
return out
83102
def patches_to_file(
84103
patches: List[Patch], path: str, height: tx.IntLike,
85104
width: tx.IntLike,
86105
animate: bool = False,
87-
time_steps=0
106+
time_steps: int =0
88107
) -> None:
89-
dwg = f"""<?xml version="1.0" encoding="utf-8" ?>
90-
<svg baseProfile="full" height="{int(height)}" version="1.1" width="{int(width)}" xmlns="http://www.w3.org/2000/svg" xmlns:ev="http://www.w3.org/2001/xml-events" xmlns:xlink="http://www.w3.org/1999/xlink">
91-
"""
92-
dwg += render_svg_patches(patches, animate, time_steps)
93-
dwg += "</svg>"
94108
with open(path, "w") as f:
95-
f.write(dwg)
109+
f.write(f"""<?xml version="1.0" encoding="utf-8" ?>
110+
<svg baseProfile="full" height="{int(height)}" version="1.1" width="{int(width)}" xmlns="http://www.w3.org/2000/svg" xmlns:ev="http://www.w3.org/2001/xml-events" xmlns:xlink="http://www.w3.org/1999/xlink">
111+
""")
112+
113+
out = render_svg_patches(patches, animate, time_steps)
114+
f.write(out)
115+
f.write("</svg>")
96116

97117

98118
def render(

chalk/broadcast.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def check(x: tx.Array) -> None:
5555
V1 = TypeVar("V1", bound=Diagram)
5656
V2 = TypeVar("V2", bound=Diagram)
5757

58+
def broadcast_to(tree, old_shape: Tuple[int, ...], new_shape: Tuple[int, ...]):
59+
if old_shape == new_shape: return tree
60+
def reshape(x: tx.Array) -> tx.Array:
61+
shape = x.shape
62+
return tx.np.broadcast_to(x, new_shape + shape[len(old_shape):])
63+
64+
return tx.tree_map(reshape, tree)
65+
5866

5967
def broadcast_diagrams(self: V1, other: V2) -> Tuple[V1, V2]:
6068
"""
@@ -65,20 +73,23 @@ def broadcast_diagrams(self: V1, other: V2) -> Tuple[V1, V2]:
6573
other_size = other.size()
6674
if size == other_size:
6775
return self, other
68-
check(size, other_size, str(type(self)), str(type(other)))
69-
ml = max(len(size), len(other_size))
70-
for i in range(ml):
71-
off = -1 - i
72-
if i > len(other_size) - 1:
73-
other = other.add_axis(size[off]) # type: ignore
74-
elif i > len(size) - 1:
75-
self = self.add_axis(other_size[off]) # type: ignore
76-
elif size[off] == 1 and other_size[off] != 1:
77-
self = self.repeat_axis(other_size[off], len(size) + off) # type: ignore
78-
elif size[off] != 1 and other_size[off] == 1:
79-
other = other.repeat_axis(size[off], len(other_size) + off) # type: ignore
80-
check_consistent(self)
81-
check_consistent(other)
76+
new_shape = tx.np.broadcast_shapes(size, other_size)
77+
self = broadcast_to(self, size, new_shape)
78+
other = broadcast_to(other, other_size, new_shape)
79+
80+
# ml = max(len(size), len(other_size))
81+
# for i in range(ml):
82+
# off = -1 - i
83+
# if i > len(other_size) - 1:
84+
# other = other.add_axis(size[off]) # type: ignore
85+
# elif i > len(size) - 1:
86+
# self = self.add_axis(other_size[off]) # type: ignore
87+
# elif size[off] == 1 and other_size[off] != 1:
88+
# self = self.repeat_axis(other_size[off], len(size) + off) # type: ignore
89+
# elif size[off] != 1 and other_size[off] == 1:
90+
# other = other.repeat_axis(size[off], len(other_size) + off) # type: ignore
91+
# check_consistent(self)
92+
# check_consistent(other)
8293
assert (
8394
self.size() == other.size()
8495
), f"{size} {other_size} {self.size()} {other.size()}"

chalk/combinators.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,13 @@ def call_scan(diagram: Diagram) -> Diagram:
102102
env = diagram.get_envelope()
103103
right = env(v)
104104
left = env(-v)
105-
off = tx.np.roll(right, 1) + left + sep
106-
off = tx.index_update(off, 0, 0)
107-
off = tx.np.cumsum(off, axis=0)
105+
off = tx.np.roll(right, 1, axis=axis) + left + sep
106+
off = tx.index_update(off, (Ellipsis, 0), 0)
107+
off = tx.np.cumsum(off, axis=axis)
108108
t = v * off[..., None, None]
109109
return diagram.translate_by(t)
110-
111110

112-
call_scan = tx.multi_vmap(call_scan, axis) # type: ignore
111+
# call_scan = tx.multi_vmap(call_scan, axis) # type: ignore
113112
return call_scan(diagram).compose_axis()
114113

115114

chalk/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def apply_style(self: B1, style: StyleHolder) -> B: # type: ignore
8585
return ApplyStyle(new_diagram.style, self)
8686

8787
def __repr__(self) -> str:
88-
return f"Diagram[self.shape]"
88+
return f"Diagram[{self.shape}]"
8989

9090
def __tree_pp__(self, **kwargs): # type: ignore
9191
import jax._src.pretty_printer as pp

chalk/path.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def from_array(points: P2_t, closed: bool = False) -> Path:
8686

8787

8888
def from_points(points: List[P2_t], closed: bool = False) -> Path:
89+
points = tx.np.broadcast_arrays(*points)
8990
return Path.from_array(tx.np.stack(points, axis=-3))
9091

9192

chalk/style.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def shape(self) -> Tuple[int, ...]:
131131
return self.base.shape[:-1]
132132

133133
def get(self, key: str) -> tx.Scalars:
134-
v = self.base[slice(*STYLE_LOCATIONS[key])]
134+
v = self.base[..., slice(*STYLE_LOCATIONS[key])]
135135
return tx.np.where(
136-
self.mask[slice(*STYLE_LOCATIONS[key])], v, DEFAULTS[key]
136+
self.mask[..., slice(*STYLE_LOCATIONS[key])], v, DEFAULTS[key]
137137
)
138138

139139
@property
@@ -193,8 +193,8 @@ def to_mpl(self) -> Dict[str, Any]:
193193

194194
# Set by observation
195195
lw = self.line_width_
196-
style["linewidth"] = lw.reshape(-1)[0]
197-
style["alpha"] = self.fill_opacity_[0]
196+
style["linewidth"] = lw[...,0]
197+
style["alpha"] = self.fill_opacity_[...,0]
198198
return style
199199

200200

chalk/transform.py

+8
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def P2(x: Floating, y: Floating) -> P2_t:
8383
s: P2_t = np.stack([x, y, o], axis=-1)[..., None]
8484
return s
8585

86+
@jit
87+
# @partial(vectorize, signature="(),()->(3,1)")
88+
def to_P2(x: Float[Array, "*B 2"]) -> P2_t:
89+
"Map a standard vector to a point"
90+
_, o = np.broadcast_arrays(x[..., :1], ftos(1.0))
91+
s: P2_t = np.concatenate([x, o], axis=-1)[..., None]
92+
return s
93+
8694

8795
@jit
8896
def norm(v: V2_t) -> V2_t:

0 commit comments

Comments
 (0)