How to implement triplet loss with mxnet #6909
Replies: 9 comments
-
It's simple once you know how to use kernels = [(1, feature_size), (2, feature_size), (3, feature_size)]
for i in range(len(kernels)):
conv_weight.append(mx.sym.Variable('conv' + str(i) + '_weight'))
conv_bias.append(mx.sym.Variable('conv' + str(i) + '_bias'))
fa = get_conv(data=anchor,
kernels=kernels, conv_weight=conv_weight, conv_bias=conv_bias,
entity_weight=entity_weight, entity_bias=entity_bias,
feature_name='fa') # share weight.
fs = get_conv(data=same,
kernels=kernels, conv_weight=conv_weight, conv_bias=conv_bias,
entity_weight=entity_weight, entity_bias=entity_bias,
feature_name='fs')
fd = get_conv(data=diff,
kernels=kernels, conv_weight=conv_weight, conv_bias=conv_bias,
entity_weight=entity_weight, entity_bias=entity_bias,
feature_name='fd')
"""
triple-loss
"""
fs = fa - fs
fd = fa - fd
fs = fs * fs
fd = fd * fd
fs = mx.sym.sum(fs, axis=1, keepdims=1)
fd = mx.sym.sum(fd, axis=1, keepdims=1)
loss = fd - fs
loss = one - loss # a scalar
loss = mx.sym.Activation(data=loss, act_type='relu') # acts like a norm.
triple_loss = mx.sym.MakeLoss(loss) |
Beta Was this translation helpful? Give feedback.
-
Refer to https://github.com/xlvector/learning-dl/tree/master/mxnet/triple-loss |
Beta Was this translation helpful? Give feedback.
-
@zihaolucky
I wonder why the anchor is treated as label. Will the model be effected If i change the label to 'one', and put the 'anchor' in data, and change the batch generator accordingly? |
Beta Was this translation helpful? Give feedback.
-
The shape is up to you. Check out the example in the link. |
Beta Was this translation helpful? Give feedback.
-
@zihaolucky nice work! |
Beta Was this translation helpful? Give feedback.
-
Hi, @zihaolucky , I think the triple-loss example you mentioned above is very helpful, but I don't think it is a complete tutorial for getting started. Can you kindly guide me how to prepare the training data for that triplet loss? I can see the script is loading some *.npy files but don't know exactly what they are: https://github.com/xlvector/learning-dl/blob/master/mxnet/triple-loss/triplet_loss.py#L155-L160 |
Beta Was this translation helpful? Give feedback.
-
@zihaolucky, |
Beta Was this translation helpful? Give feedback.
-
@apache/mxnet-committers: This issue has been inactive for the past 90 days. It has no label and needs triage. For general "how-to" questions, our user forum (and Chinese version) is a good place to get help. |
Beta Was this translation helpful? Give feedback.
-
Please remove the label: Need Triage Since the tutorial needs to be updated. |
Beta Was this translation helpful? Give feedback.
-
I refer to #1502 and the forum provided can not be opened.
I think a tutorial about triplet loss can be added to the tutorials.
Beta Was this translation helpful? Give feedback.
All reactions