@@ -379,25 +379,38 @@ def new_forward(self, model):
379379 """Define a new greedy search function."""
380380 model ._greedy_search = new_greedy_search .__get__ (model , model .__class__ )
381381 trans_version = version .parse (transformers .__version__ )
382- if trans_version >= version .parse ("4.57 .0" ):
383- import llm_bench_utils .llm_hook_sample .hook_sample_v57 as hook_sample_v57
382+ if trans_version >= version .parse ("5.3 .0" ):
383+ import llm_bench_utils .llm_hook_sample .hook_sample_v5_3 as hook_sample_v5_3
384384
385- type(model )._sample = hook_sample_v57 .new_sample
385+ type(model )._sample = hook_sample_v5_3 .new_sample
386+ if trans_version >= version .parse ("5.0" ):
387+ import llm_bench_utils .llm_hook_sample .hook_sample_v5 as hook_sample_v5
388+
389+ type(model )._sample = hook_sample_v5 .new_sample
390+ elif trans_version >= version .parse ("4.57.0" ):
391+ import tools .llm_bench .llm_bench_utils .llm_hook_sample .hook_sample_v4_57 as hook_sample_v4_57
392+
393+ type(model )._sample = hook_sample_v4_57 .new_sample
386394 elif trans_version >= version .parse ("4.55.0" ):
387- import llm_bench_utils .llm_hook_sample .hook_sample_v55 as hook_sample_v55
388- model ._sample = hook_sample_v55 .new_sample .__get__ (model , model .__class__ )
395+ import tools .llm_bench .llm_bench_utils .llm_hook_sample .hook_sample_v4_55 as hook_sample_v4_55
396+
397+ model ._sample = hook_sample_v4_55 .new_sample .__get__ (model , model .__class__ )
389398 elif trans_version >= version .parse ('4.52.0' ):
390- import llm_bench_utils .llm_hook_sample .hook_sample_v52 as hook_sample_v52
391- model ._sample = hook_sample_v52 .new_sample .__get__ (model , model .__class__ )
399+ import tools .llm_bench .llm_bench_utils .llm_hook_sample .hook_sample_v4_52 as hook_sample_v4_52
400+
401+ model ._sample = hook_sample_v4_52 .new_sample .__get__ (model , model .__class__ )
392402 elif trans_version >= version .parse ('4.51.0' ):
393- import llm_bench_utils .llm_hook_sample .hook_sample_v51 as hook_sample_v51
394- model ._sample = hook_sample_v51 .new_sample .__get__ (model , model .__class__ )
403+ import tools .llm_bench .llm_bench_utils .llm_hook_sample .hook_sample_v4_51 as hook_sample_v4_51
404+
405+ model ._sample = hook_sample_v4_51 .new_sample .__get__ (model , model .__class__ )
395406 elif trans_version >= version .parse ('4.45.0' ):
396- import llm_bench_utils .llm_hook_sample .hook_sample_v45 as hook_sample_v45
397- model ._sample = hook_sample_v45 .new_sample .__get__ (model , model .__class__ )
407+ import tools .llm_bench .llm_bench_utils .llm_hook_sample .hook_sample_v4_45 as hook_sample_v4_45
408+
409+ model ._sample = hook_sample_v4_45 .new_sample .__get__ (model , model .__class__ )
398410 elif trans_version >= version .parse ('4.43.0' ):
399- import llm_bench_utils .llm_hook_sample .hook_sample_v43 as hook_sample_v43
400- model ._sample = hook_sample_v43 .new_sample .__get__ (model , model .__class__ )
411+ import tools .llm_bench .llm_bench_utils .llm_hook_sample .hook_sample_v4_43 as hook_sample_v4_43
412+
413+ model ._sample = hook_sample_v4_43 .new_sample .__get__ (model , model .__class__ )
401414 else :
402415 model ._sample = hook_sample .new_sample .__get__ (model , model .__class__ )
403416
0 commit comments