File tree Expand file tree Collapse file tree 5 files changed +25
-11
lines changed
Expand file tree Collapse file tree 5 files changed +25
-11
lines changed Original file line number Diff line number Diff line change @@ -100,7 +100,7 @@ def _loss(
100100 def beam_decode (self , * args , ** kwargs ):
101101 """Overrides incompatible implementation inherited from RNNModel."""
102102 raise NotImplementedError (
103- f"Beam search not implemented for { self .name } model"
103+ f"Beam search is not supported by { self .name } model"
104104 )
105105
106106 def decode (
Original file line number Diff line number Diff line change @@ -211,7 +211,7 @@ def _check_layer_sizes(self) -> None:
211211 def beam_decode (self , * args , ** kwargs ):
212212 """Overrides incompatible implementation inherited from RNNModel."""
213213 raise NotImplementedError (
214- f"Beam search not implemented for { self .name } model"
214+ f"Beam search is not supported by { self .name } model"
215215 )
216216
217217 @abc .abstractmethod
Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def beam_decode(
5959 batch_size = encoder_mask .size (0 )
6060 if batch_size != 1 :
6161 raise NotImplementedError (
62- "Beam search is not implemented for batch_size > 1"
62+ "Beam search is not supported for batch_size > 1"
6363 )
6464 # Initializes hidden states for decoder LSTM.
6565 decoder_hiddens = self .init_hiddens (batch_size )
@@ -176,7 +176,16 @@ def forward(
176176 log-probabilities) for each prediction; greedy search returns
177177 a tensor of predictions of shape
178178 B x seq_len x target_vocab_size.
179+
180+ Raises:
181+ NotImplementedError: separate features encoders are not supported.
179182 """
183+ # TODO(#313): add support for this.
184+ if self .has_features_encoder :
185+ raise NotImplementedError (
186+ "Separate features encoders are not supported "
187+ f"by { self .name } model"
188+ )
180189 encoder_out = self .source_encoder (batch .source ).output
181190 # This function has a polymorphic return because beam search needs to
182191 # return two tensors. For greedy, the return has not been modified to
Original file line number Diff line number Diff line change @@ -62,7 +62,7 @@ def _get_loss_func(
6262 def beam_decode (self , * args , ** kwargs ):
6363 """Overrides incompatible implementation inherited from RNNModel."""
6464 raise NotImplementedError (
65- f"Beam search not implemented for { self .name } model"
65+ f"Beam search is not supported by { self .name } model"
6666 )
6767
6868 @property
Original file line number Diff line number Diff line change @@ -21,8 +21,7 @@ class TransformerModel(base.BaseModel):
2121 """
2222
2323 # Model arguments.
24- source_attention_heads : int
25- # Constructed inside __init__.
24+ source_attention_heads : int # Constructed inside __init__.
2625 classifier : nn .Linear
2726
2827 def __init__ (
@@ -57,21 +56,27 @@ def init_embeddings(
5756
5857 def beam_decode (self , * args , ** kwargs ):
5958 raise NotImplementedError (
60- f"Beam search not implemented for { self .name } model"
59+ f"Beam search is not supported by { self .name } model"
6160 )
6261
63- def forward (
64- self ,
65- batch : data .PaddedBatch ,
66- ) -> torch .Tensor :
62+ def forward (self , batch : data .PaddedBatch ) -> torch .Tensor :
6763 """Runs the encoder-decoder.
6864
6965 Args:
7066 batch (data.PaddedBatch).
7167
7268 Returns:
7369 torch.Tensor.
70+
71+ Raises:
72+ NotImplementedError: separate features encoders are not supported.
7473 """
74+ # TODO(#313): add support for this.
75+ if self .has_features_encoder :
76+ raise NotImplementedError (
77+ "Separate features encoders are not supported by "
78+ "{self.name} model"
79+ )
7580 if self .training and self .teacher_forcing :
7681 assert (
7782 batch .has_target
You can’t perform that action at this time.
0 commit comments