Skip to content

Semantics for batched Diagrams? #138

Open
@srush

Description

This is actually a pretty interesting question that I'm stuck on. In Jax, I'm thinking of a diagram as a https://jax.readthedocs.io/en/latest/pytrees.html . A pytree is basically a tree of arrays. When you call vmap, and return a diagram, Jax returns a diagram where there is an extra dimension on all the arrays.

@jax.vmap
def draw(i):
    return circle(i)
draw(np.arange(10))

This object is a Primitive where the transform/style has a batch of 10 in front of it. By default I am interpreting this as a concat of 10 primitives. However one might also interpret it as an animation with 10 frames.

The question is what happens if you try to compose this with another object? Think in this case you just have 15 composed elements.

draw(np.arange(10)) + draw(np.arange(5))

But is that the same as this case? Here you have a Compose node that also has a Batch dimension on it that applies to both its children.

@jax.vmap
def draw(i):
    return circle(i) + circle(i+10)
draw(np.arange(10))

You also have the case where there are multiple vmaps, in this case I think the concat should just flatten them and draw them in order.

@jax.vmap
def draw(i):
    @jax.vmap
    def draw(j):
        return circle(i) + circle(j+10)

However I think it would be nice if whatever we do here works both for animation and drawing. Like there is some notion of a composable sequence of diagrams either in z-space or time that corresponds to this tree idea.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions