Skip to content

Commit 7747db1

Browse files
authored
Construct EarlyFusion's encoder_token_ids on correct device (#2276)
1 parent 890deab commit 7747db1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchtune/modules/model_fusion/_early_fusion.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def reset_caches(self):
185185

186186
def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]:
187187
"""Embed the text-only tokens with the decoder's tok_embeddings"""
188-
encoder_token_ids = torch.tensor(list(self.encoder_tokens.values()))
188+
encoder_token_ids = torch.tensor(
189+
list(self.encoder_tokens.values()), device=tokens.device
190+
)
189191
# [bsz, seq_len], True indicates the token is not an encoder special token
190192
is_text = ~torch.isin(tokens, encoder_token_ids)
191193
text_tokens = torch.masked_select(tokens, is_text)

0 commit comments

Comments
 (0)