|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 | 4 | import copy
|
| 5 | +import hashlib |
| 6 | +import os |
| 7 | +import zipfile |
5 | 8 | from typing import Any, Callable
|
6 | 9 |
|
7 | 10 | import pytest
|
| 11 | +import requests |
8 | 12 | from pytest import fixture
|
9 |
| -from tenacity import retry, stop_after_attempt, wait_fixed |
10 | 13 | from transformers import PreTrainedTokenizerBase
|
11 | 14 |
|
12 | 15 | from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
|
@@ -195,109 +198,81 @@ def tiny_bert_config_helper():
|
195 | 198 | return config_object
|
196 | 199 |
|
197 | 200 |
|
198 |
| -## TOKENIZER HELPERS ## |
199 |
| -@retry( |
200 |
| - wait=wait_fixed(5), |
201 |
| - stop=stop_after_attempt(1), |
202 |
| -) |
203 |
| -def tiny_gpt2_tokenizer_helper(add_pad: bool = False): |
204 |
| - transformers = pytest.importorskip('transformers') |
205 |
| - |
206 |
| - hf_tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') |
207 |
| - |
208 |
| - if add_pad: |
209 |
| - hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
210 |
| - return hf_tokenizer |
| 201 | +def assets_path(): |
| 202 | + rank = os.environ.get('RANK', '0') |
| 203 | + folder_name = 'tokenizers' + (f'_{rank}' if rank != '0' else '') |
| 204 | + return os.path.join( |
| 205 | + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), |
| 206 | + 'assets', |
| 207 | + folder_name, |
| 208 | + ) |
211 | 209 |
|
212 | 210 |
|
213 |
| -@retry( |
214 |
| - wait=wait_fixed(5), |
215 |
| - stop=stop_after_attempt(1), |
216 |
| -) |
217 |
| -def tiny_llama_tokenizer_helper(): |
218 |
| - transformers = pytest.importorskip('transformers') |
| 211 | +@pytest.fixture(scope='session') |
| 212 | +def tokenizers_assets(): |
| 213 | + download_tokenizers_files() |
219 | 214 |
|
220 |
| - hf_tokenizer = transformers.AutoTokenizer.from_pretrained( |
221 |
| - 'huggyllama/llama-7b', |
222 |
| - use_fast=False, |
223 |
| - ) |
224 |
| - return hf_tokenizer |
225 | 215 |
|
| 216 | +def download_tokenizers_files(): |
| 217 | + """Download the tokenizers assets. |
226 | 218 |
|
227 |
| -@retry( |
228 |
| - wait=wait_fixed(5), |
229 |
| - stop=stop_after_attempt(1), |
230 |
| -) |
231 |
| -def tiny_codellama_tokenizer_helper(): |
232 |
| - transformers = pytest.importorskip('transformers') |
| 219 | + We download from github, because downloading from HF directly is flaky and gets rate limited easily. |
233 | 220 |
|
234 |
| - hf_tokenizer = transformers.AutoTokenizer.from_pretrained( |
235 |
| - 'codellama/CodeLlama-7b-hf', |
236 |
| - ) |
237 |
| - return hf_tokenizer |
| 221 | + Raises: |
| 222 | + ValueError: If the checksum of the downloaded file does not match the expected checksum. |
| 223 | + """ |
| 224 | + # Define paths |
| 225 | + tokenizers_dir = assets_path() |
238 | 226 |
|
| 227 | + if os.path.exists(tokenizers_dir): |
| 228 | + return |
239 | 229 |
|
240 |
| -@retry( |
241 |
| - wait=wait_fixed(5), |
242 |
| - stop=stop_after_attempt(1), |
243 |
| -) |
244 |
| -def tiny_neox_tokenizer_helper(): |
245 |
| - transformers = pytest.importorskip('transformers') |
| 230 | + # Create assets directory if it doesn't exist |
| 231 | + os.makedirs(tokenizers_dir, exist_ok=True) |
246 | 232 |
|
247 |
| - hf_tokenizer = transformers.AutoTokenizer.from_pretrained( |
248 |
| - 'EleutherAI/gpt-neox-20b', |
249 |
| - model_max_length=2048, |
250 |
| - ) |
251 |
| - return hf_tokenizer |
| 233 | + # URL for the tokenizers.zip file |
| 234 | + url = 'https://github.com/mosaicml/ci-testing/releases/download/tokenizers/tokenizers.zip' |
| 235 | + expected_checksum = '12dc1f254270582f7806588f1f1d47945590c5b42dee28925e5dab95f2d08075' |
252 | 236 |
|
| 237 | + # Download the zip file |
| 238 | + response = requests.get(url, stream=True) |
| 239 | + response.raise_for_status() |
253 | 240 |
|
254 |
| -@retry( |
255 |
| - wait=wait_fixed(5), |
256 |
| - stop=stop_after_attempt(1), |
257 |
| -) |
258 |
| -def tiny_t5_tokenizer_helper(): |
259 |
| - transformers = pytest.importorskip('transformers') |
| 241 | + zip_path = os.path.join(tokenizers_dir, 'tokenizers.zip') |
260 | 242 |
|
261 |
| - hf_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-base',) |
262 |
| - return hf_tokenizer |
| 243 | + # Check the checksum |
| 244 | + checksum = hashlib.sha256(response.content).hexdigest() |
| 245 | + if checksum != expected_checksum: |
| 246 | + raise ValueError( |
| 247 | + f'Checksum mismatch: expected {expected_checksum}, got {checksum}', |
| 248 | + ) |
263 | 249 |
|
| 250 | + with open(zip_path, 'wb') as f: |
| 251 | + for chunk in response.iter_content(chunk_size=8192): |
| 252 | + f.write(chunk) |
264 | 253 |
|
265 |
| -@retry( |
266 |
| - wait=wait_fixed(5), |
267 |
| - stop=stop_after_attempt(1), |
268 |
| -) |
269 |
| -def tiny_bert_tokenizer_helper(): |
270 |
| - transformers = pytest.importorskip('transformers') |
| 254 | + # Extract the zip file |
| 255 | + print(f'Extracting tokenizers.zip to {tokenizers_dir}') |
| 256 | + with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| 257 | + zip_ref.extractall(tokenizers_dir) |
271 | 258 |
|
272 |
| - return transformers.AutoTokenizer.from_pretrained( |
273 |
| - 'google-bert/bert-base-uncased', |
274 |
| - ) |
| 259 | + # Optionally remove the zip file after extraction |
| 260 | + os.remove(zip_path) |
275 | 261 |
|
276 | 262 |
|
277 |
| -@retry( |
278 |
| - wait=wait_fixed(5), |
279 |
| - stop=stop_after_attempt(1), |
280 |
| -) |
281 |
| -def tiny_mpt_tokenizer_helper(): |
| 263 | +## TOKENIZER HELPERS ## |
| 264 | +def assets_tokenizer_helper(name: str): |
| 265 | + """Load a tokenizer from the assets directory.""" |
282 | 266 | transformers = pytest.importorskip('transformers')
|
283 | 267 |
|
284 |
| - return transformers.AutoTokenizer.from_pretrained( |
285 |
| - 'mosaicml/mpt-7b', |
286 |
| - model_max_length=2048, |
287 |
| - ) |
288 |
| - |
| 268 | + download_tokenizers_files() |
289 | 269 |
|
290 |
| -@retry( |
291 |
| - wait=wait_fixed(5), |
292 |
| - stop=stop_after_attempt(1), |
293 |
| -) |
294 |
| -def tiny_mpt_chat_tokenizer_helper(): |
295 |
| - transformers = pytest.importorskip('transformers') |
| 270 | + assets_dir = assets_path() |
| 271 | + tokenizer_path = os.path.join(assets_dir, name) |
296 | 272 |
|
297 |
| - return transformers.AutoTokenizer.from_pretrained( |
298 |
| - 'mosaicml/mpt-7b-8k-chat', |
299 |
| - model_max_length=2048, |
300 |
| - ) |
| 273 | + # Load the tokenizer |
| 274 | + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path) |
| 275 | + return hf_tokenizer |
301 | 276 |
|
302 | 277 |
|
303 | 278 | ## SESSION MODELS ##
|
@@ -336,48 +311,50 @@ def _session_tiny_bert_config(): # type: ignore
|
336 | 311 |
|
337 | 312 | ## SESSION TOKENIZERS ##
|
338 | 313 | @pytest.fixture(scope='session')
|
339 |
| -def _session_tiny_gpt2_tokenizer(): # type: ignore |
340 |
| - return tiny_gpt2_tokenizer_helper() |
| 314 | +def _session_tiny_gpt2_tokenizer(tokenizers_assets): # type: ignore |
| 315 | + return assets_tokenizer_helper('gpt2') |
341 | 316 |
|
342 | 317 |
|
343 | 318 | @pytest.fixture(scope='session')
|
344 |
| -def _session_tiny_gpt2_with_pad_tokenizer(): # type: ignore |
345 |
| - return tiny_gpt2_tokenizer_helper(add_pad=True) |
| 319 | +def _session_tiny_gpt2_with_pad_tokenizer(tokenizers_assets): # type: ignore |
| 320 | + tokenizer = assets_tokenizer_helper('gpt2') |
| 321 | + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| 322 | + return tokenizer |
346 | 323 |
|
347 | 324 |
|
348 | 325 | @pytest.fixture(scope='session')
|
349 |
| -def _session_tiny_llama_tokenizer(): # type: ignore |
350 |
| - return tiny_llama_tokenizer_helper() |
| 326 | +def _session_tiny_llama_tokenizer(tokenizers_assets): # type: ignore |
| 327 | + return assets_tokenizer_helper('llama') |
351 | 328 |
|
352 | 329 |
|
353 | 330 | @pytest.fixture(scope='session')
|
354 |
| -def _session_tiny_codellama_tokenizer(): # type: ignore |
355 |
| - return tiny_codellama_tokenizer_helper() |
| 331 | +def _session_tiny_codellama_tokenizer(tokenizers_assets): # type: ignore |
| 332 | + return assets_tokenizer_helper('codellama') |
356 | 333 |
|
357 | 334 |
|
358 | 335 | @pytest.fixture(scope='session')
|
359 |
| -def _session_tiny_neox_tokenizer(): # type: ignore |
360 |
| - return tiny_neox_tokenizer_helper() |
| 336 | +def _session_tiny_neox_tokenizer(tokenizers_assets): # type: ignore |
| 337 | + return assets_tokenizer_helper('neox') |
361 | 338 |
|
362 | 339 |
|
363 | 340 | @pytest.fixture(scope='session')
|
364 |
| -def _session_tiny_t5_tokenizer(): # type: ignore |
365 |
| - return tiny_t5_tokenizer_helper() |
| 341 | +def _session_tiny_t5_tokenizer(tokenizers_assets): # type: ignore |
| 342 | + return assets_tokenizer_helper('t5') |
366 | 343 |
|
367 | 344 |
|
368 | 345 | @pytest.fixture(scope='session')
|
369 |
| -def _session_tiny_bert_tokenizer(): # type: ignore |
370 |
| - return tiny_bert_tokenizer_helper() |
| 346 | +def _session_tiny_bert_tokenizer(tokenizers_assets): # type: ignore |
| 347 | + return assets_tokenizer_helper('bertt') |
371 | 348 |
|
372 | 349 |
|
373 | 350 | @pytest.fixture(scope='session')
|
374 |
| -def _session_tiny_mpt_tokenizer(): # type: ignore |
375 |
| - return tiny_mpt_tokenizer_helper() |
| 351 | +def _session_tiny_mpt_tokenizer(tokenizers_assets): # type: ignore |
| 352 | + return assets_tokenizer_helper('mptt') |
376 | 353 |
|
377 | 354 |
|
378 | 355 | @pytest.fixture(scope='session')
|
379 |
| -def _session_tiny_mpt_chat_tokenizer(): # type: ignore |
380 |
| - return tiny_mpt_chat_tokenizer_helper() |
| 356 | +def _session_tiny_mpt_chat_tokenizer(tokenizers_assets): # type: ignore |
| 357 | + return assets_tokenizer_helper('mptct') |
381 | 358 |
|
382 | 359 |
|
383 | 360 | ## MODEL FIXTURES ##
|
|
0 commit comments