In the Transformer, you do
# training data attends to itself
src_left = self.self_attention_between_datapoints(src[:,:train_test_split_index], src[:,:train_test_split_index], src[:,:train_test_split_index])[0]
# test data attends to the training data
src_right = self.self_attention_between_datapoints(src[:,train_test_split_index:], src[:,:train_test_split_index], src[:,:train_test_split_index])[0]
src = torch.cat([src_left, src_right], dim=1)+src
Wouldn't it be easier (and potentially faster) to do the following?
# training + test data attends only to training data
out = self.self_attention_between_datapoints(src, src[:,:train_test_split_index], src[:,:train_test_split_index])[0]
src = src + out
In the Transformer, you do
Wouldn't it be easier (and potentially faster) to do the following?