Skip to content

Commit 8ca745c

Browse files
sharanvamsifacebook-github-bot
authored andcommitted
Created new context strategy function with base64 function (#265)
Summary: Pull Request resolved: #265 Created a context interface, and strategy parent class for base64 function encoding and future types of encoding. Lot of code is reused from diffs D75792134 and D75792239. Differential Revision: D76028127
1 parent 4fb3507 commit 8ca745c

File tree

13 files changed

+292
-176
lines changed

13 files changed

+292
-176
lines changed

augly/tests/assets/expected_metadata/text_tests/expected_metadata.json

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,20 @@
8282
"src_length": 1
8383
}
8484
],
85-
"encode_base64": [
85+
"encode_text": [
8686
{
8787
"dst_length": 1,
8888
"input_type": "list",
8989
"intensity": 100.0,
90-
"name": "encode_base64",
90+
"name": "encode_text",
9191
"src_length": 1,
92-
"granularity": "all",
9392
"aug_min": 1,
94-
"aug_max": 10,
95-
"aug_p": 0.3,
96-
"n": 1
93+
"aug_max": 1,
94+
"aug_p": 1.0,
95+
"n": 1,
96+
"p": 1.0,
97+
"encoder": "base64",
98+
"method": "sentence"
9799
}
98100
],
99101
"get_baseline": [

augly/tests/text_tests/functional_unit_test.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import unittest
1111

1212
from augly import text as txtaugs
13+
from augly.text.augmenters.utils import Encoding
1314
from augly.utils import FUN_FONTS_GREEK_PATH
15+
from nlpaug.util import Method
1416

1517

1618
class FunctionalTextUnitTest(unittest.TestCase):
@@ -51,34 +53,23 @@ def test_contractions(self) -> None:
5153
augmented_words[0] == "I would call him but I don't know where he's gone"
5254
)
5355

54-
def test_encode_base64_all(self) -> None:
55-
augmented_words = txtaugs.encode_base64("Hello, world!")
56-
self.assertTrue(augmented_words[0] == "SGVsbG8sIHdvcmxkIQ==")
57-
58-
def test_encode_base64_word(self) -> None:
59-
random.seed(42) # Set seed for reproducibility
60-
augmented_words_word = txtaugs.encode_base64(
61-
"Hello, world!", granularity="word", aug_min=1, aug_max=1, aug_p=1.0
56+
def test_encode_text_base64_sentence(self) -> None:
57+
augmented_words = txtaugs.encode_text(
58+
"Hello, world!", 1, 1, 1.0, Method.SENTENCE, Encoding.BASE64
6259
)
63-
self.assertEqual(augmented_words_word[0], "SGVsbG8=, world!")
60+
self.assertEqual(augmented_words[0], "SGVsbG8sIHdvcmxkIQ==")
6461

65-
def test_encode_base64_char(self) -> None:
66-
random.seed(42)
67-
augmented_words_char = txtaugs.encode_base64(
68-
"Hello, world!", granularity="char", aug_min=1, aug_max=2, aug_p=1.0
62+
def test_encode_text_base64_word(self) -> None:
63+
augmented_words_word = txtaugs.encode_text(
64+
"Hello, world!", 1, 1, 1.0, Method.WORD, Encoding.BASE64
6965
)
70-
self.assertEqual(augmented_words_char[0], "SA==ellbw== LA== wbw==rlZA== IQ==")
66+
self.assertEqual(augmented_words_word[0], "SGVsbG8=, world!")
7167

72-
def test_encode_base64_general(self) -> None:
73-
random.seed(42)
74-
augmented_words_low_p = txtaugs.encode_base64(
75-
"Hello, world!", granularity="word", aug_min=1, aug_max=2, aug_p=0.1
76-
)
77-
random.seed(42)
78-
augmented_words_high_p = txtaugs.encode_base64(
79-
"Hello, world!", granularity="word", aug_min=1, aug_max=2, aug_p=0.9
68+
def test_encode_text_base64_char(self) -> None:
69+
augmented_words_char = txtaugs.encode_text(
70+
"Hello, world!", 1, 1, 1.0, Method.CHAR, Encoding.BASE64
8071
)
81-
self.assertTrue(len(augmented_words_high_p[0]) > len(augmented_words_low_p[0]))
72+
self.assertEqual(augmented_words_char[0], "SA==ello LA== dw==orld IQ==")
8273

8374
def test_get_baseline(self) -> None:
8475
augmented_baseline = txtaugs.get_baseline(self.texts)

augly/tests/text_tests/transforms_unit_test.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from typing import Any, Dict, List
1515

1616
from augly import text as txtaugs
17+
from augly.text.augmenters.utils import Encoding
1718
from augly.utils import TEXT_METADATA_PATH
19+
from nlpaug.util import Method
1820

1921

2022
def are_equal_metadata(
@@ -136,57 +138,68 @@ def test_Compose(self) -> None:
136138
are_equal_metadata(self.metadata, self.expected_metadata["compose"]),
137139
)
138140

139-
def test_EncodeBase64(self) -> None:
140-
augmented_text = txtaugs.EncodeBase64(
141-
granularity="all", aug_min=1, aug_max=10, aug_p=0.3, n=1, p=1.0
141+
def test_EncodeText_Base64_Sentence(self) -> None:
142+
augmented_text = txtaugs.EncodeTextTransform(
143+
aug_min=1,
144+
aug_max=1,
145+
aug_p=1.0,
146+
method=Method.SENTENCE,
147+
encoder=Encoding.BASE64,
148+
n=1,
149+
p=1.0,
142150
)(
143151
["Hello, world!"],
144152
metadata=self.metadata,
145153
)
146154

147155
self.assertTrue(augmented_text[0] == "SGVsbG8sIHdvcmxkIQ==")
156+
self.expected_metadata["encode_text"][0]["encoder"] = Encoding.BASE64
148157
self.assertTrue(
149-
are_equal_metadata(self.metadata, self.expected_metadata["encode_base64"])
158+
are_equal_metadata(self.metadata, self.expected_metadata["encode_text"])
150159
)
151160

152-
def test_EncodeBase64_Word(self) -> None:
161+
def test_EncodeText_Base64_Word(self) -> None:
153162
self.metadata = []
154163

155-
random.seed(42)
156-
augmented_text = txtaugs.EncodeBase64(
157-
granularity="word", aug_min=1, aug_max=1, aug_p=1.0, n=1, p=1.0
164+
augmented_text = txtaugs.EncodeTextTransform(
165+
aug_min=1,
166+
aug_max=1,
167+
aug_p=1.0,
168+
method=Method.WORD,
169+
encoder=Encoding.BASE64,
170+
n=1,
171+
p=1.0,
158172
)(
159173
["Hello, world!"],
160174
metadata=self.metadata,
161175
)
162176
self.assertEqual(augmented_text[0], "SGVsbG8=, world!")
163177

164-
expected_metadata = deepcopy(self.expected_metadata["encode_base64"])
165-
expected_metadata[0]["granularity"] = "word"
166-
expected_metadata[0]["aug_p"] = 1.0
167-
expected_metadata[0]["aug_max"] = 1
168-
expected_metadata[0]["intensity"] = 100.0
178+
metadata_expected = deepcopy(self.expected_metadata["encode_text"])
179+
metadata_expected[0]["method"] = "word"
180+
metadata_expected[0]["encoder"] = Encoding.BASE64
181+
self.assertTrue(are_equal_metadata(self.metadata, metadata_expected))
169182

170-
self.assertTrue(are_equal_metadata(self.metadata, expected_metadata))
171-
172-
def test_EncodeBase64_Char(self) -> None:
183+
def test_EncodeText_Base64_Char(self) -> None:
173184
self.metadata = []
174185

175-
random.seed(42)
176-
augmented_text = txtaugs.EncodeBase64(
177-
granularity="char", aug_min=1, aug_max=2, aug_p=1.0, n=1, p=1.0
186+
augmented_text = txtaugs.EncodeTextTransform(
187+
aug_min=1,
188+
aug_max=1,
189+
aug_p=1.0,
190+
method=Method.CHAR,
191+
encoder=Encoding.BASE64,
192+
n=1,
193+
p=1.0,
178194
)(
179195
["Hello, world!"],
180196
metadata=self.metadata,
181197
)
182-
self.assertEqual(augmented_text[0], "SA==ebA==lo LA== wbw==rlZA== IQ==")
183-
184-
expected_metadata = deepcopy(self.expected_metadata["encode_base64"])
185-
expected_metadata[0]["granularity"] = "char"
186-
expected_metadata[0]["aug_p"] = 1.0
187-
expected_metadata[0]["aug_max"] = 2
188-
expected_metadata[0]["intensity"] = 100.0
198+
self.assertEqual(augmented_text[0], "SA==ello LA== wocg==ld IQ==")
189199

200+
expected_metadata = deepcopy(self.expected_metadata["encode_text"])
201+
expected_metadata[0]["method"] = "char"
202+
expected_metadata[0]["encoder"] = Encoding.BASE64
190203
self.assertTrue(are_equal_metadata(self.metadata, expected_metadata))
191204

192205
def test_GetBaseline(self) -> None:

augly/text/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
apply_lambda,
1313
change_case,
1414
contractions,
15-
encode_base64,
15+
encode_text,
1616
get_baseline,
1717
insert_punctuation_chars,
1818
insert_text,
@@ -32,9 +32,10 @@
3232
)
3333
from augly.text.intensity import (
3434
apply_lambda_intensity,
35+
base64_intensity,
3536
change_case_intensity,
3637
contractions_intensity,
37-
encode_base64_intensity,
38+
encode_text_intensity,
3839
get_baseline_intensity,
3940
insert_punctuation_chars_intensity,
4041
insert_text_intensity,
@@ -56,7 +57,7 @@
5657
ApplyLambda,
5758
ChangeCase,
5859
Contractions,
59-
EncodeBase64,
60+
EncodeTextTransform,
6061
GetBaseline,
6162
InsertPunctuationChars,
6263
InsertText,
@@ -81,7 +82,7 @@
8182
"ApplyLambda",
8283
"ChangeCase",
8384
"Contractions",
84-
"EncodeBase64",
85+
"EncodeTextTransform",
8586
"GetBaseline",
8687
"InsertPunctuationChars",
8788
"InsertText",
@@ -101,7 +102,7 @@
101102
"apply_lambda",
102103
"change_case",
103104
"contractions",
104-
"encode_base64",
105+
"encode_text",
105106
"get_baseline",
106107
"insert_punctuation_chars",
107108
"insert_text",
@@ -119,9 +120,10 @@
119120
"split_words",
120121
"swap_gendered_words",
121122
"apply_lambda_intensity",
123+
"base64_intensity",
122124
"change_case_intensity",
123125
"contractions_intensity",
124-
"encode_base64_intensity",
126+
"encode_text_intensity",
125127
"get_baseline_intensity",
126128
"insert_punctuation_chars_intensity",
127129
"insert_text_intensity",

augly/text/augmenters/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77

88
# pyre-unsafe
99

10+
from augly.text.augmenters.base64 import Base64
1011
from augly.text.augmenters.baseline import BaselineAugmenter
1112
from augly.text.augmenters.bidirectional import BidirectionalAugmenter
1213
from augly.text.augmenters.case import CaseAugmenter
1314
from augly.text.augmenters.contraction import ContractionAugmenter
14-
from augly.text.augmenters.encode_base64 import EncodeBase64
15+
from augly.text.augmenters.encode_text_context import EncodeText
16+
from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation
1517
from augly.text.augmenters.fun_fonts import FunFontsAugmenter
1618
from augly.text.augmenters.insert_text import InsertTextAugmenter
1719
from augly.text.augmenters.insertion import InsertionAugmenter
@@ -22,13 +24,14 @@
2224
from augly.text.augmenters.word_replacement import WordReplacementAugmenter
2325
from augly.text.augmenters.words_augmenter import WordsAugmenter
2426

25-
2627
__all__ = [
28+
"Base64",
2729
"BaselineAugmenter",
2830
"BidirectionalAugmenter",
2931
"CaseAugmenter",
3032
"ContractionAugmenter",
31-
"EncodeBase64",
33+
"EncodeText",
34+
"EncodeTextAugmentation",
3235
"FunFontsAugmenter",
3336
"InsertTextAugmenter",
3437
"InsertionAugmenter",

augly/text/augmenters/base64.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import codecs
9+
10+
from augly.text.augmenters.encode_text_strategy import EncodeTextAugmentation
11+
from augly.text.augmenters.utils import Encoding
12+
from nlpaug.util import Method
13+
14+
15+
class Base64(EncodeTextAugmentation):
16+
def __init__(
17+
self,
18+
aug_min: int,
19+
aug_max: int,
20+
aug_p: float,
21+
method: Method,
22+
):
23+
super().__init__(
24+
name="Base64",
25+
aug_min=aug_min,
26+
aug_max=aug_max,
27+
aug_p=aug_p,
28+
encoder=Encoding.BASE64,
29+
method=str(method),
30+
)
31+
assert 0 <= aug_min <= aug_max
32+
assert 0 <= aug_p <= 1
33+
34+
def encode(self, input_string: str) -> str:
35+
encoded_bytes = codecs.encode(input_string.encode("utf-8"), "base64")
36+
return encoded_bytes.decode("utf-8").strip()

0 commit comments

Comments
 (0)