Skip to content

Commit 9dc1142

Browse files
committed
remove old collate fn
1 parent cf8e8af commit 9dc1142

File tree

5 files changed

+10
-85
lines changed

5 files changed

+10
-85
lines changed

xtuner/dataset/hybrid/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from .collate import text_collate_fn
1+
# Copyright (c) OpenMMLab. All rights reserved.
22
from .dataset import TextDataset
33
from .mappings import map_protocol, map_sequential, openai_to_raw_training
44

55
__all__ = [
6-
'text_collate_fn',
76
'TextDataset',
87
'map_protocol',
98
'map_sequential',

xtuner/dataset/hybrid/collate.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

xtuner/dataset/hybrid/dataset.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
12
import functools
23
import json
34
import os
@@ -270,8 +271,8 @@ def load_dataset(
270271
"""
271272
if self.is_cached(cache_dir):
272273
print_log(
273-
f'{cache_dir} is cached dataset that will be loaded '
274-
'directly; `data_files` and `data_dir` will become'
274+
f'{cache_dir} is a cached dataset that will be loaded '
275+
'directly; `data_files` and `data_dir` will become '
275276
'invalid.',
276277
logger='current')
277278

@@ -359,7 +360,7 @@ def tokenize_dataset(self, dataset: List[dict]) -> List[dict]:
359360
`labels` and `num_tokens`.
360361
`input_ids` and `labels` are lists of int, and they should
361362
have equal lengths.
362-
`num_tokens` is an integerthe length of `input_ids`.
363+
`num_tokens` is an integer, the length of `input_ids`.
363364
"""
364365

365366
def openai_to_raw_training(item: dict) -> Dict:
@@ -574,48 +575,27 @@ def __getitem__(self, item: int) -> Dict[str, List]:
574575
stop_words=['<|im_end|>'],
575576
)
576577

577-
from xtuner.dataset.hybrid.mappings import openai_to_raw_training
578-
579-
data_dir = './llava_data/LLaVA-Instruct-150K/'
580-
image_dir = './llava_data/llava_images/'
581-
data_files = 'llava_v1_5_mix665k.json'
582-
583578
dataset = TextDataset(
584579
'internlm/internlm2-chat-1_8b',
585580
chat_template,
586581
sample_ratio=1,
587582
max_length=32 * 1024,
588-
data_dir=data_dir,
589-
data_files=data_files,
583+
data_dir='converted_alpaca',
584+
cache_dir='cached_alpaca',
590585
pack_to_max_length=True,
591-
mappings=[openai_to_raw_training],
592586
num_proc=4)
593587

594588
print(dataset[0])
595589

596-
dataset.cache('cached_llava')
597-
dataset = TextDataset(
598-
'internlm/internlm2-chat-1_8b',
599-
chat_template,
600-
sample_ratio=1,
601-
max_length=32 * 1024,
602-
cache_dir='cached_llava',
603-
pack_to_max_length=True,
604-
mappings=[
605-
openai_to_raw_training,
606-
],
607-
num_proc=4)
608-
print(dataset[0])
609-
610590
from mmengine.dataset import DefaultSampler
611591
from torch.utils.data import DataLoader
612592

613-
from xtuner.dataset.hybrid.collate import text_collate_fn
593+
from xtuner.model import TextFinetune
614594
loader = DataLoader(
615595
dataset,
616596
4,
617597
num_workers=0,
618-
collate_fn=text_collate_fn,
598+
collate_fn=TextFinetune.dataloader_collate_fn,
619599
sampler=DefaultSampler(dataset, shuffle=True))
620600

621601
for data in tqdm(loader):

xtuner/model/text/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
12
from .finetune import TextFinetune
23

34
__all__ = ['TextFinetune']

xtuner/model/text/finetune.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
32
from collections import OrderedDict
43
from typing import Dict, List, Optional, Union
54

0 commit comments

Comments
 (0)