We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 890deab commit 7747db1Copy full SHA for 7747db1
torchtune/modules/model_fusion/_early_fusion.py
@@ -185,7 +185,9 @@ def reset_caches(self):
185
186
def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]:
187
"""Embed the text-only tokens with the decoder's tok_embeddings"""
188
- encoder_token_ids = torch.tensor(list(self.encoder_tokens.values()))
+ encoder_token_ids = torch.tensor(
189
+ list(self.encoder_tokens.values()), device=tokens.device
190
+ )
191
# [bsz, seq_len], True indicates the token is not an encoder special token
192
is_text = ~torch.isin(tokens, encoder_token_ids)
193
text_tokens = torch.masked_select(tokens, is_text)
0 commit comments