Description
This is actually a pretty interesting question that I'm stuck on. In Jax, I'm thinking of a diagram as a https://jax.readthedocs.io/en/latest/pytrees.html . A pytree is basically a tree of arrays. When you call vmap, and return a diagram, Jax returns a diagram where there is an extra dimension on all the arrays.
@jax.vmap
def draw(i):
return circle(i)
draw(np.arange(10))
This object is a Primitive
where the transform/style has a batch of 10 in front of it. By default I am interpreting this as a concat of 10 primitives. However one might also interpret it as an animation with 10 frames.
The question is what happens if you try to compose this with another object? Think in this case you just have 15 composed elements.
draw(np.arange(10)) + draw(np.arange(5))
But is that the same as this case? Here you have a Compose node that also has a Batch dimension on it that applies to both its children.
@jax.vmap
def draw(i):
return circle(i) + circle(i+10)
draw(np.arange(10))
You also have the case where there are multiple vmaps, in this case I think the concat should just flatten
them and draw them in order.
@jax.vmap
def draw(i):
@jax.vmap
def draw(j):
return circle(i) + circle(j+10)
However I think it would be nice if whatever we do here works both for animation and drawing. Like there is some notion of a composable sequence of diagrams either in z-space or time that corresponds to this tree idea.