-
Dear all, this relates to issue #19609. In mxnet 2.0 the operator GroupNorm does not exist (yet!). Therefore, I am using the following method to create a GroupNorm operator, compatible with mxnet.np.ndarray: class GroupNormHack(gluon.nn.HybridBlock):
"""
This is a partial fix for issue #19609
see https://github.com/apache/incubator-mxnet/issues/19609
"""
def __init__(self, num_groups, **kwards):
super().__init__(**kwards)
self.norm = gluon.nn.GroupNorm(num_groups=num_groups,**kwards)
def forward(self, input):
tinput = input.as_nd_ndarray() if mx.npx.is_np_array() else input
mx.npx.reset_np()
out = self.norm(tinput)
mx.npx.set_np()
out = out.as_np_ndarray()
return out Using this operator inside as a constituent of a deeper network, I can run my code fine, but when I am trying to load from a checkpoint, somehow I get an error upon performing a forward operation: File "/scratch1/dia021/Software/mxprosthesis/models/changedetection/mantis/mantis_dn_features.py", line 112, in forward
conv1_t1 = self.conv_first(input_t1)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 1644, in __call__
return super().__call__(x, *args)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 851, in __call__
out = self.forward(*args)
File "/scratch1/dia021/Software/mxprosthesis/nn/layers/conv2Dnormed.py", line 34, in forward
x = self.norm_layer(x)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 1644, in __call__
return super().__call__(x, *args)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 851, in __call__
out = self.forward(*args)
File "/scratch1/dia021/Software/mxprosthesis/utils/get_norm.py", line 18, in forward
out = self.norm(tinput)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 1625, in __call__
return super().__call__(x, *args)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 851, in __call__
out = self.forward(*args)
File "/scratch1/dia021/Software/mxnet/gluon/block.py", line 1681, in forward
return self.hybrid_forward(ndarray, x, *args, **params)
File "/scratch1/dia021/Software/mxnet/gluon/nn/basic_layers.py", line 870, in hybrid_forward
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
File "<string>", line 66, in GroupNorm
File "/scratch1/dia021/Software/mxnet/ndarray/register.py", line 92, in _verify_all_legacy_ndarrays
raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
TypeError: Operator `GroupNorm` registered in backend is known as `GroupNorm` in Python. This is a legacy operator which can only accept legacy ndarrays, while received an MXNet numpy ndarray. Please call `as_nd_ndarray()` upon the numpy ndarray to convert it to a legacy ndarray, and then feed the converted array to this operator.
Any ideas/pointers where to look and fix this most appreciated. Calling Many thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Dear all, I finally use this GroupNorm definition, and it looks good for running and loading models: import mxnet as mx
from mxnet.gluon import HybridBlock
from mxnet.gluon.parameter import Parameter
@mx.use_np
class GroupNorm(HybridBlock):
r"""
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where the leftmost 2 axis are
`batch` and `channel` respectively:
.. math::
x = x.reshape((N, num_groups, C // num_groups, ...))
axis = (2, ...)
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
Parameters
----------
num_groups: int, default 1
Number of groups to separate the channel axis into.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
Inputs:
- **data**: input tensor with shape (N, C, ...).
Outputs:
- **out**: output tensor with the same shape as `data`.
References
----------
`Group Normalization
<https://arxiv.org/pdf/1803.08494.pdf>`_
Examples
--------
# Input of shape (2, 3, 4)
x = mx.nd.array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
# Group normalization is calculated with the above formula
layer = GroupNorm()
layer.initialize(ctx=mx.cpu(0))
layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
in_channels=0):
super(GroupNorm, self).__init__()
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = Parameter('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = Parameter('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)
def infer_shape(self,in_shape):
# Necessary for mxnet 2.0
tshape = in_shape.shape
self.gamma.shape = tshape[1],
self.beta.shape = tshape[1],
def forward(self, x):
gamma = self.gamma.data().as_nd_ndarray()
beta = self.beta.data().as_nd_ndarray()
x = mx.nd.GroupNorm(data=x.as_nd_ndarray(), gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
x = x.as_np_ndarray()
return x
def __repr__(self):
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
|
Beta Was this translation helpful? Give feedback.
Dear all,
I finally use this GroupNorm definition, and it looks good for running and loading models: