diff --git a/toynlp/bert/model.py b/toynlp/bert/model.py index 08eae85..900e98b 100644 --- a/toynlp/bert/model.py +++ b/toynlp/bert/model.py @@ -60,8 +60,7 @@ def forward( attention_weight = q @ k.transpose(-2, -1) / (head_dim**0.5) # pad mask if mask is not None: - # TODO: -inf? - attention_weight = attention_weight.masked_fill(mask == 0, float("-10000")) + attention_weight = attention_weight.masked_fill(mask == 0, float("-inf")) attention_score = torch.nn.functional.softmax(attention_weight, dim=-1) attention_score = self.dropout(attention_score) diff --git a/toynlp/transformer/model.py b/toynlp/transformer/model.py index 0f726bc..07d1b6b 100644 --- a/toynlp/transformer/model.py +++ b/toynlp/transformer/model.py @@ -54,7 +54,7 @@ def forward( # pad mask if mask is not None: # TODO: -inf? - attention_weight = attention_weight.masked_fill(mask == 0, float("-10000")) + attention_weight = attention_weight.masked_fill(mask == 0, float("-inf")) attention_score = torch.nn.functional.softmax(attention_weight, dim=-1) @@ -253,8 +253,8 @@ def _get_source_mask(self, source_token_ids: torch.Tensor) -> torch.Tensor: return (source_token_ids != self.padding_idx).unsqueeze(1).unsqueeze(2) def _get_target_mask(self, target_token_ids: torch.Tensor) -> torch.Tensor: - # shape: (batch_size, 1, target_seq_length, 1) - pad_mask = (target_token_ids != self.padding_idx).unsqueeze(1).unsqueeze(3) + # shape: (batch_size, 1, 1, target_seq_length) + pad_mask = (target_token_ids != self.padding_idx).unsqueeze(1).unsqueeze(2) target_seq_length = target_token_ids.size(1) # shape: (batch_size, 1, target_seq_length, target_seq_length) causal_mask = torch.tril(torch.ones((target_seq_length, target_seq_length), device=self.device)).bool()