Skip to content

Commit 7b25dfc

Browse files
MthwRobinsonscanny
andauthored
fix(CVE-2024-39705): remove nltk download (#3361)
### Summary Addresses [CVE-2024-39705](https://nvd.nist.gov/vuln/detail/CVE-2024-39705), which highlights the risk of remote code execution when running `nltk.download` . Removes `nltk.download` in favor of a `.tgz` file with the appropriate NLTK data files and checking the SHA256 hash to validate the download. An error now raises if `nltk.download` is invoked. The logic for determining the NLTK download directory is borrowed from `nltk`, so users can still set `NLTK_DATA` as they did previously. ### Testing 1. Create a directory called `~/tmp/nltk_test`. Set `NLTK_DATA=${HOME}/tmp/nltk_test`. 2. From a python interactive session, run: ```python from unstructured.nlp.tokenize import download_nltk_packages download_nltk_packages() ``` 3. Run `ls /tmp/nltk_test/nltk_data`. You should see the downloaded data. --------- Co-authored-by: Steve Canny <[email protected]>
1 parent d48fa3b commit 7b25dfc

File tree

12 files changed

+179
-27
lines changed

12 files changed

+179
-27
lines changed

.github/workflows/ci.yml

+2
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ jobs:
256256
matrix:
257257
python-version: [ "3.9","3.10" ]
258258
runs-on: ubuntu-latest
259+
env:
260+
NLTK_DATA: ${{ github.workspace }}/nltk_data
259261
needs: [ setup_ingest, lint ]
260262
steps:
261263
# actions/checkout MUST come before auth

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## 0.14.10-dev13
1+
## 0.14.10
22

33
### Enhancements
44

@@ -14,6 +14,7 @@
1414

1515
* **Fix counting false negatives and false positives in table structure evaluation**
1616
* **Fix Slack CI test** Change channel that Slack test is pointing to because previous test bot expired
17+
* **Remove NLTK download** Removes `nltk.download` in favor of downloading from an S3 bucket we host to mitigate CVE-2024-39705
1718

1819
## 0.14.9
1920

Dockerfile

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM quay.io/unstructured-io/base-images:wolfi-base-d46498e@sha256:3db0544df1d8d9989cd3c3b28670d8b81351dfdc1d9129004c71ff05996fd51e as base
1+
FROM quay.io/unstructured-io/base-images:wolfi-base-e48da6b@sha256:8ad3479e5dc87a86e4794350cca6385c01c6d110902c5b292d1a62e231be711b as base
22

33
USER root
44

@@ -18,8 +18,7 @@ USER notebook-user
1818

1919
RUN find requirements/ -type f -name "*.txt" -exec pip3.11 install --no-cache-dir --user -r '{}' ';' && \
2020
pip3.11 install unstructured.paddlepaddle && \
21-
python3.11 -c "import nltk; nltk.download('punkt')" && \
22-
python3.11 -c "import nltk; nltk.download('averaged_perceptron_tagger')" && \
21+
python3.11 -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" && \
2322
python3.11 -c "from unstructured.partition.model_init import initialize; initialize()" && \
2423
python3.11 -c "from unstructured_inference.models.tables import UnstructuredTableTransformerModel; model = UnstructuredTableTransformerModel(); model.initialize('microsoft/table-transformer-structure-recognition')"
2524

test_unstructured/nlp/test_tokenize.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,28 @@
22
from unittest.mock import patch
33

44
import nltk
5+
import pytest
56

67
from test_unstructured.nlp.mock_nltk import mock_sent_tokenize, mock_word_tokenize
78
from unstructured.nlp import tokenize
89

910

11+
def test_error_raised_on_nltk_download():
12+
with pytest.raises(ValueError):
13+
tokenize.nltk.download("tokenizers/punkt")
14+
15+
1016
def test_nltk_packages_download_if_not_present():
1117
with patch.object(nltk, "find", side_effect=LookupError):
12-
with patch.object(nltk, "download") as mock_download:
13-
tokenize._download_nltk_package_if_not_present("fake_package", "tokenizers")
18+
with patch.object(tokenize, "download_nltk_packages") as mock_download:
19+
tokenize._download_nltk_packages_if_not_present()
1420

15-
mock_download.assert_called_with("fake_package")
21+
mock_download.assert_called_once()
1622

1723

1824
def test_nltk_packages_do_not_download_if():
1925
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
20-
tokenize._download_nltk_package_if_not_present("fake_package", "tokenizers")
26+
tokenize._download_nltk_packages_if_not_present()
2127

2228
mock_download.assert_not_called()
2329

typings/nltk/__init__.pyi

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
from nltk import data, internals
4+
from nltk.data import find
5+
from nltk.downloader import download
6+
from nltk.tag import pos_tag
7+
from nltk.tokenize import sent_tokenize, word_tokenize
8+
9+
__all__ = [
10+
"data",
11+
"download",
12+
"find",
13+
"internals",
14+
"pos_tag",
15+
"sent_tokenize",
16+
"word_tokenize",
17+
]

typings/nltk/data.pyi

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
from typing import Sequence
4+
5+
path: list[str]
6+
7+
def find(resource_name: str, paths: Sequence[str] | None = None) -> str: ...

typings/nltk/downloader.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
5+
download: Callable[..., bool]

typings/nltk/internals.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import annotations
2+
3+
def is_writable(path: str) -> bool: ...

typings/nltk/tag.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
def pos_tag(
4+
tokens: list[str], tagset: str | None = None, lang: str = "eng"
5+
) -> list[tuple[str, str]]: ...

typings/nltk/tokenize.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from __future__ import annotations
2+
3+
def sent_tokenize(text: str, language: str = ...) -> list[str]: ...
4+
def word_tokenize(text: str, language: str = ..., preserve_line: bool = ...) -> list[str]: ...

unstructured/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.14.10-dev13" # pragma: no cover
1+
__version__ = "0.14.10" # pragma: no cover

unstructured/nlp/tokenize.py

+121-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from __future__ import annotations
2+
3+
import hashlib
4+
import os
15
import sys
6+
import tarfile
7+
import tempfile
8+
import urllib.request
29
from functools import lru_cache
3-
from typing import List, Tuple
4-
5-
if sys.version_info < (3, 8):
6-
from typing_extensions import Final # pragma: no cover
7-
else:
8-
from typing import Final
10+
from typing import Any, Final, List, Tuple
911

1012
import nltk
1113
from nltk import pos_tag as _pos_tag
@@ -14,42 +16,143 @@
1416

1517
CACHE_MAX_SIZE: Final[int] = 128
1618

19+
NLTK_DATA_URL = "https://utic-public-cf.s3.amazonaws.com/nltk_data.tgz"
20+
NLTK_DATA_SHA256 = "126faf671cd255a062c436b3d0f2d311dfeefcd92ffa43f7c3ab677309404d61"
21+
22+
23+
def _raise_on_nltk_download(*args: Any, **kwargs: Any):
24+
raise ValueError("NLTK download disabled. See CVE-2024-39705")
25+
26+
27+
nltk.download = _raise_on_nltk_download
28+
29+
30+
# NOTE(robinson) - mimic default dir logic from NLTK
31+
# https://github.com/nltk/nltk/
32+
# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046
33+
def get_nltk_data_dir() -> str | None:
34+
"""Locates the directory the nltk data will be saved too. The directory
35+
set by the NLTK environment variable takes highest precedence. Otherwise
36+
the default is determined by the rules indicated below. Returns None when
37+
the directory is not writable.
38+
39+
On Windows, the default download directory is
40+
``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the
41+
directory containing Python, e.g. ``C:\\Python311``.
42+
43+
On all other platforms, the default directory is the first of
44+
the following which exists or which can be created with write
45+
permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``,
46+
``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``.
47+
"""
48+
# Check if we are on GAE where we cannot write into filesystem.
49+
if "APPENGINE_RUNTIME" in os.environ:
50+
return
51+
52+
# Check if we have sufficient permissions to install in a
53+
# variety of system-wide locations.
54+
for nltkdir in nltk.data.path:
55+
if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir):
56+
return nltkdir
57+
58+
# On Windows, use %APPDATA%
59+
if sys.platform == "win32" and "APPDATA" in os.environ:
60+
homedir = os.environ["APPDATA"]
61+
62+
# Otherwise, install in the user's home directory.
63+
else:
64+
homedir = os.path.expanduser("~/")
65+
if homedir == "~/":
66+
raise ValueError("Could not find a default download directory")
67+
68+
# NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already
69+
# present in the tar file so we don't have to do that here.
70+
return homedir
71+
72+
73+
def download_nltk_packages():
74+
nltk_data_dir = get_nltk_data_dir()
75+
76+
if nltk_data_dir is None:
77+
raise OSError("NLTK data directory does not exist or is not writable.")
78+
79+
def sha256_checksum(filename: str, block_size: int = 65536):
80+
sha256 = hashlib.sha256()
81+
with open(filename, "rb") as f:
82+
for block in iter(lambda: f.read(block_size), b""):
83+
sha256.update(block)
84+
return sha256.hexdigest()
85+
86+
with tempfile.NamedTemporaryFile() as tmp_file:
87+
tgz_file = tmp_file.name
88+
urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file)
89+
90+
file_hash = sha256_checksum(tgz_file)
91+
if file_hash != NLTK_DATA_SHA256:
92+
os.remove(tgz_file)
93+
raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}")
94+
95+
# Extract the contents
96+
if not os.path.exists(nltk_data_dir):
97+
os.makedirs(nltk_data_dir)
98+
99+
with tarfile.open(tgz_file, "r:gz") as tar:
100+
tar.extractall(path=nltk_data_dir)
101+
102+
103+
def check_for_nltk_package(package_name: str, package_category: str) -> bool:
104+
"""Checks to see if the specified NLTK package exists on the file system"""
105+
paths: list[str] = []
106+
for path in nltk.data.path:
107+
if not path.endswith("nltk_data"):
108+
path = os.path.join(path, "nltk_data")
109+
paths.append(path)
17110

