Skip to content

Commit 40949ba

Browse files
committed
minor fix
1 parent 1bf4898 commit 40949ba

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

llava/train/train.py

+8
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,14 @@ def preprocess_mpt(
561561

562562
if has_image:
563563
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0)
564+
else:
565+
input_ids = tokenizer(
566+
conversations,
567+
return_tensors="pt",
568+
padding="longest",
569+
max_length=tokenizer.model_max_length,
570+
truncation=True,
571+
).input_ids
564572

565573
targets = input_ids.clone()
566574
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

0 commit comments

Comments
 (0)