Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlavaForConditionalGeneration._merge_input_ids_with_image_features is incorrect #33976

Open
2 of 4 tasks
fpgaminer opened this issue Oct 5, 2024 · 5 comments
Open
2 of 4 tasks
Labels

Comments

@fpgaminer
Copy link
Contributor

System Info

commit: 38f9f10

Who can help?

@amyeroberts @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The function _merge_input_ids_with_image_features of class LlavaForConditionalGeneration gives incorrect results if it has to pad the result (which arises when the batch has different amounts of image tokens in each row).

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):

I created the following reproduction example to show the failure mode simply. Note that it uses a copy of the function to add print statements for debugging, but the code is otherwise an identical copy at HEAD.

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
	num_images, num_image_patches, embed_dim = image_features.shape
	batch_size, sequence_length = input_ids.shape
	left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
	# 1. Create a mask to know where special image tokens are
	special_image_token_mask = input_ids == self.config.image_token_index
	num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
	# Compute the maximum embed dimension
	max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
	batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

	# 2. Compute the positions where text should be written
	# Calculate new positions for text tokens in merged image-text sequence.
	# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
	# `torch.cumsum` computes how each image token shifts subsequent text token positions.
	# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
	new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
	nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
	print(f"nb_image_pad: {nb_image_pad}")
	if left_padding:
		new_token_positions += nb_image_pad[:, None]  # offset for left padding
	text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

	# 3. Create the full embedding, already padded to the maximum position
	final_embedding = torch.zeros(
		batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
	)
	final_attention_mask = torch.zeros(
		batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
	)
	if labels is not None:
		final_labels = torch.full(
			(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
		)
	# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
	# set the corresponding tensors into their correct target device.
	target_device = inputs_embeds.device
	batch_indices, non_image_indices, text_to_overwrite = (
		batch_indices.to(target_device),
		non_image_indices.to(target_device),
		text_to_overwrite.to(target_device),
	)
	attention_mask = attention_mask.to(target_device)

	print(f"max_embed_dim: {max_embed_dim}")
	print(f"batch_indices: {batch_indices}")
	print(f"non_image_indices: {non_image_indices}")
	print(f"text_to_overwrite: {text_to_overwrite}")

	# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
	# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
	final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
	final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
	if labels is not None:
		final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

	# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
	image_to_overwrite = torch.full(
		(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
	)
	image_to_overwrite[batch_indices, text_to_overwrite] = False
	print(f"image_to_overwrite cumsum: {image_to_overwrite.cumsum(-1) - 1}")
	image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
	print(f"image_to_overwrite: {image_to_overwrite}")

	if image_to_overwrite.sum() != image_features.shape[:-1].numel():
		raise ValueError(
			f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
			f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
		)

	final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
	final_attention_mask |= image_to_overwrite
	position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

	# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
	batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
	indices_to_mask = new_token_positions[batch_indices, pad_indices]

	final_embedding[batch_indices, indices_to_mask] = 0

	if labels is None:
		final_labels = None

	return final_embedding, final_attention_mask, final_labels, position_ids


tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf", use_fast=True)

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", device_map=0, torch_dtype=torch.float16)
model.eval()
assert isinstance(model, LlavaForConditionalGeneration)


model._merge_input_ids_with_image_features = types.MethodType(_merge_input_ids_with_image_features, model)


prompt1 = "System prompt<image>Caption 1<pad><pad><pad><pad><pad>"
prompt2 = "System prompt<image>Caption 2<image>Caption 3"           # 12 tokens

print("`System Prompt` tokens: ", tokenizer.encode("System prompt", add_special_tokens=False, truncation=False))
print("`Caption 1` tokens: ", tokenizer.encode("Caption 1", add_special_tokens=False, truncation=False))

# Tokenize the prompt
input_ids = [
	tokenizer.encode(prompt1, return_tensors="pt", add_special_tokens=False, truncation=False).squeeze(0),
	tokenizer.encode(prompt2, return_tensors="pt", add_special_tokens=False, truncation=False).squeeze(0),
]
input_ids = torch.stack(input_ids)
assert isinstance(input_ids, torch.Tensor)

print(f"Input IDs: {input_ids.shape}")
print(f"Input IDs: {input_ids}")


gc.collect()
torch.cuda.empty_cache()

with torch.no_grad():
	input_embeddings = model.get_input_embeddings()(input_ids.to('cuda'))

	embedded_images = torch.randn(3, 5, 4096, device='cuda', dtype=torch.float16)
	attention_mask = torch.ones(2, 12, device='cuda', dtype=torch.bool)
	result = model._merge_input_ids_with_image_features(embedded_images, input_embeddings, input_ids, attention_mask, None)[0]
	print(result.shape)

	print("`System Prompt` diff: ", (result[0, :2] - input_embeddings[0, :2]).abs().max())
	print("`image` diff: ", (result[0, 2:7] - embedded_images[0]).abs().max())
	print("`Caption 1` diff: ", (result[0, 7:11] - input_embeddings[0, 3:7]).abs().max())

	print("`System Prompt` diff: ", (result[1, :2] - input_embeddings[1, :2]).abs().max())
	print("`image` diff: ", (result[1, 2:7] - embedded_images[1]).abs().max())
	print("`Caption 2` diff: ", (result[1, 7:11] - input_embeddings[1, 3:7]).abs().max())
	print("`image` diff: ", (result[1, 11:16] - embedded_images[2]).abs().max())
	print("`Caption 3` diff: ", (result[1, 16:20] - input_embeddings[1, 8:12]).abs().max())

Running this gives:

`System Prompt` tokens:  [2184, 9508]
`Caption 1` tokens:  [9243, 683, 29871, 29896]
Input IDs: torch.Size([2, 12])
Input IDs: tensor([[ 2184,  9508, 32000,  9243,   683, 29871, 29896, 32001, 32001, 32001,
         32001, 32001],
        [ 2184,  9508, 32000,  9243,   683, 29871, 29906, 32000,  9243,   683,
         29871, 29941]])
nb_image_pad: tensor([4, 0])
max_embed_dim: 20
batch_indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       device='cuda:0')
non_image_indices: tensor([ 0,  1,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  3,  4,  5,  6,  8,
         9, 10, 11], device='cuda:0')
text_to_overwrite: tensor([ 0,  1,  7,  8,  9, 10, 11, 12, 13, 14, 15,  0,  1,  7,  8,  9, 10, 16,
        17, 18, 19], device='cuda:0')
image_to_overwrite cumsum: tensor([[-1, -1,  0,  1,  2,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  6,
          7,  8],
        [-1, -1,  0,  1,  2,  3,  4,  4,  4,  4,  4,  5,  6,  7,  8,  9,  9,  9,
          9,  9]], device='cuda:0')
image_to_overwrite: tensor([[False, False, False, False, False, False,  True, False, False, False,
         False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True, False, False, False,
         False,  True,  True,  True,  True,  True, False, False, False, False]],
       device='cuda:0')
torch.Size([2, 20, 4096])
`System Prompt` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`image` diff:  tensor(5.0898, device='cuda:0', dtype=torch.float16)
`Caption 1` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`System Prompt` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`image` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`Caption 2` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`image` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`Caption 3` diff:  tensor(0., device='cuda:0', dtype=torch.float16)

The last set of prints shows the difference between what the function outputs and what is expected, by comparing the different pieces of the result to the original inputs. As can be seen here, the result is almost correct but has misplaced the first image.

The reason is this line:

image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

It looks like this line is trying to account for padding in the outputs, but assumes the outputs are going to be left padded, whereas the rest of the function dynamically switches between left and right padding depending on the left_padding variable. If the outputs are being right padded, the comparison on this line is incorrect. That can be seen in the debug output above where image_to_overwrite shows the image being scattered around the tensor.

Importantly, this incorrect behavior can occur for any padding, because left_padding is "automatically" detected from the input. Even if the model is setup for left padding, if a particular batch doesn't have any padding going into this function, the function will assume right padding and thus mangle the output.

One solution is to add an if statement on that line based on left_padding and use a different condition to handle right padding. I think this code (for the right padding case) would work, but I haven't tested it extensively:

	idxs = torch.arange(max_embed_dim, device=image_to_overwrite.device).expand(batch_size, -1)
	image_to_overwrite &= idxs < (max_embed_dim - nb_image_pad[:, None]).to(target_device)

Side note:

I think the current implementation for handling the left padding case could also be replaced with arange based logic. To me it would be more clear, as it indicates the intention of that line explicitly (set all padded indexes to False). Whereas I found the use of image_to_overwrite.cumsum(-1) opaque and difficult to decipher what it was trying to calculate.

Expected behavior

See above

@fpgaminer fpgaminer added the bug label Oct 5, 2024
@ArthurZucker
Copy link
Collaborator

Hey! Thanks for the report. I think this piece of code has been changed quite a lot since initial porting so @zucchini-nlp will have a better idea of what is going on!

Note that this is the legacy branch, which should not be triggered anymore and we overall deprecate!

@fpgaminer
Copy link
Contributor Author

Note that this is the legacy branch, which should not be triggered anymore and we overall deprecate!

Ah, I hadn't noticed! Good point, and it looks like this code should disappear soon based on the comment TODO: @raushan retain only the new behavior after v4.47. So maybe this issue isn't worth fixing?

@zucchini-nlp
Copy link
Member

Hey! Yes, it is worth fixing as I still didn't update our files on the hub and we fall to the legacy logic. I am planning to gradually update files and in the meanwhile we need to be sure the legacy branch is working correctly

Interesting that the padding="right" is not correct, as it should have been covered when adding the model. I'll get back to this on Monday/Tuesday, thanks

@ArthurZucker
Copy link
Collaborator

Yeah I remember that we had some tests for this!

@zucchini-nlp
Copy link
Member

Okey, verified that indeed the merging is not correct when we do right padding, and it is really weird noone noticed that earlier. We added handling for padding side only in llava-next because the number of image tokens can be different and needs unpadding. And the other models it was assumed that the following line handles it correctly

if left_padding:
	new_token_positions += nb_image_pad[:, None]  # offset for left padding

Also related to #33662 which I didn't had time to look due to absence of clear reproducer. @fpgaminer would you like to open a PR and fix this in all llava models? We might also need a test with dummy inputs and tiny model, to make sure it doesn't happen in the future :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants