Skip to content

Commit 8b33f14

Browse files
authored
Revert "Feat/chipper repetitions" (#312)
Reverts #295
1 parent 54e3e46 commit 8b33f14

File tree

4 files changed

+19
-253
lines changed

4 files changed

+19
-253
lines changed

CHANGELOG.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
## 0.7.21
2-
3-
* Revised repetitions for Chipper
4-
51
## 0.7.20
62

73
* chipper-v3: improved table prediction

test_unstructured_inference/models/test_chippermodel.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_no_repeat_ngram_logits():
139139

140140
no_repeat_ngram_size = 2
141141

142-
logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2, context_length=10)
142+
logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2)
143143
output = logitsProcessor(input_ids=input_ids, scores=logits)
144144

145145
assert (
@@ -194,49 +194,20 @@ def test_ngram_repetiton_stopping_criteria():
194194
logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]])
195195

196196
stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
197-
ngram_size=2, context_length=10, skip_tokens={0, 1, 2, 3, 4}
197+
repetition_window=2, skip_tokens={0, 1, 2, 3, 4}
198198
)
199199

200200
output = stoppingCriteria(input_ids=input_ids, scores=logits)
201201

202202
assert output is False
203203

204204
stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
205-
ngram_size=2, context_length=10, skip_tokens={1, 2, 3, 4}
205+
repetition_window=2, skip_tokens={1, 2, 3, 4}
206206
)
207207
output = stoppingCriteria(input_ids=input_ids, scores=logits)
208208
assert output is True
209209

210210

211-
def test_no_repeat_group_ngram_logits_processor():
212-
input_ids = torch.tensor([[1, 2, 3, 4, 0, 1, 2, 3, 4]])
213-
logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]])
214-
215-
logitsProcessor = chipper.NoRepeatGroupNGramLogitsProcessor(ngram_size=3, token_group=[1, 2])
216-
217-
output = logitsProcessor(input_ids=input_ids, scores=logits)
218-
219-
assert (
220-
int(
221-
torch.sum(
222-
output == torch.tensor([[0.1000, -0.3000, -0.5000, 0.0000, 1.0000, -0.9000]]),
223-
),
224-
)
225-
== 6
226-
)
227-
228-
229-
def test_target_token_id_stopping_criterion():
230-
input_ids = torch.tensor([1, 2, 3])
231-
logits = torch.tensor([0.1, 0.2, 0.3])
232-
233-
stoppingCriterion = chipper.TargetTokenIdStoppingCriterion(1)
234-
235-
output = stoppingCriterion(input_ids=input_ids, scores=logits)
236-
237-
assert output is True
238-
239-
240211
@pytest.mark.parametrize(
241212
("decoded_str", "expected_classes"),
242213
[
@@ -288,8 +259,7 @@ def test_predict_tokens_beam_indices():
288259
model = get_model("chipper")
289260
model.stopping_criteria = [
290261
chipper.NGramRepetitonStoppingCriteria(
291-
ngram_size=1,
292-
context_length=10,
262+
repetition_window=1,
293263
skip_tokens={},
294264
),
295265
]
@@ -326,7 +296,7 @@ def test_deduplicate_detected_elements():
326296

327297
def test_norepeatnGramlogitsprocessor_exception():
328298
with pytest.raises(ValueError):
329-
chipper.NoRepeatNGramLogitsProcessor(ngram_size="", context_length=10)
299+
chipper.NoRepeatNGramLogitsProcessor(ngram_size="")
330300

331301

332302
def test_run_chipper_v3():
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.21" # pragma: no cover
1+
__version__ = "0.7.20" # pragma: no cover

0 commit comments

Comments
 (0)