Skip to content

Commit 8c5afcf

Browse files
jorisSchallerJoris Schaller
andauthored
feat: Support of encoder (#24)
Co-authored-by: Joris Schaller <[email protected]>
1 parent 61d186a commit 8c5afcf

File tree

19 files changed

+107354
-18
lines changed

19 files changed

+107354
-18
lines changed

src/mblm/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@
2626
__version__ = importlib.metadata.version("mblm")
2727

2828

29-
from mblm.model.config import MBLMModelConfig, MBLMReturnType
29+
from mblm.model.config import MBLMEncoderModelConfig, MBLMModelConfig, MBLMReturnType
3030
from mblm.model.mamba import MambaBlock
31-
from mblm.model.mblm import MBLM
32-
from mblm.model.transformer import TransformerBlock
31+
from mblm.model.mblm import MBLM, MBLMEncoder
32+
from mblm.model.transformer import TransformerBlock, TransformerEncoderBlock
3333

3434
__all__ = [
3535
"MBLM",
36+
"MBLMEncoder",
3637
"MBLMModelConfig",
38+
"MBLMEncoderModelConfig",
3739
"MBLMReturnType",
3840
"TransformerBlock",
41+
"TransformerEncoderBlock",
3942
"MambaBlock",
4043
]

src/mblm/data/dataset/beep.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from pathlib import Path
2+
3+
from typing_extensions import Unpack
4+
5+
from mblm.data.datasets import DistributedDataset, DistributedDatasetConfig
6+
from mblm.data.types import BatchMaskedForMLM, ModelMode
7+
from mblm.train.mblm import (
8+
TrainMaskedEntryConfig,
9+
masked_dataset_registry,
10+
)
11+
12+
13+
@masked_dataset_registry.register("beep")
14+
class Beep(DistributedDataset[BatchMaskedForMLM]):
15+
"""The beep dataset raw data"""
16+
17+
def __init__(
18+
self,
19+
mode: ModelMode,
20+
data_dir: str | Path,
21+
**args: Unpack[DistributedDatasetConfig],
22+
):
23+
# Dummy example - Get data from anywhere, e.g., the disk
24+
print(f"Reading dataset from {data_dir}")
25+
if mode == ModelMode.TRAIN:
26+
# TODO Load the train BEEP FILE.
27+
data = list(range(10_000))
28+
elif mode == ModelMode.VALID:
29+
# TODO Load the Beep Validation file
30+
data = list(range(2_000))
31+
elif mode == ModelMode.TEST:
32+
# TODO Load the Beep TEST file
33+
pass
34+
else:
35+
raise ValueError("This variant isn't implemented yet, please update the code")
36+
self._data = data
37+
38+
super().__init__(
39+
data_size=len(data),
40+
is_sequential=True, # We have a sequential dataset
41+
**args,
42+
)
43+
44+
def get_sample(self, from_idx: int):
45+
"""
46+
Tell the superclass how to get a single sample - here, a sequence of
47+
the specified length.
48+
"""
49+
# data = torch.tensor(self._data[from_idx : from_idx + self.seq_len])
50+
# return torch.ones_like(data), data
51+
raise NotImplementedError()
52+
53+
@staticmethod
54+
def from_train_entry_config(
55+
config: TrainMaskedEntryConfig,
56+
mode: ModelMode,
57+
worker_id: int,
58+
num_workers: int,
59+
) -> DistributedDataset[BatchMaskedForMLM]:
60+
"""
61+
How to parse a training config to a dataset.
62+
"""
63+
return Beep(
64+
data_dir=config.io.dataset_dir,
65+
mode=mode,
66+
seq_len=config.params.input_seq_len,
67+
num_workers=num_workers,
68+
worker_id=worker_id,
69+
)
70+
71+
@staticmethod
72+
def supports_test_mode() -> bool:
73+
"""
74+
Whether or not this dataset supports a test mode. Some datasets might not
75+
expose the answers in their test set so we cannot evaluate a model on it.
76+
Override if necessary
77+
"""
78+
return True

src/mblm/data/dataset/clevr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,4 +401,4 @@ def get_sample(self, from_idx: int) -> BatchWithLossMask:
401401
DistributedDataset superclass.
402402
"""
403403
q_i_q_a, loss_mask, _ = self.get_sample_with_parts(from_idx)
404-
return q_i_q_a, loss_mask
404+
return q_i_q_a, loss_mask #
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from __future__ import annotations
2+
3+
__copyright__ = """MIT License
4+
5+
Copyright (c) 2025 - IBM Research
6+
7+
Permission is hereby granted, free of charge, to any person obtaining a copy
8+
of this software and associated documentation files (the "Software"), to deal
9+
in the Software without restriction, including without limitation the rights
10+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
copies of the Software, and to permit persons to whom the Software is
12+
furnished to do so, subject to the following conditions:
13+
14+
The above copyright notice and this permission notice shall be included in all
15+
copies or substantial portions of the Software.
16+
17+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
SOFTWARE."""
24+
25+
26+
import pathlib
27+
import random
28+
from pathlib import Path
29+
from typing import TYPE_CHECKING, Generator
30+
31+
import torch
32+
import tqdm
33+
from typing_extensions import Unpack
34+
35+
from mblm.data.datasets import DistributedDataset, DistributedDatasetConfig
36+
from mblm.data.types import BatchMaskedForMLM, ModelMode
37+
from mblm.data.utils import Bytes
38+
39+
if TYPE_CHECKING:
40+
from mblm.train.mblm import TrainMaskedEntryConfig
41+
42+
43+
# @masked_dataset_registry.register("maskedPG19")
44+
class PG19Masked(DistributedDataset[BatchMaskedForMLM]):
45+
"""
46+
https://github.com/google-deepmind/pg19
47+
48+
The `data_dir` is expected to be the of the exact structure as the
49+
original dataset, although only the test, train and validation folders
50+
are strictly needed:
51+
├── LICENSE
52+
├── README.md
53+
├── metadata.csv
54+
├── test
55+
├── train
56+
└── validation
57+
58+
"""
59+
60+
def __init__(
61+
self,
62+
data_dir: str | Path,
63+
mode: ModelMode,
64+
masked_token_id: int = -100,
65+
masking_proba: float = 0.15,
66+
load_mininterval: int = 30,
67+
display_load_progress: bool = True,
68+
padding_token_id: int = -101,
69+
**config: Unpack[DistributedDatasetConfig],
70+
):
71+
root = Path(data_dir)
72+
if mode == ModelMode.VALID:
73+
data_path = root / "validation"
74+
else:
75+
data_path = root / mode.value
76+
self.txt_files = [file for file in pathlib.Path.iterdir(data_path)]
77+
self.masking_proba = masking_proba
78+
data_buff = bytearray()
79+
for file in tqdm.tqdm(
80+
self.txt_files,
81+
desc=f"Loading pg19 {data_path}",
82+
mininterval=load_mininterval,
83+
disable=not display_load_progress,
84+
):
85+
with Path.open(file, "rb") as f:
86+
data_buff.extend(f.read())
87+
self.data = Bytes.bytes_to_tensor(data_buff)
88+
self.masked_token_id = masked_token_id
89+
self.padding_token_id = padding_token_id
90+
if masked_token_id == padding_token_id:
91+
raise ValueError("You can't set the padding and the mask with the same value")
92+
93+
super().__init__(
94+
data_size=self.data.numel(),
95+
is_sequential=True,
96+
**config,
97+
)
98+
99+
@staticmethod
100+
def from_train_entry_config(
101+
config: TrainMaskedEntryConfig,
102+
mode: ModelMode,
103+
worker_id: int,
104+
num_workers: int,
105+
) -> DistributedDataset[BatchMaskedForMLM]:
106+
return PG19Masked(
107+
data_dir=config.io.dataset_dir,
108+
masking_proba=config.train.masking_proba,
109+
masked_token_id=config.params.mask_token_id,
110+
mode=mode,
111+
padding_token_id=config.params.mblm_config.pad_token_id,
112+
seq_len=config.params.input_seq_len,
113+
worker_id=worker_id,
114+
num_workers=num_workers,
115+
)
116+
117+
@staticmethod
118+
def supports_test_mode() -> bool:
119+
return True
120+
121+
def get_sample(self, from_idx: int) -> BatchMaskedForMLM:
122+
"""
123+
Get a sample with a loss mask. This method is required by the
124+
DistributedDataset superclass.
125+
"""
126+
sample = self.data[from_idx : from_idx + self.seq_len].long()
127+
mask = torch.rand(sample.size()) < self.masking_proba
128+
tokens_masked = sample.clone()
129+
130+
# TODO implement same strategy as BERT and even when token is masked, sometimes copy the correct token, not the masked_token_id
131+
tokens_masked[mask] = self.masked_token_id
132+
# Padd if necessary, should only be needed when from_idx == len(self)
133+
if sample.size(-1) != self.seq_len:
134+
# padd with tensors with padding_token_id
135+
pad_tensor = self.padding_token_id * torch.ones(self.seq_len - sample.size(-1))
136+
tokens_masked = torch.concat((tokens_masked, pad_tensor))
137+
# pad_tensor * 0 ensures that the loss is never computed over the padding tokens
138+
# as 1 shows MASKED elements and 0 non-MASKED token
139+
mask = torch.concat((mask, pad_tensor * 0))
140+
sample = torch.concat((sample, pad_tensor))
141+
return tokens_masked.long(), mask.bool(), sample.long()
142+
143+
def book(self, name: str) -> str:
144+
"""
145+
Get a book by its name (e.g., `44381.txt`) return its content as a
146+
string
147+
"""
148+
for candidate in self.txt_files:
149+
if candidate.name == name:
150+
with Path.open(candidate, "r", encoding="utf8") as f:
151+
return f.read()
152+
raise ValueError(f"Book {name} does not exist")
153+
154+
def iter_sequences_rand(self) -> Generator[torch.Tensor, None, None]:
155+
"""
156+
Iterate over random sequences across books of PG19
157+
"""
158+
max_sample_start_idx = len(self.data) - self.seq_len - 1
159+
while True:
160+
idx = random.randint(0, max_sample_start_idx)
161+
yield self.data[idx : idx + self.seq_len]
162+
163+
def iter_books(self, shuffle: bool = False) -> Generator[tuple[str, str], None, None]:
164+
"""
165+
Iterate over all the books in PG19, possibly in random order. Return an
166+
iterator over the index of the book and its content as a string
167+
"""
168+
txt_file_idxs = list(range(len(self.txt_files)))
169+
if shuffle:
170+
random.shuffle(txt_file_idxs)
171+
for i in txt_file_idxs:
172+
book = self.txt_files[i]
173+
with Path.open(book, "r", encoding="utf8") as f:
174+
yield book.name, f.read()

src/mblm/data/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ class ModelMode(Enum):
3333

3434

3535
BatchWithLossMask: TypeAlias = tuple[torch.Tensor, torch.Tensor]
36+
BatchMaskedForMLM: TypeAlias = tuple[torch.Tensor, torch.Tensor, torch.Tensor]

src/mblm/model/block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ def try_parse(
9898
except Exception:
9999
pass
100100

101-
raise ValueError(f"Coult not parse data to any of {self}")
101+
raise ValueError(f"Could not parse data to any of {self}")

src/mblm/model/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,19 @@
3434

3535
from mblm.model.block import StageBlock, StageBlockRegistry
3636
from mblm.model.mamba import MambaBlock
37-
from mblm.model.transformer import TransformerBlock
37+
from mblm.model.transformer import TransformerBlock, TransformerEncoderBlock
3838

3939
block_registry = StageBlockRegistry()
4040
block_registry.register()(TransformerBlock)
4141
block_registry.register()(MambaBlock)
42+
block_registry.register()(TransformerEncoderBlock)
4243

4344

4445
class MBLMReturnType(str, Enum):
4546
LOGITS = auto()
4647
LOSS = auto()
4748
LOSS_LOGITS = auto()
49+
HIDDEN_STATE = auto()
4850

4951

5052
class MBLMModelConfig(BaseModel):
@@ -113,3 +115,8 @@ def stage_blocks(self) -> Sequence[StageBlock]:
113115
if isinstance(self.block, Sequence):
114116
return self.block
115117
return list(repeat(self.block, len(self.hidden_dims)))
118+
119+
120+
class MBLMEncoderModelConfig(BaseModel):
121+
mask_token_id: int
122+
mblm_config: MBLMModelConfig

src/mblm/model/mamba.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
SOFTWARE."""
2222

2323

24-
from pydantic import Field
24+
from pydantic import Field, model_validator
2525

2626
from mblm.model.block import StageBlock
2727
from mblm.model.mamba_shim import Mamba1, Mamba1Config, Mamba2Mixer
@@ -79,3 +79,9 @@ def to_model(self, model_dim, num_layers):
7979
raise RuntimeError(
8080
"Failed to import any Mamba version - this should never happen",
8181
)
82+
83+
@model_validator(mode="after")
84+
def validate_block_type(self):
85+
if "mamba" not in self.block_type:
86+
raise ValueError("This model is a mamba block")
87+
return self

0 commit comments

Comments
 (0)