File tree 2 files changed +4
-2
lines changed
2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -972,16 +972,18 @@ def _get_model_args(self) -> Dict[str, Any]:
972
972
return args
973
973
974
974
def _create_pipeline (self , model_args : Dict [str , Any ]):
975
- from transformers import pipeline
975
+ from transformers import AutoTokenizer , pipeline
976
976
977
977
path = self .model_name
978
978
if settings .hf_offline_models_path is not None :
979
979
path = os .path .join (settings .hf_offline_models_path , path )
980
980
981
+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
981
982
self .model = pipeline (
982
983
model = path ,
983
984
task = self .task ,
984
985
use_fast = self .use_fast_tokenizer ,
986
+ tokenizer = tokenizer ,
985
987
trust_remote_code = settings .allow_unverified_code ,
986
988
** model_args ,
987
989
** self .to_dict (
Original file line number Diff line number Diff line change @@ -307,7 +307,7 @@ def load_dataset(
307
307
self .__class__ ._loader_cache .max_size = settings .loader_cache_size
308
308
if not disable_memory_caching :
309
309
self .__class__ ._loader_cache [dataset_id ] = dataset
310
- return self . __class__ . _loader_cache [ dataset_id ]
310
+ return dataset
311
311
312
312
def _maybe_set_classification_policy (self ):
313
313
if os .path .exists (self .path ):
You can’t perform that action at this time.
0 commit comments