Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 18e1439

Browse files
committed
Add pipeline as kwargs to hf models
1 parent afad25a commit 18e1439

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

Diff for: genai_stack/model/hf.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Dict
22
from langchain.llms import HuggingFacePipeline
3+
from transformers import pipeline
34

45
from genai_stack.model.base import BaseModel, BaseModelConfig, BaseModelConfigModel
56

@@ -17,6 +18,8 @@ class HuggingFaceModelConfigModel(BaseModelConfigModel):
1718
"""Key word arguments passed to the pipeline."""
1819
task: str = "text-generation"
1920
"""Valid tasks: 'text2text-generation', 'text-generation', 'summarization'"""
21+
pipeline: Optional[pipeline] = None
22+
"""If pipeline is passed, all other configs are ignored."""
2023

2124

2225
class HuggingFaceModelConfig(BaseModelConfig):
@@ -30,9 +33,14 @@ def _post_init(self, *args, **kwargs):
3033
self.model = self.load()
3134

3235
def load(self):
33-
model = HuggingFacePipeline.from_model_id(
34-
model_id=self.config.model, task=self.config.task, model_kwargs=self.config.model_kwargs
35-
)
36+
if self.config.pipeline is not None:
37+
model = self.config.pipeline
38+
else:
39+
model = HuggingFacePipeline.from_model_id(
40+
model_id=self.config.model,
41+
task=self.config.task,
42+
model_kwargs=self.config.model_kwargs,
43+
)
3644
return model
3745

3846
def predict(self, prompt: str):

0 commit comments

Comments
 (0)