Skip to content

Commit ca02c5d

Browse files
committed
finish
1 parent 449d9fc commit ca02c5d

13 files changed

Lines changed: 203 additions & 120 deletions

File tree

.github/requirements-test.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
datasets
2+
pillow
3+
pytest
4+
ruff
5+
torch
6+
transformers

.github/workflows/tests.yml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name: tests
22

33
on:
4+
workflow_dispatch:
45
push:
56
branches:
67
- "main"
@@ -28,6 +29,10 @@ jobs:
2829

2930
runs-on: ${{ matrix.os }}
3031

32+
concurrency:
33+
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python-version }}
34+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
35+
3136
steps:
3237
- name: Checkout
3338
uses: actions/checkout@v4
@@ -37,13 +42,22 @@ jobs:
3742
with:
3843
python-version: ${{ matrix.python-version }}
3944
cache: "pip"
40-
cache-dependency-path: "setup.py"
45+
cache-dependency-path: "**/requirements*.txt"
4146

4247
- name: Install dependencies
4348
run: |
4449
python -m pip install --upgrade pip
45-
python -m pip install ruff
50+
python -m pip install -r .github/requirements-test.txt --index-url https://download.pytorch.org/whl/cpu
51+
python -m pip install --no-deps .
4652
4753
- name: Check quality
4854
run: |
4955
make style && make quality
56+
57+
- name: Check license
58+
run: |
59+
make license
60+
61+
- name: Test with pytest
62+
run: |
63+
make test

Makefile

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
.PHONY: build commit quality style
1+
.PHONY: build commit license quality style test
22

3-
check_dirs := examples scripts verl setup.py
3+
check_dirs := examples scripts tests verl setup.py
44

55
build:
66
python3 setup.py sdist bdist_wheel
@@ -9,10 +9,16 @@ commit:
99
pre-commit install
1010
pre-commit run --all-files
1111

12+
license:
13+
python3 tests/check_license.py $(check_dirs)
14+
1215
quality:
1316
ruff check $(check_dirs)
1417
ruff format --check $(check_dirs)
1518

1619
style:
1720
ruff check $(check_dirs) --fix
1821
ruff format $(check_dirs)
22+
23+
test:
24+
pytest -vv tests/

examples/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ data:
1313
override_chat_template: null
1414
shuffle: true
1515
seed: 1
16-
max_pixels: 4194304
1716
min_pixels: 262144
17+
max_pixels: 4194304
1818
filter_overlong_prompts: true
1919

2020
algorithm:

tests/check_license.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
from pathlib import Path
17+
18+
19+
KEYWORDS = ("Copyright", "2024", "Bytedance")
20+
21+
22+
def main():
23+
path_list: list[Path] = []
24+
for check_dir in sys.argv[1:]:
25+
path_list.extend(Path(check_dir).glob("**/*.py"))
26+
27+
for path in path_list:
28+
with open(path.absolute(), encoding="utf-8") as f:
29+
file_content = f.read().strip().split("\n")
30+
license = "\n".join(file_content[:5])
31+
if not license:
32+
continue
33+
34+
print(f"Check license: {path}")
35+
assert all(keyword in license for keyword in KEYWORDS), f"File {path} does not contain license."
36+
37+
38+
if __name__ == "__main__":
39+
main()

