Skip to content

Commit 47f8da5

Browse files
committed
[llm_bench] Add hooks for transformers v5
1 parent 16150e4 commit 47f8da5

26 files changed

Lines changed: 4238 additions & 2903 deletions

tools/llm_bench/llm_bench_utils/hook_beam_search.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,34 @@
1212
tm_infer_list = []
1313
tm_mm_embeddings = []
1414

15-
if version.parse(transformers.__version__) >= version.parse("4.57.0"):
16-
import llm_bench_utils.llm_hook_beam_search.hook_beam_search_v57 as hook_beam_search_v57
15+
if version.parse(transformers.__version__) >= version.parse("5.3.0"):
16+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v5_3 as hook_beam_search_v5_3
1717

18-
new_beam_search = hook_beam_search_v57.new_beam_search_v57
18+
new_beam_search = hook_beam_search_v5_3.new_beam_search
19+
elif version.parse(transformers.__version__) >= version.parse("5.0"):
20+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v5 as hook_beam_search_v5
21+
22+
new_beam_search = hook_beam_search_v5.new_beam_search
23+
elif version.parse(transformers.__version__) >= version.parse("4.57.0"):
24+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v4_57 as hook_beam_search_v4_57
25+
26+
new_beam_search = hook_beam_search_v4_57.new_beam_search_v57
1927
elif version.parse(transformers.__version__) >= version.parse("4.55.0"):
20-
import llm_bench_utils.llm_hook_beam_search.hook_beam_search_v55 as hook_beam_search_v55
21-
new_beam_search = hook_beam_search_v55.new_beam_search_v55
28+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v4_55 as hook_beam_search_v4_55
29+
30+
new_beam_search = hook_beam_search_v4_55.new_beam_search_v55
2231
elif version.parse(transformers.__version__) >= version.parse("4.52.0"):
23-
import llm_bench_utils.llm_hook_beam_search.hook_beam_search_v52 as hook_beam_search_v52
24-
new_beam_search = hook_beam_search_v52.new_beam_search_v52
32+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v4_52 as hook_beam_search_v4_52
33+
34+
new_beam_search = hook_beam_search_v4_52.new_beam_search_v52
2535
elif version.parse(transformers.__version__) >= version.parse("4.51.0"):
26-
import llm_bench_utils.llm_hook_beam_search.hook_beam_search_v51 as hook_beam_search_v51
27-
new_beam_search = hook_beam_search_v51.new_beam_search_v51
36+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v4_51 as hook_beam_search_v4_51
37+
38+
new_beam_search = hook_beam_search_v4_51.new_beam_search_v51
2839
else:
29-
import llm_bench_utils.llm_hook_beam_search.hook_beam_search_v40 as hook_beam_search_v40
30-
new_beam_search = hook_beam_search_v40.new_beam_search_v40
40+
import tools.llm_bench.llm_bench_utils.llm_hook_beam_search.hook_beam_search_v4_40 as hook_beam_search_v4_40
41+
42+
new_beam_search = hook_beam_search_v4_40.new_beam_search_v40
3143

3244
def new_get_multimodal_embeddings(
3345
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, **kwargs

tools/llm_bench/llm_bench_utils/hook_greedy_search.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -379,25 +379,38 @@ def new_forward(self, model):
379379
"""Define a new greedy search function."""
380380
model._greedy_search = new_greedy_search.__get__(model, model.__class__)
381381
trans_version = version.parse(transformers.__version__)
382-
if trans_version >= version.parse("4.57.0"):
383-
import llm_bench_utils.llm_hook_sample.hook_sample_v57 as hook_sample_v57
382+
if trans_version >= version.parse("5.3.0"):
383+
import llm_bench_utils.llm_hook_sample.hook_sample_v5_3 as hook_sample_v5_3
384384

385-
type(model)._sample = hook_sample_v57.new_sample
385+
type(model)._sample = hook_sample_v5_3.new_sample
386+
if trans_version >= version.parse("5.0"):
387+
import llm_bench_utils.llm_hook_sample.hook_sample_v5 as hook_sample_v5
388+
389+
type(model)._sample = hook_sample_v5.new_sample
390+
elif trans_version >= version.parse("4.57.0"):
391+
import tools.llm_bench.llm_bench_utils.llm_hook_sample.hook_sample_v4_57 as hook_sample_v4_57
392+
393+
type(model)._sample = hook_sample_v4_57.new_sample
386394
elif trans_version >= version.parse("4.55.0"):
387-
import llm_bench_utils.llm_hook_sample.hook_sample_v55 as hook_sample_v55
388-
model._sample = hook_sample_v55.new_sample.__get__(model, model.__class__)
395+
import tools.llm_bench.llm_bench_utils.llm_hook_sample.hook_sample_v4_55 as hook_sample_v4_55
396+
397+
model._sample = hook_sample_v4_55.new_sample.__get__(model, model.__class__)
389398
elif trans_version >= version.parse('4.52.0'):
390-
import llm_bench_utils.llm_hook_sample.hook_sample_v52 as hook_sample_v52
391-
model._sample = hook_sample_v52.new_sample.__get__(model, model.__class__)
399+
import tools.llm_bench.llm_bench_utils.llm_hook_sample.hook_sample_v4_52 as hook_sample_v4_52
400+
401+
model._sample = hook_sample_v4_52.new_sample.__get__(model, model.__class__)
392402
elif trans_version >= version.parse('4.51.0'):
393-
import llm_bench_utils.llm_hook_sample.hook_sample_v51 as hook_sample_v51
394-
model._sample = hook_sample_v51.new_sample.__get__(model, model.__class__)
403+
import tools.llm_bench.llm_bench_utils.llm_hook_sample.hook_sample_v4_51 as hook_sample_v4_51
404+
405+
model._sample = hook_sample_v4_51.new_sample.__get__(model, model.__class__)
395406
elif trans_version >= version.parse('4.45.0'):
396-
import llm_bench_utils.llm_hook_sample.hook_sample_v45 as hook_sample_v45
397-
model._sample = hook_sample_v45.new_sample.__get__(model, model.__class__)
407+
import tools.llm_bench.llm_bench_utils.llm_hook_sample.hook_sample_v4_45 as hook_sample_v4_45
408+
409+
model._sample = hook_sample_v4_45.new_sample.__get__(model, model.__class__)
398410
elif trans_version >= version.parse('4.43.0'):
399-
import llm_bench_utils.llm_hook_sample.hook_sample_v43 as hook_sample_v43
400-
model._sample = hook_sample_v43.new_sample.__get__(model, model.__class__)
411+
import tools.llm_bench.llm_bench_utils.llm_hook_sample.hook_sample_v4_43 as hook_sample_v4_43
412+
413+
model._sample = hook_sample_v4_43.new_sample.__get__(model, model.__class__)
401414
else:
402415
model._sample = hook_sample.new_sample.__get__(model, model.__class__)
403416

0 commit comments

Comments
 (0)