Skip to content

Add array protocol dispatch methods to top-level Funsor class #546

@eb8680

Description

@eb8680

Now that PyTorch supports tensor subtyping and function overloading with __torch_function__, should we add __array_function__ and __torch_function__ methods to funsor.terms.Funsor to allow evaluation of (some) PyTorch/Numpy code on Funsors?

Here is the meat of a Funsor.__torch_function__ implementation, modulo handling of edge cases; __array_function__ for the Numpy backend would be very similar:

class Funsor:
    ...
    def __torch_function__(self, func, types, args=(), kwargs=None):
        # exploit our op registry: ops should know how to handle and convert their arguments
        try:
            op = getattr(funsor.ops, func.__name__)
        except AttributeError:
            op = funsor.ops.make_op(func). # handle e.g. nn.Module or dist.Transform instances
        return op(*args, **kwargs)

The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful to_data/to_funsor primitives in pyro.contrib.funsor, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for special torch.Tensor subclasses that duplicate Funsor broadcasting semantics.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions