Skip to content

Possible optimization #3

@dholzmueller

Description

@dholzmueller

In the Transformer, you do

# training data attends to itself
src_left = self.self_attention_between_datapoints(src[:,:train_test_split_index], src[:,:train_test_split_index], src[:,:train_test_split_index])[0]
# test data attends to the training data
src_right = self.self_attention_between_datapoints(src[:,train_test_split_index:], src[:,:train_test_split_index], src[:,:train_test_split_index])[0]
src = torch.cat([src_left, src_right], dim=1)+src

Wouldn't it be easier (and potentially faster) to do the following?

# training + test data attends only to training data
out = self.self_attention_between_datapoints(src, src[:,:train_test_split_index], src[:,:train_test_split_index])[0]
src = src + out

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions