Skip to content

Commit 7f96a8b

Browse files
authored
Merge pull request #326 from chaoming0625/master
[compatibility] more operators in pytorch and tensorflow
2 parents f46e78b + 6927195 commit 7f96a8b

21 files changed

+583
-262
lines changed

brainpy/_src/math/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
from . import activations
4040

4141
# high-level numpy operations
42-
from .arraycreation import *
4342
from .arrayinterporate import *
44-
from .arraycompatible import *
43+
from .compat_numpy import *
44+
from .compat_tensorflow import *
4545
from .others import *
4646
from . import random, linalg, fft
4747

brainpy/_src/math/_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,29 @@ def _compatible_with_brainpy_array(fun: Callable):
3939
@functools.wraps(fun)
4040
def new_fun(*args, **kwargs):
4141
args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf)
42+
out = None
4243
if len(kwargs):
44+
# compatible with PyTorch syntax
45+
if 'dim' in kwargs:
46+
kwargs['axis'] = kwargs.pop('dim')
47+
# compatible with PyTorch syntax
48+
if 'keepdim' in kwargs:
49+
kwargs['keep_dims'] = kwargs.pop('keepdim')
50+
# compatible with TensorFlow syntax
51+
if 'keepdims' in kwargs:
52+
kwargs['keep_dims'] = kwargs.pop('keepdims')
53+
# compatible with NumPy/PyTorch syntax
54+
if 'out' in kwargs:
55+
out = kwargs.get('out')
56+
if not isinstance(out, Array):
57+
raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}')
58+
# format
4359
kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf)
4460
r = fun(*args, **kwargs)
45-
return tree_map(_return, r)
61+
if out is None:
62+
return tree_map(_return, r)
63+
else:
64+
out.value = r
4665

4766
new_fun.__doc__ = getattr(fun, "__doc__", None)
4867

brainpy/_src/math/arraycreation.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

brainpy/_src/math/arrayoperation.py

Lines changed: 0 additions & 108 deletions
This file was deleted.

0 commit comments

Comments
 (0)