|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +__copyright__ = """MIT License |
| 4 | +
|
| 5 | +Copyright (c) 2025 - IBM Research |
| 6 | +
|
| 7 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 8 | +of this software and associated documentation files (the "Software"), to deal |
| 9 | +in the Software without restriction, including without limitation the rights |
| 10 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 11 | +copies of the Software, and to permit persons to whom the Software is |
| 12 | +furnished to do so, subject to the following conditions: |
| 13 | +
|
| 14 | +The above copyright notice and this permission notice shall be included in all |
| 15 | +copies or substantial portions of the Software. |
| 16 | +
|
| 17 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 18 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 19 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 20 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 21 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 22 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 23 | +SOFTWARE.""" |
| 24 | + |
| 25 | + |
| 26 | +import pathlib |
| 27 | +import random |
| 28 | +from pathlib import Path |
| 29 | +from typing import TYPE_CHECKING, Generator |
| 30 | + |
| 31 | +import torch |
| 32 | +import tqdm |
| 33 | +from typing_extensions import Unpack |
| 34 | + |
| 35 | +from mblm.data.datasets import DistributedDataset, DistributedDatasetConfig |
| 36 | +from mblm.data.types import BatchMaskedForMLM, ModelMode |
| 37 | +from mblm.data.utils import Bytes |
| 38 | + |
| 39 | +if TYPE_CHECKING: |
| 40 | + from mblm.train.mblm import TrainMaskedEntryConfig |
| 41 | + |
| 42 | + |
| 43 | +# @masked_dataset_registry.register("maskedPG19") |
| 44 | +class PG19Masked(DistributedDataset[BatchMaskedForMLM]): |
| 45 | + """ |
| 46 | + https://github.com/google-deepmind/pg19 |
| 47 | +
|
| 48 | + The `data_dir` is expected to be the of the exact structure as the |
| 49 | + original dataset, although only the test, train and validation folders |
| 50 | + are strictly needed: |
| 51 | + ├── LICENSE |
| 52 | + ├── README.md |
| 53 | + ├── metadata.csv |
| 54 | + ├── test |
| 55 | + ├── train |
| 56 | + └── validation |
| 57 | +
|
| 58 | + """ |
| 59 | + |
| 60 | + def __init__( |
| 61 | + self, |
| 62 | + data_dir: str | Path, |
| 63 | + mode: ModelMode, |
| 64 | + masked_token_id: int = -100, |
| 65 | + masking_proba: float = 0.15, |
| 66 | + load_mininterval: int = 30, |
| 67 | + display_load_progress: bool = True, |
| 68 | + padding_token_id: int = -101, |
| 69 | + **config: Unpack[DistributedDatasetConfig], |
| 70 | + ): |
| 71 | + root = Path(data_dir) |
| 72 | + if mode == ModelMode.VALID: |
| 73 | + data_path = root / "validation" |
| 74 | + else: |
| 75 | + data_path = root / mode.value |
| 76 | + self.txt_files = [file for file in pathlib.Path.iterdir(data_path)] |
| 77 | + self.masking_proba = masking_proba |
| 78 | + data_buff = bytearray() |
| 79 | + for file in tqdm.tqdm( |
| 80 | + self.txt_files, |
| 81 | + desc=f"Loading pg19 {data_path}", |
| 82 | + mininterval=load_mininterval, |
| 83 | + disable=not display_load_progress, |
| 84 | + ): |
| 85 | + with Path.open(file, "rb") as f: |
| 86 | + data_buff.extend(f.read()) |
| 87 | + self.data = Bytes.bytes_to_tensor(data_buff) |
| 88 | + self.masked_token_id = masked_token_id |
| 89 | + self.padding_token_id = padding_token_id |
| 90 | + if masked_token_id == padding_token_id: |
| 91 | + raise ValueError("You can't set the padding and the mask with the same value") |
| 92 | + |
| 93 | + super().__init__( |
| 94 | + data_size=self.data.numel(), |
| 95 | + is_sequential=True, |
| 96 | + **config, |
| 97 | + ) |
| 98 | + |
| 99 | + @staticmethod |
| 100 | + def from_train_entry_config( |
| 101 | + config: TrainMaskedEntryConfig, |
| 102 | + mode: ModelMode, |
| 103 | + worker_id: int, |
| 104 | + num_workers: int, |
| 105 | + ) -> DistributedDataset[BatchMaskedForMLM]: |
| 106 | + return PG19Masked( |
| 107 | + data_dir=config.io.dataset_dir, |
| 108 | + masking_proba=config.train.masking_proba, |
| 109 | + masked_token_id=config.params.mask_token_id, |
| 110 | + mode=mode, |
| 111 | + padding_token_id=config.params.mblm_config.pad_token_id, |
| 112 | + seq_len=config.params.input_seq_len, |
| 113 | + worker_id=worker_id, |
| 114 | + num_workers=num_workers, |
| 115 | + ) |
| 116 | + |
| 117 | + @staticmethod |
| 118 | + def supports_test_mode() -> bool: |
| 119 | + return True |
| 120 | + |
| 121 | + def get_sample(self, from_idx: int) -> BatchMaskedForMLM: |
| 122 | + """ |
| 123 | + Get a sample with a loss mask. This method is required by the |
| 124 | + DistributedDataset superclass. |
| 125 | + """ |
| 126 | + sample = self.data[from_idx : from_idx + self.seq_len].long() |
| 127 | + mask = torch.rand(sample.size()) < self.masking_proba |
| 128 | + tokens_masked = sample.clone() |
| 129 | + |
| 130 | + # TODO implement same strategy as BERT and even when token is masked, sometimes copy the correct token, not the masked_token_id |
| 131 | + tokens_masked[mask] = self.masked_token_id |
| 132 | + # Padd if necessary, should only be needed when from_idx == len(self) |
| 133 | + if sample.size(-1) != self.seq_len: |
| 134 | + # padd with tensors with padding_token_id |
| 135 | + pad_tensor = self.padding_token_id * torch.ones(self.seq_len - sample.size(-1)) |
| 136 | + tokens_masked = torch.concat((tokens_masked, pad_tensor)) |
| 137 | + # pad_tensor * 0 ensures that the loss is never computed over the padding tokens |
| 138 | + # as 1 shows MASKED elements and 0 non-MASKED token |
| 139 | + mask = torch.concat((mask, pad_tensor * 0)) |
| 140 | + sample = torch.concat((sample, pad_tensor)) |
| 141 | + return tokens_masked.long(), mask.bool(), sample.long() |
| 142 | + |
| 143 | + def book(self, name: str) -> str: |
| 144 | + """ |
| 145 | + Get a book by its name (e.g., `44381.txt`) return its content as a |
| 146 | + string |
| 147 | + """ |
| 148 | + for candidate in self.txt_files: |
| 149 | + if candidate.name == name: |
| 150 | + with Path.open(candidate, "r", encoding="utf8") as f: |
| 151 | + return f.read() |
| 152 | + raise ValueError(f"Book {name} does not exist") |
| 153 | + |
| 154 | + def iter_sequences_rand(self) -> Generator[torch.Tensor, None, None]: |
| 155 | + """ |
| 156 | + Iterate over random sequences across books of PG19 |
| 157 | + """ |
| 158 | + max_sample_start_idx = len(self.data) - self.seq_len - 1 |
| 159 | + while True: |
| 160 | + idx = random.randint(0, max_sample_start_idx) |
| 161 | + yield self.data[idx : idx + self.seq_len] |
| 162 | + |
| 163 | + def iter_books(self, shuffle: bool = False) -> Generator[tuple[str, str], None, None]: |
| 164 | + """ |
| 165 | + Iterate over all the books in PG19, possibly in random order. Return an |
| 166 | + iterator over the index of the book and its content as a string |
| 167 | + """ |
| 168 | + txt_file_idxs = list(range(len(self.txt_files))) |
| 169 | + if shuffle: |
| 170 | + random.shuffle(txt_file_idxs) |
| 171 | + for i in txt_file_idxs: |
| 172 | + book = self.txt_files[i] |
| 173 | + with Path.open(book, "r", encoding="utf8") as f: |
| 174 | + yield book.name, f.read() |
0 commit comments