tests/test_dataset.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from PIL.Image import Image
17+
18+
from verl.utils.dataset import RLHFDataset
19+
from verl.utils.tokenizer import get_processor, get_tokenizer
20+
21+
22+
def test_image_dataset():
23+
tokenizer = get_tokenizer("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
24+
processor = get_processor("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
25+
dataset = RLHFDataset(
26+
data_path="hiyouga/geometry3k@test",
27+
tokenizer=tokenizer,
28+
processor=processor,
29+
prompt_key="problem",
30+
answer_key="answer",
31+
image_key="images",
32+
max_prompt_length=16,
33+
truncation="right",
34+
filter_overlong_prompts=False,
35+
)
36+
token_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655]
37+
assert set(dataset[0].keys()) == {
38+
"problem",
39+
"ground_truth",
40+
"input_ids",
41+
"attention_mask",
42+
"position_ids",
43+
"raw_prompt_ids",
44+
"multi_modal_data",
45+
}
46+
assert dataset[0]["problem"] == (
47+
"<image>Chords $\\overline{A C}$ and $\\overline{D F}$ are equidistant from the center. "
48+
"If the radius of $\\odot G$ is 26 find $A C$"
49+
)
50+
assert dataset[0]["ground_truth"] == "48"
51+
assert torch.all(dataset[0]["input_ids"] == torch.tensor(token_ids))
52+
assert torch.all(dataset[0]["attention_mask"] == torch.ones(16))
53+
assert torch.all(dataset[0]["position_ids"] == torch.arange(16).unsqueeze(0).expand(3, -1))
54+
assert list(dataset[0]["position_ids"].size()) == [3, 16] # avoid fake positive caused by broadcasting
55+
assert dataset[0]["raw_prompt_ids"] == token_ids
56+
assert isinstance(dataset[0]["multi_modal_data"]["images"][0], Image)
57+
58+
59+
if __name__ == "__main__":
60+
test_image_dataset()

verl/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,11 @@ def pop(
384384
meta_info_keys = meta_info_keys or []
385385

386386
tensors = {}
387-
for key in batch_keys and key in self.batch:
387+
for key in filter(lambda k: k in self.batch, batch_keys):
388388
tensors[key] = self.batch.pop(key)
389389

390390
non_tensors = {}
391-
for key in non_tensor_batch_keys and key in self.non_tensor_batch:
391+
for key in filter(lambda k: k in self.non_tensor_batch, non_tensor_batch_keys):
392392
non_tensors[key] = self.non_tensor_batch.pop(key)
393393

394394
meta_info = {}

verl/trainer/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class DataConfig:
4747
override_chat_template: Optional[str] = None
4848
shuffle: bool = True
4949
seed: int = 1
50-
max_pixels: int = 4194304
51-
min_pixels: int = 262144
50+
min_pixels: Optional[int] = 262144
51+
max_pixels: Optional[int] = 4194304
5252
filter_overlong_prompts: bool = True
5353

5454
def post_init(self):

verl/trainer/ray_trainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,14 @@ def fit(self):
475475
break
476476

477477
metrics, timing_raw = {}, {}
478-
batch: DataProto = DataProto.from_single_dict(batch_dict)
479-
batch.meta_info = {
480-
"min_pixels": self.config.data.min_pixels,
481-
"max_pixels": self.config.data.max_pixels,
482-
}
478+
meta_info = {"min_pixels": self.config.data.min_pixels, "max_pixels": self.config.data.max_pixels}
479+
batch: DataProto = DataProto.from_single_dict(batch_dict, meta_info=meta_info)
483480

484481
# pop those keys for generation
485482
gen_batch = batch.pop(
486483
batch_keys=["input_ids", "attention_mask", "position_ids"],
487484
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
485+
meta_info_keys=["min_pixels", "max_pixels"],
488486
)
489487
with timer("step", timing_raw):
490488
# generate a batch

verl/utils/dataset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
5050
return {**tensors, **non_tensors}
5151

5252

53-
def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: int, max_pixels: int) -> ImageObject:
53+
def process_image(
54+
image: Union[Dict[str, Any], ImageObject, str], min_pixels: Optional[int], max_pixels: Optional[int]
55+
) -> ImageObject:
5456
if isinstance(image, str):
5557
image = Image.open(image)
5658
elif isinstance(image, dict):
@@ -59,12 +61,12 @@ def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: in
5961
image = Image.open(BytesIO(image))
6062

6163
image.load() # avoid "Too many open files" errors
62-
if (image.width * image.height) > max_pixels:
64+
if max_pixels is not None and (image.width * image.height) > max_pixels:
6365
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
6466
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
6567
image = image.resize((width, height))
6668

67-
if (image.width * image.height) < min_pixels:
69+
if min_pixels is not None and (image.width * image.height) < min_pixels:
6870
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
6971
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
7072
image = image.resize((width, height))
@@ -92,8 +94,8 @@ def __init__(
9294
max_prompt_length: int = 1024,
9395
truncation: str = "error",
9496
format_prompt: Optional[str] = None,
95-
max_pixels: Optional[int] = None,
9697
min_pixels: Optional[int] = None,
98+
max_pixels: Optional[int] = None,
9799
filter_overlong_prompts: bool = True,
98100
):
99101
self.tokenizer = tokenizer
@@ -104,8 +106,8 @@ def __init__(
104106
self.image_dir = image_dir
105107
self.max_prompt_length = max_prompt_length
106108
self.truncation = truncation
107-
self.max_pixels = max_pixels
108109
self.min_pixels = min_pixels
110+
self.max_pixels = max_pixels
109111
self.filter_overlong_prompts = filter_overlong_prompts
110112

111113
if "@" in data_path:
@@ -169,17 +171,16 @@ def __getitem__(self, index):
169171
if self.image_key in example:
170172
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
171173
images = example.pop(self.image_key)
172-
if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): # image paths
174+
if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): # image paths
173175
images = [os.path.join(self.image_dir, image) for image in images]
174176

175177
resized_images = [
176-
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels)
177-
for image in images
178+
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) for image in images
178179
]
179180
model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt")
180181
input_ids = model_inputs.pop("input_ids")[0]
181182
attention_mask = model_inputs.pop("attention_mask")[0]
182-
example["multi_modal_inputs"] = {"images": images}
183+
example["multi_modal_data"] = {"images": images}
183184
else:
184185
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
185186
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")

0 commit comments

Comments
 (0)