Skip to content

Abnormal GPU Utilization When Evaluating Qwen 2.5 VL 72B #738

@LiamLian0727

Description

@LiamLian0727

Low GPU Utilization and Errors During Qwen 2.5 VL 72B Evaluation

First, I'd like to commend the lmms-eval team for developing such a robust evaluation framework - it's been incredibly valuable for multimodal research.

Problem Description

When evaluating Qwen 2.5 VL 72B on 8x H100 80G GPUs using the following command:

accelerate launch --num_processes=1 --main_process_port=12346 -m lmms_eval \
  --model qwen2_5_vl \
  --model_args="pretrained=/mnt/mypath/models/Qwen2.5-VL-72B-Instruct,max_pixels=12845056,attn_implementation=flash_attention_2,interleave_visuals=False,device_map=auto" \
  --tasks super_clevr \
  --output_path /mnt/mypath/qwen72b \
  --batch_size 1 \
  --verbosity DEBUG

My GPU utilization only reaches a maximum of 10%.

Image

Attempted Solutions and Errors

1 Increasing num_processes to 8: Results in OOM (Out of Memory) errors despite using H100 GPUs 80G

This is most likely because the code in this ignores the device_map passed in by the user when num_processes is greater than 1 and uses f “cuda:{accelerator.local_process_index}”, which results in the model not being sharded correctly

accelerator = Accelerator()
if accelerator.num_processes > 1:
    self._device = torch.device(f"cuda:{accelerator.local_process_index}")
    self.device_map = f"cuda:{accelerator.local_process_index}"
else:
    self._device = torch.device(device)
    self.device_map = device_map if device_map else device

2 Increasing batch_size > 1: Throws ValueError:

ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. 
Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input

After that, I tried to change code :self._tokenizer = AutoTokenizer.from_pretrained(pretrained) in qwen2_5_vl.py to self._tokenizer = AutoTokenizer.from_pretrained(pretrained, padding_side = 'left') but it didn't work.

Request for Assistance

How should I properly configure either:

  1. Multi-GPU evaluation to avoid OOM, or
  2. Batched evaluation with correct padding configuration?
  3. Any other recommended parameters for efficient evaluation of this large vision-language model?

We'd be grateful for any insights from the team or community members who have successfully evaluated similar large multimodal models.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions