Skip to content

Commit ad49c05

Browse files
sharanvamsifacebook-github-bot
authored andcommitted
Updated method and encoder params to accept strings
Summary: Changed functional.py to accept strings for method and encoder params. This is addressing the comments from D76369496 Added enum type checking to be done in functional for encoder and method checking in encode_text_strategy. Updated all dependant files as well. Reviewed By: bclarkson-code Differential Revision: D76706716 fbshipit-source-id: ed7431a6b335466858ef454490e9aa7c1ac0573f
1 parent ec9d07e commit ad49c05

File tree

10 files changed

+71
-111
lines changed

10 files changed

+71
-111
lines changed

augly/tests/assets/expected_metadata/text_tests/expected_metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
"n": 1,
9696
"p": 1.0,
9797
"encoder": "base64",
98-
"method": "sentence"
98+
"granularity": "all"
9999
}
100100
],
101101
"get_baseline": [

augly/tests/text_tests/functional_unit_test.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
import unittest
1111

1212
from augly import text as txtaugs
13-
from augly.text.augmenters.utils import Encoding
1413
from augly.utils import FUN_FONTS_GREEK_PATH
15-
from nlpaug.util import Method
1614

1715

1816
class FunctionalTextUnitTest(unittest.TestCase):
@@ -40,22 +38,16 @@ def test_apply_lambda(self) -> None:
4038

4139
def test_base64_sentence(self) -> None:
4240
augmented_words = txtaugs.encode_text(
43-
"Hello, world!", 1, 1, 1.0, Method.SENTENCE, Encoding.BASE64
41+
"Hello, world!", 1, 1, 1.0, "all", "base64"
4442
)
4543
self.assertEqual(augmented_words[0], "SGVsbG8sIHdvcmxkIQ==")
4644

4745
def test_base64_word(self) -> None:
4846
augmented_words_word = txtaugs.encode_text(
49-
"Hello, world!", 1, 1, 1.0, Method.WORD, Encoding.BASE64
47+
"Hello, world!", 1, 1, 1.0, "word", "base64"
5048
)
5149
self.assertEqual(augmented_words_word[0], "SGVsbG8=, world!")
5250

53-
def test_base64_char(self) -> None:
54-
augmented_words_char = txtaugs.encode_text(
55-
"Hello, world!", 1, 1, 1.0, Method.CHAR, Encoding.BASE64
56-
)
57-
self.assertEqual(augmented_words_char[0], "SA==ello LA== dw==orld IQ==")
58-
5951
def test_change_case(self) -> None:
6052
augmented_words = txtaugs.change_case(self.texts[0], cadence=3.0, case="upper")
6153
self.assertTrue(
@@ -274,13 +266,13 @@ def test_insert_zero_width_chars(self) -> None:
274266

275267
def test_leetspeak_sentence(self) -> None:
276268
augmented_words = txtaugs.encode_text(
277-
"Hello, world!", 1, 1, 1.0, Method.SENTENCE, Encoding.LEETSPEAK
269+
"Hello, world!", 1, 1, 1.0, "all", "leetspeak"
278270
)
279271
self.assertEqual(augmented_words[0], "h3110, w0r1d!")
280272

281273
def test_leetspeak_word(self) -> None:
282274
augmented_words = txtaugs.encode_text(
283-
"Hello, world!", 1, 1, 1.0, Method.WORD, Encoding.LEETSPEAK
275+
"Hello, world!", 1, 1, 1.0, "word", "leetspeak"
284276
)
285277
self.assertEqual(augmented_words[0], "h3110, world!")
286278

augly/tests/text_tests/transforms_unit_test.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from typing import Any, Dict, List
1515

1616
from augly import text as txtaugs
17-
from augly.text.augmenters.utils import Encoding
1817
from augly.utils import TEXT_METADATA_PATH
19-
from nlpaug.util import Method
2018

2119

2220
def are_equal_metadata(
@@ -143,17 +141,15 @@ def test_Base64_Sentence(self) -> None:
143141
aug_min=1,
144142
aug_max=1,
145143
aug_p=1.0,
146-
method=Method.SENTENCE,
147-
encoder=Encoding.BASE64,
144+
granularity="all",
145+
encoder="base64",
148146
n=1,
149147
p=1.0,
150148
)(
151149
["Hello, world!"],
152150
metadata=self.metadata,
153151
)
154-
155152
self.assertTrue(augmented_text[0] == "SGVsbG8sIHdvcmxkIQ==")
156-
self.expected_metadata["encode_text"][0]["encoder"] = Encoding.BASE64
157153
self.assertTrue(
158154
are_equal_metadata(self.metadata, self.expected_metadata["encode_text"])
159155
)
@@ -165,8 +161,8 @@ def test_Base64_Word(self) -> None:
165161
aug_min=1,
166162
aug_max=1,
167163
aug_p=1.0,
168-
method=Method.WORD,
169-
encoder=Encoding.BASE64,
164+
granularity="word",
165+
encoder="base64",
170166
n=1,
171167
p=1.0,
172168
)(
@@ -176,32 +172,9 @@ def test_Base64_Word(self) -> None:
176172
self.assertEqual(augmented_text[0], "SGVsbG8=, world!")
177173

178174
metadata_expected = deepcopy(self.expected_metadata["encode_text"])
179-
metadata_expected[0]["method"] = "word"
180-
metadata_expected[0]["encoder"] = Encoding.BASE64
175+
metadata_expected[0]["granularity"] = "word"
181176
self.assertTrue(are_equal_metadata(self.metadata, metadata_expected))
182177

183-
def test_Base64_Char(self) -> None:
184-
self.metadata = []
185-
186-
augmented_text = txtaugs.EncodeTextTransform(
187-
aug_min=1,
188-
aug_max=1,
189-
aug_p=1.0,
190-
method=Method.CHAR,
191-
encoder=Encoding.BASE64,
192-
n=1,
193-
p=1.0,
194-
)(
195-
["Hello, world!"],
196-
metadata=self.metadata,
197-
)
198-
self.assertEqual(augmented_text[0], "SA==ello LA== wocg==ld IQ==")
199-
200-
expected_metadata = deepcopy(self.expected_metadata["encode_text"])
201-
expected_metadata[0]["method"] = "char"
202-
expected_metadata[0]["encoder"] = Encoding.BASE64
203-
self.assertTrue(are_equal_metadata(self.metadata, expected_metadata))
204-
205178
def test_GetBaseline(self) -> None:
206179
augmented_baseline = txtaugs.GetBaseline()(self.texts, metadata=self.metadata)
207180

@@ -296,8 +269,8 @@ def test_LeetSpeak_Sentence(self) -> None:
296269
aug_min=1,
297270
aug_max=1,
298271
aug_p=1.0,
299-
method=Method.SENTENCE,
300-
encoder=Encoding.LEETSPEAK,
272+
granularity="all",
273+
encoder="leetspeak",
301274
n=1,
302275
p=1.0,
303276
)(
@@ -306,7 +279,7 @@ def test_LeetSpeak_Sentence(self) -> None:
306279
)
307280

308281
self.assertTrue(augmented_text[0] == "h3110, w0r1d!")
309-
self.expected_metadata["encode_text"][0]["encoder"] = Encoding.LEETSPEAK
282+
self.expected_metadata["encode_text"][0]["encoder"] = "leetspeak"
310283
self.assertTrue(
311284
are_equal_metadata(self.metadata, self.expected_metadata["encode_text"])
312285
)
@@ -318,8 +291,8 @@ def test_Leetspeak_Word(self) -> None:
318291
aug_min=1,
319292
aug_max=1,
320293
aug_p=1.0,
321-
method=Method.WORD,
322-
encoder=Encoding.LEETSPEAK,
294+
granularity="word",
295+
encoder="leetspeak",
323296
n=1,
324297
p=1.0,
325298
)(
@@ -329,8 +302,8 @@ def test_Leetspeak_Word(self) -> None:
329302
self.assertEqual(augmented_text[0], "h3110, world!")
330303

331304
metadata_expected = deepcopy(self.expected_metadata["encode_text"])
332-
metadata_expected[0]["method"] = "word"
333-
metadata_expected[0]["encoder"] = Encoding.LEETSPEAK
305+
metadata_expected[0]["granularity"] = "word"
306+
metadata_expected[0]["encoder"] = "leetspeak"
334307
self.assertTrue(are_equal_metadata(self.metadata, metadata_expected))
335308

336309
def test_MergeWords(self) -> None:

augly/text/augmenters/base64.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import codecs
9+
from typing import Literal
910

1011
from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation
11-
from augly.text.augmenters.utils import Encoding
12-
from nlpaug.util import Method
1312

1413

1514
class Base64(EncodeTextAugmentation):
@@ -18,16 +17,20 @@ def __init__(
1817
aug_min: int,
1918
aug_max: int,
2019
aug_p: float,
21-
method: Method,
20+
granularity: Literal["all", "word", "char"],
2221
):
2322
super().__init__(
2423
name="Base64",
2524
aug_min=aug_min,
2625
aug_max=aug_max,
2726
aug_p=aug_p,
28-
encoder=Encoding.BASE64,
29-
method=str(method),
27+
encoder="base64",
28+
granularity=granularity,
3029
)
30+
assert granularity in {
31+
"all",
32+
"word",
33+
}, f"Base64 only supports granularity type 'all' or 'word', found type {granularity}"
3134
assert 0 <= aug_min <= aug_max
3235
assert 0 <= aug_p <= 1
3336

augly/text/augmenters/encode_text_strategy.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# pyre-unsafe
99

1010
from abc import abstractmethod
11-
from typing import List, Union
11+
from typing import List, Literal, Union
1212

13-
from augly.text.augmenters.utils import detokenize, Encoding, get_aug_idxes, tokenize
13+
from augly.text.augmenters.utils import detokenize, get_aug_idxes, tokenize
1414
from nlpaug.augmenter.word import Augmenter
1515
from nlpaug.util import Action, Method
1616

@@ -22,20 +22,25 @@ def __init__(
2222
aug_min: int,
2323
aug_max: int,
2424
aug_p: float,
25-
encoder: Encoding = Encoding.BASE64,
26-
method: str = Method.SENTENCE,
25+
granularity: Literal["all", "word", "char"],
26+
encoder: Literal["base64", "leetspeak"],
2727
):
28+
assert granularity in {
29+
"all",
30+
"word",
31+
"char",
32+
}, f"Granularity type must be either 'all', 'word', 'char', found type {granularity}"
2833
super().__init__(
2934
name=name,
3035
aug_min=aug_min,
3136
aug_max=aug_max,
3237
aug_p=aug_p,
3338
action=Action.SUBSTITUTE,
34-
method=method,
39+
method=Method.SENTENCE,
3540
)
3641

3742
self.encoder = encoder
38-
self.method = method
43+
self.granularity = granularity
3944

4045
@classmethod
4146
def clean(cls, data: Union[str, List[str], None]) -> Union[str, List[str]]:
@@ -61,35 +66,35 @@ def encode(self, input_string: str) -> str:
6166
raise NotImplementedError
6267

6368
def substitute(self, data: str) -> str:
64-
if self.method == Method.SENTENCE:
69+
if self.granularity == "all":
6570
return self.encode(data)
6671

6772
tokens = tokenize(data)
6873
if not tokens:
6974
return ""
7075

71-
if self.method == Method.WORD:
76+
if self.granularity == "word":
7277
augment_count = self._generate_aug_cnt(
7378
len(tokens), self.aug_min, self.aug_max, self.aug_p
7479
)
7580
to_augment = set(
7681
get_aug_idxes(
77-
self, tokens, list(range(len(tokens))), augment_count, Method.WORD
82+
self, tokens, list(range(len(tokens))), augment_count, "word"
7883
)
7984
)
8085
for i, token in enumerate(tokens):
8186
if i in to_augment:
8287
tokens[i] = self.encode(token)
8388

84-
elif self.method == Method.CHAR:
89+
elif self.granularity == "char":
8590
for token_idx, token in enumerate(tokens):
8691
chars = list(token)
8792
augment_count = self._generate_aug_cnt(
8893
len(chars), self.aug_min, self.aug_max, self.aug_p
8994
)
9095
to_augment = set(
9196
get_aug_idxes(
92-
self, chars, list(range(len(chars))), augment_count, Method.CHAR
97+
self, chars, list(range(len(chars))), augment_count, "char"
9398
)
9499
)
95100
for char_idx, char in enumerate(chars):

augly/text/augmenters/leetspeak.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import random
9+
from typing import Literal
910

1011
from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation
11-
from augly.text.augmenters.utils import Encoding
12-
from nlpaug.util import Method
1312

1413

1514
class LeetSpeak(EncodeTextAugmentation):
@@ -18,15 +17,15 @@ def __init__(
1817
aug_min: int,
1918
aug_max: int,
2019
aug_p: float,
21-
method: Method,
20+
granularity: Literal["all", "word", "char"],
2221
):
2322
super().__init__(
2423
name="LeetSpeak",
2524
aug_min=aug_min,
2625
aug_max=aug_max,
2726
aug_p=aug_p,
28-
encoder=Encoding.LEETSPEAK,
29-
method=str(method),
27+
encoder="leetspeak",
28+
granularity=granularity,
3029
)
3130
assert 0 <= aug_min <= aug_max
3231
assert 0 <= aug_p <= 1

augly/text/augmenters/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# pyre-unsafe
99

1010
import re
11-
from enum import Enum
1211
from typing import List, Optional, Tuple
1312

1413
import regex
@@ -270,8 +269,3 @@ def get_aug_idxes(
270269
aug_idxes = augmenter.sample(priority_idxes, aug_cnt)
271270

272271
return aug_idxes
273-
274-
275-
class Encoding(Enum):
276-
BASE64 = "base64"
277-
LEETSPEAK = "leetspeak"

augly/text/functional.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,16 @@
88
# pyre-unsafe
99

1010
from copy import deepcopy
11-
from typing import Any, Callable, Dict, List, Optional, Union
11+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1212

1313
from augly.text import augmenters as a, utils as txtutils
14-
from augly.text.augmenters.utils import Encoding
1514
from augly.utils import (
1615
CONTRACTIONS_MAPPING,
1716
FUN_FONTS_PATH,
1817
GENDERED_WORDS_MAPPING,
1918
MISSPELLING_DICTIONARY_PATH,
2019
UNICODE_MAPPING_PATH,
2120
)
22-
from nlpaug.util import Method
2321

2422

2523
def apply_lambda(
@@ -174,8 +172,8 @@ def encode_text(
174172
aug_min: int,
175173
aug_max: int,
176174
aug_p: float,
177-
method: Method,
178-
encoder: Encoding,
175+
granularity: Literal["all", "word", "char"],
176+
encoder: Literal["base64", "leetspeak"],
179177
n: int = 1,
180178
p: float = 1.0,
181179
metadata: Optional[List[Dict[str, Any]]] = None,
@@ -206,14 +204,19 @@ def encode_text(
206204
207205
@returns: the list of augmented(now in base 64) text documents
208206
"""
207+
assert encoder in {
208+
"base64",
209+
"leetspeak",
210+
}, f"Encode text only supports encoder type 'base64' or 'leetspeak', found type {encoder}"
211+
209212
func_kwargs = txtutils.get_func_kwargs(metadata, locals())
210213

211214
if not isinstance(texts, list):
212215
texts = [texts]
213-
if encoder == Encoding.BASE64:
214-
encoder_strategy = a.Base64(aug_min, aug_max, aug_p, method)
216+
if encoder == "base64":
217+
encoder_strategy = a.Base64(aug_min, aug_max, aug_p, granularity)
215218
else:
216-
encoder_strategy = a.LeetSpeak(aug_min, aug_max, aug_p, method)
219+
encoder_strategy = a.LeetSpeak(aug_min, aug_max, aug_p, granularity)
217220
encoder_context = a.EncodeText(encoder_strategy)
218221
aug_texts = encoder_context.augmenter(texts)
219222

0 commit comments

Comments
 (0)