Skip to content

Commit 4abca3c

Browse files
authored
Merge pull request #314 from kylebgorman/error
Raises error with unsupported features encoders
2 parents ac1964f + c09a00d commit 4abca3c

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

yoyodyne/models/hard_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _loss(
100100
def beam_decode(self, *args, **kwargs):
101101
"""Overrides incompatible implementation inherited from RNNModel."""
102102
raise NotImplementedError(
103-
f"Beam search not implemented for {self.name} model"
103+
f"Beam search is not supported by {self.name} model"
104104
)
105105

106106
def decode(

yoyodyne/models/pointer_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _check_layer_sizes(self) -> None:
211211
def beam_decode(self, *args, **kwargs):
212212
"""Overrides incompatible implementation inherited from RNNModel."""
213213
raise NotImplementedError(
214-
f"Beam search not implemented for {self.name} model"
214+
f"Beam search is not supported by {self.name} model"
215215
)
216216

217217
@abc.abstractmethod

yoyodyne/models/rnn.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def beam_decode(
5959
batch_size = encoder_mask.size(0)
6060
if batch_size != 1:
6161
raise NotImplementedError(
62-
"Beam search is not implemented for batch_size > 1"
62+
"Beam search is not supported for batch_size > 1"
6363
)
6464
# Initializes hidden states for decoder LSTM.
6565
decoder_hiddens = self.init_hiddens(batch_size)
@@ -176,7 +176,16 @@ def forward(
176176
log-probabilities) for each prediction; greedy search returns
177177
a tensor of predictions of shape
178178
B x seq_len x target_vocab_size.
179+
180+
Raises:
181+
NotImplementedError: separate features encoders are not supported.
179182
"""
183+
# TODO(#313): add support for this.
184+
if self.has_features_encoder:
185+
raise NotImplementedError(
186+
"Separate features encoders are not supported "
187+
f"by {self.name} model"
188+
)
180189
encoder_out = self.source_encoder(batch.source).output
181190
# This function has a polymorphic return because beam search needs to
182191
# return two tensors. For greedy, the return has not been modified to

yoyodyne/models/transducer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _get_loss_func(
6262
def beam_decode(self, *args, **kwargs):
6363
"""Overrides incompatible implementation inherited from RNNModel."""
6464
raise NotImplementedError(
65-
f"Beam search not implemented for {self.name} model"
65+
f"Beam search is not supported by {self.name} model"
6666
)
6767

6868
@property

yoyodyne/models/transformer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class TransformerModel(base.BaseModel):
2121
"""
2222

2323
# Model arguments.
24-
source_attention_heads: int
25-
# Constructed inside __init__.
24+
source_attention_heads: int # Constructed inside __init__.
2625
classifier: nn.Linear
2726

2827
def __init__(
@@ -57,21 +56,27 @@ def init_embeddings(
5756

5857
def beam_decode(self, *args, **kwargs):
5958
raise NotImplementedError(
60-
f"Beam search not implemented for {self.name} model"
59+
f"Beam search is not supported by {self.name} model"
6160
)
6261

63-
def forward(
64-
self,
65-
batch: data.PaddedBatch,
66-
) -> torch.Tensor:
62+
def forward(self, batch: data.PaddedBatch) -> torch.Tensor:
6763
"""Runs the encoder-decoder.
6864
6965
Args:
7066
batch (data.PaddedBatch).
7167
7268
Returns:
7369
torch.Tensor.
70+
71+
Raises:
72+
NotImplementedError: separate features encoders are not supported.
7473
"""
74+
# TODO(#313): add support for this.
75+
if self.has_features_encoder:
76+
raise NotImplementedError(
77+
"Separate features encoders are not supported by "
78+
"{self.name} model"
79+
)
7580
if self.training and self.teacher_forcing:
7681
assert (
7782
batch.has_target

0 commit comments

Comments
 (0)