Skip to content

Commit b1cce16

Browse files
authored
feat: translate_text cleaning brick (#101)
* initial implementation for translate brick * more input validation * tests for translate brick * added docs * bumped version * chinese and arabic tests * re-run pip-compile * add torch to dependencies * cleanup doc string * fix long string * fix typo in docs * take out empty string check * return string if string is empty * added huggingface into make install
1 parent 1700d4d commit b1cce16

File tree

8 files changed

+203
-5
lines changed

8 files changed

+203
-5
lines changed

Diff for: CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.3.2-dev0
2+
3+
* Added `translate_text` brick for translating text between languages
4+
15
## 0.3.1
26

37
* Added \_\_init.py\_\_ to `partition`

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ install-base: install-base-pip-packages install-nltk-models
1717

1818
## install: installs all test, dev, and experimental requirements
1919
.PHONY: install
20-
install: install-base-pip-packages install-dev install-nltk-models install-test
20+
install: install-base-pip-packages install-dev install-nltk-models install-test install-huggingface
2121

2222
.PHONY: install-ci
2323
install-ci: install-base-pip-packages install-test install-nltk-models install-huggingface

Diff for: docs/source/bricks.rst

+31
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,37 @@ Examples:
447447
extract_text_after(text, r"SPEAKER \d{1}:")
448448
449449
450+
``translate_text``
451+
------------------
452+
453+
The ``translate_text`` cleaning bricks translates text between languages. ``translate_text``
454+
uses the `Helsinki NLP MT models <https://huggingface.co/Helsinki-NLP>`_ from
455+
``transformers`` for machine translation. Works for Russian, Chinese, Arabic, and many
456+
other languages.
457+
458+
Parameters:
459+
460+
* ``text``: the input string to translate.
461+
* ``source_lang``: the two letter language code for the source language of the text.
462+
If ``source_lang`` is not specified,
463+
the language will be detected using ``langdetect``.
464+
* ``target_lang``: the two letter language code for the target language for translation.
465+
Defaults to ``"en"``.
466+
467+
468+
Examples:
469+
470+
.. code:: python
471+
472+
from unstructured.cleaners.translate import translate_text
473+
474+
# Output is "I'm a Berliner!"
475+
translate_text("Ich bin ein Berliner!")
476+
477+
# Output is "I can also translate Russian!"
478+
translate_text("Я тоже можно переводать русский язык!", "ru", "en")
479+
480+
450481
#######
451482
Staging
452483
#######

Diff for: requirements/huggingface.txt

+21-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ certifi==2022.9.24
1515
charset-normalizer==2.1.1
1616
# via requests
1717
click==8.1.3
18-
# via nltk
18+
# via
19+
# nltk
20+
# sacremoses
1921
deprecated==1.2.13
2022
# via argilla
2123
filelock==3.8.2
@@ -35,7 +37,11 @@ idna==3.4
3537
# requests
3638
# rfc3986
3739
joblib==1.2.0
38-
# via nltk
40+
# via
41+
# nltk
42+
# sacremoses
43+
langdetect==1.0.9
44+
# via unstructured (setup.py)
3945
lxml==4.9.1
4046
# via unstructured (setup.py)
4147
monotonic==1.6
@@ -69,33 +75,45 @@ pyyaml==6.0
6975
regex==2022.10.31
7076
# via
7177
# nltk
78+
# sacremoses
7279
# transformers
7380
requests==2.28.1
7481
# via
7582
# huggingface-hub
7683
# transformers
7784
rfc3986[idna2008]==1.5.0
7885
# via httpx
86+
sacremoses==0.0.53
87+
# via unstructured (setup.py)
88+
sentencepiece==0.1.97
89+
# via unstructured (setup.py)
7990
six==1.16.0
80-
# via python-dateutil
91+
# via
92+
# langdetect
93+
# python-dateutil
94+
# sacremoses
8195
sniffio==1.3.0
8296
# via
8397
# httpcore
8498
# httpx
8599
tokenizers==0.13.2
86100
# via transformers
101+
torch==1.13.0
102+
# via unstructured (setup.py)
87103
tqdm==4.64.1
88104
# via
89105
# argilla
90106
# huggingface-hub
91107
# nltk
108+
# sacremoses
92109
# transformers
93110
transformers==4.23.1
94111
# via unstructured (setup.py)
95112
typing-extensions==4.4.0
96113
# via
97114
# huggingface-hub
98115
# pydantic
116+
# torch
99117
urllib3==1.26.13
100118
# via requests
101119
wrapt==1.13.3

Diff for: setup.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
],
5555
extras_require={
5656
"huggingface": [
57+
"langdetect",
58+
"sacremoses",
59+
"sentencepiece",
60+
"torch",
5761
"transformers",
5862
],
5963
},

Diff for: test_unstructured/cleaners/test_translate.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
3+
import unstructured.cleaners.translate as translate
4+
5+
6+
def test_get_opus_mt_model_name():
7+
model_name = translate._get_opus_mt_model_name("ru", "en")
8+
assert model_name == "Helsinki-NLP/opus-mt-ru-en"
9+
10+
11+
@pytest.mark.parametrize("code", ["way-too-long", "a", "", None])
12+
def test_validate_language_code(code):
13+
with pytest.raises(ValueError):
14+
translate._validate_language_code(code)
15+
16+
17+
def test_translate_returns_same_text_if_dest_is_same():
18+
text = "This is already in English!"
19+
assert translate.translate_text(text, "en", "en") == text
20+
21+
22+
def test_translate_returns_same_text_text_is_empty():
23+
text = " "
24+
assert translate.translate_text(text) == text
25+
26+
27+
def test_translate_with_language_specified():
28+
text = "Ich bin ein Berliner!"
29+
assert translate.translate_text(text, "de") == "I'm a Berliner!"
30+
31+
32+
def test_translate_with_no_language_specified():
33+
text = "Ich bin ein Berliner!"
34+
assert translate.translate_text(text) == "I'm a Berliner!"
35+
36+
37+
def test_translate_raises_with_bad_language():
38+
text = "Ich bin ein Berliner!"
39+
with pytest.raises(ValueError):
40+
translate.translate_text(text, "zz")
41+
42+
43+
def test_tranlate_works_with_russian():
44+
text = "Я тоже можно переводать русский язык!"
45+
assert translate.translate_text(text) == "I can also translate Russian!"
46+
47+
48+
def test_translate_works_with_chinese():
49+
text = "網站有中、英文版本"
50+
translate.translate_text(text) == "Website available in Chinese and English"
51+
52+
53+
def translate_works_with_arabic():
54+
text = "مرحباً بكم في متجرنا"
55+
translate.translate_text(text) == "Welcome to our store."

Diff for: unstructured/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.1" # pragma: no cover
1+
__version__ = "0.3.2-dev0" # pragma: no cover

Diff for: unstructured/cleaners/translate.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import List, Optional
2+
import warnings
3+
4+
import langdetect
5+
from transformers import MarianMTModel, MarianTokenizer
6+
7+
from unstructured.staging.huggingface import chunk_by_attention_window
8+
from unstructured.nlp.tokenize import sent_tokenize
9+
10+
11+
def _get_opus_mt_model_name(source_lang: str, target_lang: str):
12+
"""Constructs the name of the MarianMT machine translation model based on the
13+
source and target language."""
14+
return f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
15+
16+
17+
def _validate_language_code(language_code: str):
18+
if not isinstance(language_code, str) or len(language_code) != 2:
19+
raise ValueError(
20+
f"Invalid language code: {language_code}. Language codes must be two letter strings."
21+
)
22+
23+
24+
def translate_text(text, source_lang: Optional[str] = None, target_lang: str = "en") -> str:
25+
"""Translates the foreign language text. If the source language is not specified, the
26+
function will attempt to detect it using langdetect.
27+
28+
Parameters
29+
----------
30+
text: str
31+
The text to translate
32+
target_lang: str
33+
The two letter language code for the target langague. Defaults to "en".
34+
source_lang: Optional[str]
35+
The two letter language code for the language of the input text. If source_lang is
36+
not provided, the function will try to detect it.
37+
"""
38+
if text.strip() == "":
39+
return text
40+
41+
_source_lang: str = source_lang if source_lang is not None else langdetect.detect(text)
42+
# NOTE(robinson) - Chinese gets detected with codes zh-cn, zh-tw, zh-hk for various
43+
# Chinese variants. We normalizes these because there is a single model for Chinese
44+
# machine translation
45+
if _source_lang.startswith("zh"):
46+
_source_lang = "zh"
47+
48+
_validate_language_code(target_lang)
49+
_validate_language_code(_source_lang)
50+
51+
if target_lang == _source_lang:
52+
return text
53+
54+
model_name = _get_opus_mt_model_name(_source_lang, target_lang)
55+
56+
try:
57+
tokenizer = MarianTokenizer.from_pretrained(model_name)
58+
model = MarianMTModel.from_pretrained(model_name)
59+
except OSError:
60+
raise ValueError(
61+
f"Transformers could not find the translation model {model_name}. "
62+
"The requested source/target language combo is not suppored."
63+
)
64+
65+
chunks: List[str] = chunk_by_attention_window(text, tokenizer, split_function=sent_tokenize)
66+
67+
translated_chunks: List[str] = list()
68+
for chunk in chunks:
69+
translated_chunks.append(_translate_text(text, model, tokenizer))
70+
71+
return " ".join(translated_chunks)
72+
73+
74+
def _translate_text(text, model, tokenizer):
75+
"""Translates text using the specified model and tokenizer."""
76+
# NOTE(robinson) - Suppresses the HuggingFace UserWarning resulting from the "max_length"
77+
# key in the MarianMT config. The warning states that "max_length" will be deprecated
78+
# in transformers v5
79+
with warnings.catch_warnings():
80+
warnings.simplefilter("ignore")
81+
translated = model.generate(
82+
**tokenizer([text], return_tensors="pt", padding="max_length", max_length=512)
83+
)
84+
return [tokenizer.decode(t, max_new_tokens=512, skip_special_tokens=True) for t in translated][
85+
0
86+
]

0 commit comments

Comments
 (0)