@@ -139,7 +139,7 @@ def test_no_repeat_ngram_logits():
139139
140140 no_repeat_ngram_size = 2
141141
142- logitsProcessor = chipper .NoRepeatNGramLogitsProcessor (ngram_size = 2 , context_length = 10 )
142+ logitsProcessor = chipper .NoRepeatNGramLogitsProcessor (ngram_size = 2 )
143143 output = logitsProcessor (input_ids = input_ids , scores = logits )
144144
145145 assert (
@@ -194,49 +194,20 @@ def test_ngram_repetiton_stopping_criteria():
194194 logits = torch .tensor ([[0.1 , - 0.3 , - 0.5 , 0 , 1.0 , - 0.9 ]])
195195
196196 stoppingCriteria = chipper .NGramRepetitonStoppingCriteria (
197- ngram_size = 2 , context_length = 10 , skip_tokens = {0 , 1 , 2 , 3 , 4 }
197+ repetition_window = 2 , skip_tokens = {0 , 1 , 2 , 3 , 4 }
198198 )
199199
200200 output = stoppingCriteria (input_ids = input_ids , scores = logits )
201201
202202 assert output is False
203203
204204 stoppingCriteria = chipper .NGramRepetitonStoppingCriteria (
205- ngram_size = 2 , context_length = 10 , skip_tokens = {1 , 2 , 3 , 4 }
205+ repetition_window = 2 , skip_tokens = {1 , 2 , 3 , 4 }
206206 )
207207 output = stoppingCriteria (input_ids = input_ids , scores = logits )
208208 assert output is True
209209
210210
211- def test_no_repeat_group_ngram_logits_processor ():
212- input_ids = torch .tensor ([[1 , 2 , 3 , 4 , 0 , 1 , 2 , 3 , 4 ]])
213- logits = torch .tensor ([[0.1 , - 0.3 , - 0.5 , 0 , 1.0 , - 0.9 ]])
214-
215- logitsProcessor = chipper .NoRepeatGroupNGramLogitsProcessor (ngram_size = 3 , token_group = [1 , 2 ])
216-
217- output = logitsProcessor (input_ids = input_ids , scores = logits )
218-
219- assert (
220- int (
221- torch .sum (
222- output == torch .tensor ([[0.1000 , - 0.3000 , - 0.5000 , 0.0000 , 1.0000 , - 0.9000 ]]),
223- ),
224- )
225- == 6
226- )
227-
228-
229- def test_target_token_id_stopping_criterion ():
230- input_ids = torch .tensor ([1 , 2 , 3 ])
231- logits = torch .tensor ([0.1 , 0.2 , 0.3 ])
232-
233- stoppingCriterion = chipper .TargetTokenIdStoppingCriterion (1 )
234-
235- output = stoppingCriterion (input_ids = input_ids , scores = logits )
236-
237- assert output is True
238-
239-
240211@pytest .mark .parametrize (
241212 ("decoded_str" , "expected_classes" ),
242213 [
@@ -288,8 +259,7 @@ def test_predict_tokens_beam_indices():
288259 model = get_model ("chipper" )
289260 model .stopping_criteria = [
290261 chipper .NGramRepetitonStoppingCriteria (
291- ngram_size = 1 ,
292- context_length = 10 ,
262+ repetition_window = 1 ,
293263 skip_tokens = {},
294264 ),
295265 ]
@@ -326,7 +296,7 @@ def test_deduplicate_detected_elements():
326296
327297def test_norepeatnGramlogitsprocessor_exception ():
328298 with pytest .raises (ValueError ):
329- chipper .NoRepeatNGramLogitsProcessor (ngram_size = "" , context_length = 10 )
299+ chipper .NoRepeatNGramLogitsProcessor (ngram_size = "" )
330300
331301
332302def test_run_chipper_v3 ():
0 commit comments