18-
def _download_nltk_package_if_not_present(package_name: str, package_category: str):
19-
"""If the required nlt package is not present, download it."""
20111
try:
21-
nltk.find(f"{package_category}/{package_name}")
112+
nltk.find(f"{package_category}/{package_name}", paths=paths)
113+
return True
22114
except LookupError:
23-
nltk.download(package_name)
115+
return False
116+
117+
118+
def _download_nltk_packages_if_not_present():
119+
"""If required NLTK packages are not available, download them."""
120+
121+
tagger_available = check_for_nltk_package(
122+
package_category="taggers",
123+
package_name="averaged_perceptron_tagger",
124+
)
125+
tokenizer_available = check_for_nltk_package(
126+
package_category="tokenizers", package_name="punkt"
127+
)
128+
129+
if not (tokenizer_available and tagger_available):
130+
download_nltk_packages()
24131

25132

26133
@lru_cache(maxsize=CACHE_MAX_SIZE)
27134
def sent_tokenize(text: str) -> List[str]:
28135
"""A wrapper around the NLTK sentence tokenizer with LRU caching enabled."""
29-
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
136+
_download_nltk_packages_if_not_present()
30137
return _sent_tokenize(text)
31138

32139

33140
@lru_cache(maxsize=CACHE_MAX_SIZE)
34141
def word_tokenize(text: str) -> List[str]:
35142
"""A wrapper around the NLTK word tokenizer with LRU caching enabled."""
36-
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
143+
_download_nltk_packages_if_not_present()
37144
return _word_tokenize(text)
38145

39146

40147
@lru_cache(maxsize=CACHE_MAX_SIZE)
41148
def pos_tag(text: str) -> List[Tuple[str, str]]:
42149
"""A wrapper around the NLTK POS tagger with LRU caching enabled."""
43-
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
44-
_download_nltk_package_if_not_present(
45-
package_category="taggers",
46-
package_name="averaged_perceptron_tagger",
47-
)
150+
_download_nltk_packages_if_not_present()
48151
# NOTE(robinson) - Splitting into sentences before tokenizing. The helps with
49152
# situations like "ITEM 1A. PROPERTIES" where "PROPERTIES" can be mistaken
50153
# for a verb because it looks like it's in verb form an "ITEM 1A." looks like the subject.
51154
sentences = _sent_tokenize(text)
52-
parts_of_speech = []
155+
parts_of_speech: list[tuple[str, str]] = []
53156
for sentence in sentences:
54157
tokens = _word_tokenize(sentence)
55158
parts_of_speech.extend(_pos_tag(tokens))

0 commit comments

Comments
 (0)