-
Notifications
You must be signed in to change notification settings - Fork 22
Description
Now that PyTorch supports tensor subtyping and function overloading with __torch_function__, should we add __array_function__ and __torch_function__ methods to funsor.terms.Funsor to allow evaluation of (some) PyTorch/Numpy code on Funsors?
Here is the meat of a Funsor.__torch_function__ implementation, modulo handling of edge cases; __array_function__ for the Numpy backend would be very similar:
class Funsor:
...
def __torch_function__(self, func, types, args=(), kwargs=None):
# exploit our op registry: ops should know how to handle and convert their arguments
try:
op = getattr(funsor.ops, func.__name__)
except AttributeError:
op = funsor.ops.make_op(func). # handle e.g. nn.Module or dist.Transform instances
return op(*args, **kwargs)The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful to_data/to_funsor primitives in pyro.contrib.funsor, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for special torch.Tensor subclasses that duplicate Funsor broadcasting semantics.