-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtyped_einsum.py
42 lines (33 loc) · 1.09 KB
/
typed_einsum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from fancy_einsum import einsum as _einsum
from typing import TypeVar, Type, Tuple, Any
from phantom_tensors import parse
from phantom_tensors._parse import HasShape
from phantom_tensors.torch import Tensor
from phantom_tensors.alphabet import A, B, C
T = TypeVar("T", bound=HasShape)
def _type_to_einstr(x): return ' '.join(v.__name__ for v in x.__args__)
def einsum(*in_types, out_type: Type[T], tensors) -> T:
"""
Examples
--------
import torch as tr
x, y = parse(
(tr.ones(2, 3), Tensor[A, B]),
(tr.ones(3, 4), Tensor[B, C]),
)
out = einsum(
Tensor[A, B],
Tensor[B, C],
out_type=Tensor[A, C],
tensors=(x, y),
)
out # type checker sees: Tensor[A, C]"""
assert len(in_types) == len(tensors)
in_str = ", ".join(_type_to_einstr(tp) for tp in in_types)
out_str = _type_to_einstr(out_type)
out = _einsum(f"{in_str} -> {out_str}", *tensors)
in_types += (out_type,)
tensors += (out, )
parse(*zip(tensors, in_types)) # check all types
parse(out, out_type) # casting
return parse(out, out_type)