Skip to content

LlavaForConditionalGeneration._merge_input_ids_with_image_features is incorrect #33976

Closed
@fpgaminer

Description

@fpgaminer

System Info

commit: 38f9f10

Who can help?

@amyeroberts @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The function _merge_input_ids_with_image_features of class LlavaForConditionalGeneration gives incorrect results if it has to pad the result (which arises when the batch has different amounts of image tokens in each row).

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):

I created the following reproduction example to show the failure mode simply. Note that it uses a copy of the function to add print statements for debugging, but the code is otherwise an identical copy at HEAD.

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
	num_images, num_image_patches, embed_dim = image_features.shape
	batch_size, sequence_length = input_ids.shape
	left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
	# 1. Create a mask to know where special image tokens are
	special_image_token_mask = input_ids == self.config.image_token_index
	num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
	# Compute the maximum embed dimension
	max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
	batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

	# 2. Compute the positions where text should be written
	# Calculate new positions for text tokens in merged image-text sequence.
	# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
	# `torch.cumsum` computes how each image token shifts subsequent text token positions.
	# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
	new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
	nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
	print(f"nb_image_pad: {nb_image_pad}")
	if left_padding:
		new_token_positions += nb_image_pad[:, None]  # offset for left padding
	text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

	# 3. Create the full embedding, already padded to the maximum position
	final_embedding = torch.zeros(
		batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
	)
	final_attention_mask = torch.zeros(
		batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
	)
	if labels is not None:
		final_labels = torch.full(
			(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
		)
	# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
	# set the corresponding tensors into their correct target device.
	target_device = inputs_embeds.device
	batch_indices, non_image_indices, text_to_overwrite = (
		batch_indices.to(target_device),
		non_image_indices.to(target_device),
		text_to_overwrite.to(target_device),
	)
	attention_mask = attention_mask.to(target_device)

	print(f"max_embed_dim: {max_embed_dim}")
	print(f"batch_indices: {batch_indices}")
	print(f"non_image_indices: {non_image_indices}")
	print(f"text_to_overwrite: {text_to_overwrite}")

	# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
	# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
	final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
	final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
	if labels is not None:
		final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

	# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
	image_to_overwrite = torch.full(
		(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
	)
	image_to_overwrite[batch_indices, text_to_overwrite] = False
	print(f"image_to_overwrite cumsum: {image_to_overwrite.cumsum(-1) - 1}")
	image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
	print(f"image_to_overwrite: {image_to_overwrite}")

	if image_to_overwrite.sum() != image_features.shape[:-1].numel():
		raise ValueError(
			f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
			f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
		)

	final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
	final_attention_mask |= image_to_overwrite
	position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

	# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
	batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
	indices_to_mask = new_token_positions[batch_indices, pad_indices]

	final_embedding[batch_indices, indices_to_mask] = 0

	if labels is None:
		final_labels = None

	return final_embedding, final_attention_mask, final_labels, position_ids


tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf", use_fast=True)

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", device_map=0, torch_dtype=torch.float16)
model.eval()
assert isinstance(model, LlavaForConditionalGeneration)


model._merge_input_ids_with_image_features = types.MethodType(_merge_input_ids_with_image_features, model)


prompt1 = "System prompt<image>Caption 1<pad><pad><pad><pad><pad>"
prompt2 = "System prompt<image>Caption 2<image>Caption 3"           # 12 tokens

print("`System Prompt` tokens: ", tokenizer.encode("System prompt", add_special_tokens=False, truncation=False))
print("`Caption 1` tokens: ", tokenizer.encode("Caption 1", add_special_tokens=False, truncation=False))

# Tokenize the prompt
input_ids = [
	tokenizer.encode(prompt1, return_tensors="pt", add_special_tokens=False, truncation=False).squeeze(0),
	tokenizer.encode(prompt2, return_tensors="pt", add_special_tokens=False, truncation=False).squeeze(0),
]
input_ids = torch.stack(input_ids)
assert isinstance(input_ids, torch.Tensor)

print(f"Input IDs: {input_ids.shape}")
print(f"Input IDs: {input_ids}")


gc.collect()
torch.cuda.empty_cache()

with torch.no_grad():
	input_embeddings = model.get_input_embeddings()(input_ids.to('cuda'))

	embedded_images = torch.randn(3, 5, 4096, device='cuda', dtype=torch.float16)
	attention_mask = torch.ones(2, 12, device='cuda', dtype=torch.bool)
	result = model._merge_input_ids_with_image_features(embedded_images, input_embeddings, input_ids, attention_mask, None)[0]
	print(result.shape)

	print("`System Prompt` diff: ", (result[0, :2] - input_embeddings[0, :2]).abs().max())
	print("`image` diff: ", (result[0, 2:7] - embedded_images[0]).abs().max())
	print("`Caption 1` diff: ", (result[0, 7:11] - input_embeddings[0, 3:7]).abs().max())

	print("`System Prompt` diff: ", (result[1, :2] - input_embeddings[1, :2]).abs().max())
	print("`image` diff: ", (result[1, 2:7] - embedded_images[1]).abs().max())
	print("`Caption 2` diff: ", (result[1, 7:11] - input_embeddings[1, 3:7]).abs().max())
	print("`image` diff: ", (result[1, 11:16] - embedded_images[2]).abs().max())
	print("`Caption 3` diff: ", (result[1, 16:20] - input_embeddings[1, 8:12]).abs().max())

Running this gives:

`System Prompt` tokens:  [2184, 9508]
`Caption 1` tokens:  [9243, 683, 29871, 29896]
Input IDs: torch.Size([2, 12])
Input IDs: tensor([[ 2184,  9508, 32000,  9243,   683, 29871, 29896, 32001, 32001, 32001,
         32001, 32001],
        [ 2184,  9508, 32000,  9243,   683, 29871, 29906, 32000,  9243,   683,
         29871, 29941]])
nb_image_pad: tensor([4, 0])
max_embed_dim: 20
batch_indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       device='cuda:0')
non_image_indices: tensor([ 0,  1,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,  3,  4,  5,  6,  8,
         9, 10, 11], device='cuda:0')
text_to_overwrite: tensor([ 0,  1,  7,  8,  9, 10, 11, 12, 13, 14, 15,  0,  1,  7,  8,  9, 10, 16,
        17, 18, 19], device='cuda:0')
image_to_overwrite cumsum: tensor([[-1, -1,  0,  1,  2,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  6,
          7,  8],
        [-1, -1,  0,  1,  2,  3,  4,  4,  4,  4,  4,  5,  6,  7,  8,  9,  9,  9,
          9,  9]], device='cuda:0')
image_to_overwrite: tensor([[False, False, False, False, False, False,  True, False, False, False,
         False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True, False, False, False,
         False,  True,  True,  True,  True,  True, False, False, False, False]],
       device='cuda:0')
torch.Size([2, 20, 4096])
`System Prompt` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`image` diff:  tensor(5.0898, device='cuda:0', dtype=torch.float16)
`Caption 1` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`System Prompt` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`image` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`Caption 2` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`image` diff:  tensor(0., device='cuda:0', dtype=torch.float16)
`Caption 3` diff:  tensor(0., device='cuda:0', dtype=torch.float16)

The last set of prints shows the difference between what the function outputs and what is expected, by comparing the different pieces of the result to the original inputs. As can be seen here, the result is almost correct but has misplaced the first image.

The reason is this line:

image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

It looks like this line is trying to account for padding in the outputs, but assumes the outputs are going to be left padded, whereas the rest of the function dynamically switches between left and right padding depending on the left_padding variable. If the outputs are being right padded, the comparison on this line is incorrect. That can be seen in the debug output above where image_to_overwrite shows the image being scattered around the tensor.

Importantly, this incorrect behavior can occur for any padding, because left_padding is "automatically" detected from the input. Even if the model is setup for left padding, if a particular batch doesn't have any padding going into this function, the function will assume right padding and thus mangle the output.

One solution is to add an if statement on that line based on left_padding and use a different condition to handle right padding. I think this code (for the right padding case) would work, but I haven't tested it extensively:

	idxs = torch.arange(max_embed_dim, device=image_to_overwrite.device).expand(batch_size, -1)
	image_to_overwrite &= idxs < (max_embed_dim - nb_image_pad[:, None]).to(target_device)

Side note:

I think the current implementation for handling the left padding case could also be replaced with arange based logic. To me it would be more clear, as it indicates the intention of that line explicitly (set all padded indexes to False). Whereas I found the use of image_to_overwrite.cumsum(-1) opaque and difficult to decipher what it was trying to calculate.

Expected behavior

See above

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions