Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 82 additions & 37 deletions tests/experimental/agent_loop/test_multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
from verl.utils import hf_tokenizer


def parse_multi_modal_type(messages: list[dict]) -> str:
message = messages[-1]
if isinstance(message["content"], str):
return "text"

for content in message["content"]:
if content["type"] == "image":
return "image"
elif content["type"] == "video":
return "video"

return "text"


@pytest.fixture
def init_config() -> DictConfig:
from hydra import compose, initialize_config_dir
Expand All @@ -49,7 +63,7 @@ def init_config() -> DictConfig:
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.enforce_eager = True
config.actor_rollout_ref.rollout.prompt_length = 4096
config.actor_rollout_ref.rollout.prompt_length = 10240
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.n = 4
config.actor_rollout_ref.rollout.agent.num_workers = 2
Expand Down Expand Up @@ -157,6 +171,25 @@ def test_multimodal_tool_agent(init_config):
[
{"role": "user", "content": "How are you?"},
],
[
{
"role": "user",
"content": [
{
"type": "video",
"video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"),
"min_pixels": 4 * 32 * 32,
"max_pixels": 256 * 32 * 32,
"total_pixels": 4096 * 32 * 32,
},
{
"type": "text",
"text": "Describe this video. Then you must call the "
"image generator tool to generate a green image for me.",
},
],
},
],
[
{"role": "user", "content": "Please generate a red image for me."},
],
Expand Down Expand Up @@ -189,14 +222,23 @@ def test_multimodal_tool_agent(init_config):

# Check turns
num_turns = result.non_tensor_batch["__num_turns__"]
multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"]
print(f"num_turns: {num_turns}")
for i in range(len(num_turns)):
if i // n == 0:
multi_modal_type = parse_multi_modal_type(raw_prompts[i // n])
if multi_modal_type == "video":
assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos"
assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw"

if i // n <= 1:
# TODO: prompt with video not generate tool call as expected
# First prompt: "How are you?" - should have 2 turns [user, assistant]
assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
else:
# Tool-calling prompts should have 4 turns [user, assistant, tool, assistant]
assert num_turns[i] == 4, f"Expected 4 turns but got {num_turns[i]} for sample {i}"
assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values"
assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw"

# Check that images were properly returned in the tool responses
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
Expand Down Expand Up @@ -275,18 +317,21 @@ def test_multimodal_single_turn_agent(init_config):
test_image2 = Image.new("RGB", (512, 512), (100, 150, 200))

raw_prompts = [
# text
[
{"role": "user", "content": "Hello, how are you?"},
],
# image
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image", "image": test_image},
{"type": "text", "text": "What color is this image?"},
],
},
],
# system + image
[
{
"role": "system",
Expand All @@ -295,18 +340,27 @@ def test_multimodal_single_turn_agent(init_config):
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image", "image": test_image2},
{"type": "text", "text": "Describe this image in detail."},
],
},
],
]

# Prepare multi_modal_data for each prompt
multi_modal_data_list = [
None, # First prompt: text only
{"image": test_image}, # Second prompt: with image
{"image": test_image2}, # Third prompt: with image
# video
[
{
"role": "user",
"content": [
{
"type": "video",
"video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"),
"min_pixels": 4 * 32 * 32,
"max_pixels": 256 * 32 * 32,
"total_pixels": 4096 * 32 * 32,
},
{"type": "text", "text": "Describe this video."},
],
},
],
]

batch = DataProto(
Expand All @@ -318,10 +372,6 @@ def test_multimodal_single_turn_agent(init_config):
},
)

# Add multi_modal_data to batch
multi_modal_data_array = np.array([data if data else {} for data in multi_modal_data_list], dtype=object)
batch.non_tensor_batch["multi_modal_data"] = multi_modal_data_array

batch = batch.repeat(n)
result = agent_loop_manager.generate_sequences(prompts=batch)
assert len(result) == len(raw_prompts) * n
Expand All @@ -337,7 +387,11 @@ def test_multimodal_single_turn_agent(init_config):
prompts = result.batch["prompts"]
responses = result.batch["responses"]
response_mask = result.batch["response_mask"]
input_ids = result.batch["input_ids"]
position_ids = result.batch["position_ids"]
multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"]
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
assert position_ids.size() == (input_ids.size(0), 4, input_ids.size(1)) # (batch_size, 4, seq_len)

# Check for image pads in prompts
image_pad_count = 0
Expand All @@ -354,14 +408,17 @@ def test_multimodal_single_turn_agent(init_config):
print(f"Prompt length: {len(prompt_ids)} tokens")
print(f"Has image_pad: {has_image_pad}")

if sample_idx != 0: # Samples 1 and 2 should have images
if has_image_pad:
image_pad_count += 1
# Count the number of image_pad tokens
num_image_pads = prompt_text.count("<|image_pad|>")
print(f"Number of <|image_pad|> tokens: {num_image_pads}")
else:
print("WARNING: Expected image_pad but not found!")
# Check multi-modal type
multi_modal_type = parse_multi_modal_type(raw_prompts[sample_idx])

if multi_modal_type == "text":
assert len(multi_modal_inputs[i]) == 0, f"Sample {i} should not have multi-modal inputs"
elif multi_modal_type == "image":
assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values"
assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw"
else:
assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos"
assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw"

# Show first 200 chars of prompt
print(f"Prompt text (first 200 chars): {prompt_text[:200]}...")
Expand All @@ -374,7 +431,6 @@ def test_multimodal_single_turn_agent(init_config):
# Verify that we found image pads in multimodal samples
expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times
print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected")
assert image_pad_count > 0, "No image_pad tokens found in multimodal samples!"

print("Single turn multimodal test passed!")
ray.shutdown()
Expand Down Expand Up @@ -427,7 +483,7 @@ def test_multimodal_partial_single_turn_agent(init_config):
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image", "image": test_image},
{"type": "text", "text": "What do you see in this image?"},
],
},
Expand All @@ -440,20 +496,13 @@ def test_multimodal_partial_single_turn_agent(init_config):
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image", "image": test_image2},
{"type": "text", "text": "Analyze the colors in this image."},
],
},
],
]

# Prepare multi_modal_data for each prompt
multi_modal_data_list = [
None, # First prompt: text only
{"image": test_image}, # Second prompt: with image
{"image": test_image2}, # Third prompt: with image
]

batch = DataProto(
non_tensor_batch={
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
Expand All @@ -463,10 +512,6 @@ def test_multimodal_partial_single_turn_agent(init_config):
},
)

# Add multi_modal_data to batch
multi_modal_data_array = np.array([data if data else {} for data in multi_modal_data_list], dtype=object)
batch.non_tensor_batch["multi_modal_data"] = multi_modal_data_array

batch = batch.repeat(n)
result = agent_loop_manager.generate_sequences(prompts=batch)
assert len(result) == len(raw_prompts) * n
Expand Down
Loading
Loading