diff --git a/src/exabiome/nn/models/resnet.py b/src/exabiome/nn/models/resnet.py index eaa7ce6..e9fe30b 100644 --- a/src/exabiome/nn/models/resnet.py +++ b/src/exabiome/nn/models/resnet.py @@ -145,6 +145,8 @@ def __init__(self, hparams): hparams.simple_clf = False if not hasattr(hparams, 'dropout_clf'): hparams.dropout_clf = False + if not hasattr(hparams, 'attention'): + hparams.attention = False super(ResNet, self).__init__(hparams) @@ -183,13 +185,21 @@ def __init__(self, hparams): dilate=replace_stride_with_dilation[2]) n_output_channels = 512 * block.expansion + if hparams.attention: + hparams.bottleneck = True #just to make sure bottleneck is on if using attention + if hparams.bottleneck: self.bottleneck = FeatureReduction(n_output_channels, 64 * block.expansion) n_output_channels = 64 * block.expansion else: self.bottleneck = None - + self.avgpool = nn.AdaptiveAvgPool1d(1) + + if hparams.attention: + self.attention = nn.MultiheadAttention(n_output_channels, 16) + else: + self.attention = None if hparams.tgt_tax_lvl == 'all': self.fc = HierarchicalClassifier(n_output_channels, hparams.n_taxa_all) @@ -314,15 +324,20 @@ def _forward_impl(self, x): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) - - + if self.bottleneck is not None: x = self.bottleneck(x) x = self.avgpool(x) + + if self.attention is not False: + x = x.permute(2, 0, 1) + x, _ = self.attention(x, x, x) + x = x.permute(1, 2, 0) + x = torch.flatten(x, 1) x = self.fc(x) - + return x def forward(self, x): diff --git a/src/exabiome/nn/train.py b/src/exabiome/nn/train.py index 70a6742..ecc03fd 100644 --- a/src/exabiome/nn/train.py +++ b/src/exabiome/nn/train.py @@ -60,6 +60,7 @@ def get_conf_args(): 'classify': dict(action='store_true', help='run a classification problem', default=False), 'manifold': dict(action='store_true', help='run a manifold learning problem', default=False), 'bottleneck': dict(action='store_true', help='add bottleneck layer at the end of ResNet features', default=True), + 'attention' : dict(help='add an attention layer at end of ResNet features', default=False), 'tgt_tax_lvl': dict(choices=DeepIndexFile.taxonomic_levels, metavar='LEVEL', default='species', help='the taxonomic level to predict. choices are phylum, class, order, family, genus, species'), 'simple_clf': dict(action='store_true', help='Use a single FC layer as the classifier for ResNets', default=False),