|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import hashlib |
| 4 | +import os |
1 | 5 | import sys
|
| 6 | +import tarfile |
| 7 | +import tempfile |
| 8 | +import urllib.request |
2 | 9 | from functools import lru_cache
|
3 |
| -from typing import List, Tuple |
4 |
| - |
5 |
| -if sys.version_info < (3, 8): |
6 |
| - from typing_extensions import Final # pragma: no cover |
7 |
| -else: |
8 |
| - from typing import Final |
| 10 | +from typing import Any, Final, List, Tuple |
9 | 11 |
|
10 | 12 | import nltk
|
11 | 13 | from nltk import pos_tag as _pos_tag
|
|
14 | 16 |
|
15 | 17 | CACHE_MAX_SIZE: Final[int] = 128
|
16 | 18 |
|
| 19 | +NLTK_DATA_URL = "https://utic-public-cf.s3.amazonaws.com/nltk_data.tgz" |
| 20 | +NLTK_DATA_SHA256 = "126faf671cd255a062c436b3d0f2d311dfeefcd92ffa43f7c3ab677309404d61" |
| 21 | + |
| 22 | + |
| 23 | +def _raise_on_nltk_download(*args: Any, **kwargs: Any): |
| 24 | + raise ValueError("NLTK download disabled. See CVE-2024-39705") |
| 25 | + |
| 26 | + |
| 27 | +nltk.download = _raise_on_nltk_download |
| 28 | + |
| 29 | + |
| 30 | +# NOTE(robinson) - mimic default dir logic from NLTK |
| 31 | +# https://github.com/nltk/nltk/ |
| 32 | +# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046 |
| 33 | +def get_nltk_data_dir() -> str | None: |
| 34 | + """Locates the directory the nltk data will be saved too. The directory |
| 35 | + set by the NLTK environment variable takes highest precedence. Otherwise |
| 36 | + the default is determined by the rules indicated below. Returns None when |
| 37 | + the directory is not writable. |
| 38 | +
|
| 39 | + On Windows, the default download directory is |
| 40 | + ``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the |
| 41 | + directory containing Python, e.g. ``C:\\Python311``. |
| 42 | +
|
| 43 | + On all other platforms, the default directory is the first of |
| 44 | + the following which exists or which can be created with write |
| 45 | + permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``, |
| 46 | + ``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``. |
| 47 | + """ |
| 48 | + # Check if we are on GAE where we cannot write into filesystem. |
| 49 | + if "APPENGINE_RUNTIME" in os.environ: |
| 50 | + return |
| 51 | + |
| 52 | + # Check if we have sufficient permissions to install in a |
| 53 | + # variety of system-wide locations. |
| 54 | + for nltkdir in nltk.data.path: |
| 55 | + if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir): |
| 56 | + return nltkdir |
| 57 | + |
| 58 | + # On Windows, use %APPDATA% |
| 59 | + if sys.platform == "win32" and "APPDATA" in os.environ: |
| 60 | + homedir = os.environ["APPDATA"] |
| 61 | + |
| 62 | + # Otherwise, install in the user's home directory. |
| 63 | + else: |
| 64 | + homedir = os.path.expanduser("~/") |
| 65 | + if homedir == "~/": |
| 66 | + raise ValueError("Could not find a default download directory") |
| 67 | + |
| 68 | + # NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already |
| 69 | + # present in the tar file so we don't have to do that here. |
| 70 | + return homedir |
| 71 | + |
| 72 | + |
| 73 | +def download_nltk_packages(): |
| 74 | + nltk_data_dir = get_nltk_data_dir() |
| 75 | + |
| 76 | + if nltk_data_dir is None: |
| 77 | + raise OSError("NLTK data directory does not exist or is not writable.") |
| 78 | + |
| 79 | + def sha256_checksum(filename: str, block_size: int = 65536): |
| 80 | + sha256 = hashlib.sha256() |
| 81 | + with open(filename, "rb") as f: |
| 82 | + for block in iter(lambda: f.read(block_size), b""): |
| 83 | + sha256.update(block) |
| 84 | + return sha256.hexdigest() |
| 85 | + |
| 86 | + with tempfile.NamedTemporaryFile() as tmp_file: |
| 87 | + tgz_file = tmp_file.name |
| 88 | + urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file) |
| 89 | + |
| 90 | + file_hash = sha256_checksum(tgz_file) |
| 91 | + if file_hash != NLTK_DATA_SHA256: |
| 92 | + os.remove(tgz_file) |
| 93 | + raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}") |
| 94 | + |
| 95 | + # Extract the contents |
| 96 | + if not os.path.exists(nltk_data_dir): |
| 97 | + os.makedirs(nltk_data_dir) |
| 98 | + |
| 99 | + with tarfile.open(tgz_file, "r:gz") as tar: |
| 100 | + tar.extractall(path=nltk_data_dir) |
| 101 | + |
| 102 | + |
| 103 | +def check_for_nltk_package(package_name: str, package_category: str) -> bool: |
| 104 | + """Checks to see if the specified NLTK package exists on the file system""" |
| 105 | + paths: list[str] = [] |
| 106 | + for path in nltk.data.path: |
| 107 | + if not path.endswith("nltk_data"): |
| 108 | + path = os.path.join(path, "nltk_data") |
| 109 | + paths.append(path) |
17 | 110 |
|
18 |
| -def _download_nltk_package_if_not_present(package_name: str, package_category: str): |
19 |
| - """If the required nlt package is not present, download it.""" |
20 | 111 | try:
|
21 |
| - nltk.find(f"{package_category}/{package_name}") |
| 112 | + nltk.find(f"{package_category}/{package_name}", paths=paths) |
| 113 | + return True |
22 | 114 | except LookupError:
|
23 |
| - nltk.download(package_name) |
| 115 | + return False |
| 116 | + |
| 117 | + |
| 118 | +def _download_nltk_packages_if_not_present(): |
| 119 | + """If required NLTK packages are not available, download them.""" |
| 120 | + |
| 121 | + tagger_available = check_for_nltk_package( |
| 122 | + package_category="taggers", |
| 123 | + package_name="averaged_perceptron_tagger", |
| 124 | + ) |
| 125 | + tokenizer_available = check_for_nltk_package( |
| 126 | + package_category="tokenizers", package_name="punkt" |
| 127 | + ) |
| 128 | + |
| 129 | + if not (tokenizer_available and tagger_available): |
| 130 | + download_nltk_packages() |
24 | 131 |
|
25 | 132 |
|
26 | 133 | @lru_cache(maxsize=CACHE_MAX_SIZE)
|
27 | 134 | def sent_tokenize(text: str) -> List[str]:
|
28 | 135 | """A wrapper around the NLTK sentence tokenizer with LRU caching enabled."""
|
29 |
| - _download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt") |
| 136 | + _download_nltk_packages_if_not_present() |
30 | 137 | return _sent_tokenize(text)
|
31 | 138 |
|
32 | 139 |
|
33 | 140 | @lru_cache(maxsize=CACHE_MAX_SIZE)
|
34 | 141 | def word_tokenize(text: str) -> List[str]:
|
35 | 142 | """A wrapper around the NLTK word tokenizer with LRU caching enabled."""
|
36 |
| - _download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt") |
| 143 | + _download_nltk_packages_if_not_present() |
37 | 144 | return _word_tokenize(text)
|
38 | 145 |
|
39 | 146 |
|
40 | 147 | @lru_cache(maxsize=CACHE_MAX_SIZE)
|
41 | 148 | def pos_tag(text: str) -> List[Tuple[str, str]]:
|
42 | 149 | """A wrapper around the NLTK POS tagger with LRU caching enabled."""
|
43 |
| - _download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt") |
44 |
| - _download_nltk_package_if_not_present( |
45 |
| - package_category="taggers", |
46 |
| - package_name="averaged_perceptron_tagger", |
47 |
| - ) |
| 150 | + _download_nltk_packages_if_not_present() |
48 | 151 | # NOTE(robinson) - Splitting into sentences before tokenizing. The helps with
|
49 | 152 | # situations like "ITEM 1A. PROPERTIES" where "PROPERTIES" can be mistaken
|
50 | 153 | # for a verb because it looks like it's in verb form an "ITEM 1A." looks like the subject.
|
51 | 154 | sentences = _sent_tokenize(text)
|
52 |
| - parts_of_speech = [] |
| 155 | + parts_of_speech: list[tuple[str, str]] = [] |
53 | 156 | for sentence in sentences:
|
54 | 157 | tokens = _word_tokenize(sentence)
|
55 | 158 | parts_of_speech.extend(_pos_tag(tokens))
|
|
0 commit comments