Skip to content

Commit 945e94f

Browse files
Merge branch 'main' into add-cache-gitignore
2 parents b50d8e6 + 55aaea4 commit 945e94f

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/unitxt/inference.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -972,16 +972,18 @@ def _get_model_args(self) -> Dict[str, Any]:
972972
return args
973973

974974
def _create_pipeline(self, model_args: Dict[str, Any]):
975-
from transformers import pipeline
975+
from transformers import AutoTokenizer, pipeline
976976

977977
path = self.model_name
978978
if settings.hf_offline_models_path is not None:
979979
path = os.path.join(settings.hf_offline_models_path, path)
980980

981+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
981982
self.model = pipeline(
982983
model=path,
983984
task=self.task,
984985
use_fast=self.use_fast_tokenizer,
986+
tokenizer=tokenizer,
985987
trust_remote_code=settings.allow_unverified_code,
986988
**model_args,
987989
**self.to_dict(

src/unitxt/loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def load_dataset(
307307
self.__class__._loader_cache.max_size = settings.loader_cache_size
308308
if not disable_memory_caching:
309309
self.__class__._loader_cache[dataset_id] = dataset
310-
return self.__class__._loader_cache[dataset_id]
310+
return dataset
311311

312312
def _maybe_set_classification_policy(self):
313313
if os.path.exists(self.path):

0 commit comments

Comments
 (0)