|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
1 | 2 | import functools
|
2 | 3 | import json
|
3 | 4 | import os
|
@@ -270,8 +271,8 @@ def load_dataset(
|
270 | 271 | """
|
271 | 272 | if self.is_cached(cache_dir):
|
272 | 273 | print_log(
|
273 |
| - f'{cache_dir} is cached dataset that will be loaded ' |
274 |
| - 'directly; `data_files` and `data_dir` will become' |
| 274 | + f'{cache_dir} is a cached dataset that will be loaded ' |
| 275 | + 'directly; `data_files` and `data_dir` will become ' |
275 | 276 | 'invalid.',
|
276 | 277 | logger='current')
|
277 | 278 |
|
@@ -359,7 +360,7 @@ def tokenize_dataset(self, dataset: List[dict]) -> List[dict]:
|
359 | 360 | `labels` and `num_tokens`.
|
360 | 361 | `input_ids` and `labels` are lists of int, and they should
|
361 | 362 | have equal lengths.
|
362 |
| - `num_tokens` is an integer,the length of `input_ids`. |
| 363 | + `num_tokens` is an integer, the length of `input_ids`. |
363 | 364 | """
|
364 | 365 |
|
365 | 366 | def openai_to_raw_training(item: dict) -> Dict:
|
@@ -574,48 +575,27 @@ def __getitem__(self, item: int) -> Dict[str, List]:
|
574 | 575 | stop_words=['<|im_end|>'],
|
575 | 576 | )
|
576 | 577 |
|
577 |
| - from xtuner.dataset.hybrid.mappings import openai_to_raw_training |
578 |
| - |
579 |
| - data_dir = './llava_data/LLaVA-Instruct-150K/' |
580 |
| - image_dir = './llava_data/llava_images/' |
581 |
| - data_files = 'llava_v1_5_mix665k.json' |
582 |
| - |
583 | 578 | dataset = TextDataset(
|
584 | 579 | 'internlm/internlm2-chat-1_8b',
|
585 | 580 | chat_template,
|
586 | 581 | sample_ratio=1,
|
587 | 582 | max_length=32 * 1024,
|
588 |
| - data_dir=data_dir, |
589 |
| - data_files=data_files, |
| 583 | + data_dir='converted_alpaca', |
| 584 | + cache_dir='cached_alpaca', |
590 | 585 | pack_to_max_length=True,
|
591 |
| - mappings=[openai_to_raw_training], |
592 | 586 | num_proc=4)
|
593 | 587 |
|
594 | 588 | print(dataset[0])
|
595 | 589 |
|
596 |
| - dataset.cache('cached_llava') |
597 |
| - dataset = TextDataset( |
598 |
| - 'internlm/internlm2-chat-1_8b', |
599 |
| - chat_template, |
600 |
| - sample_ratio=1, |
601 |
| - max_length=32 * 1024, |
602 |
| - cache_dir='cached_llava', |
603 |
| - pack_to_max_length=True, |
604 |
| - mappings=[ |
605 |
| - openai_to_raw_training, |
606 |
| - ], |
607 |
| - num_proc=4) |
608 |
| - print(dataset[0]) |
609 |
| - |
610 | 590 | from mmengine.dataset import DefaultSampler
|
611 | 591 | from torch.utils.data import DataLoader
|
612 | 592 |
|
613 |
| - from xtuner.dataset.hybrid.collate import text_collate_fn |
| 593 | + from xtuner.model import TextFinetune |
614 | 594 | loader = DataLoader(
|
615 | 595 | dataset,
|
616 | 596 | 4,
|
617 | 597 | num_workers=0,
|
618 |
| - collate_fn=text_collate_fn, |
| 598 | + collate_fn=TextFinetune.dataloader_collate_fn, |
619 | 599 | sampler=DefaultSampler(dataset, shuffle=True))
|
620 | 600 |
|
621 | 601 | for data in tqdm(loader):
|
|
0 commit comments