@@ -29,6 +29,7 @@ class Hyperparameters(GenerativeLM.Hyperparameters):
29
29
gpu_memory_utilization : confloat (gt = 0.0 , le = 1.0 ) = 0.95
30
30
max_model_len : conint (ge = 1 )
31
31
generation_params : Union [TextGenerationParams , Dict , str ]
32
+ api_key : Optional [str ] = None
32
33
33
34
@model_validator (mode = "before" )
34
35
@classmethod
@@ -46,14 +47,25 @@ def set_params(cls, params: Dict) -> Dict:
46
47
params ,
47
48
param = "max_model_len" ,
48
49
alias = [
50
+ "max_length" ,
49
51
"max_len" ,
50
- "max_model_len" ,
51
52
"max_sequence_length" ,
52
53
"max_sequence_len" ,
53
54
"max_input_length" ,
54
55
"max_input_len" ,
56
+ "max_model_length" ,
57
+ "max_model_len" ,
58
+ ],
59
+ )
60
+ set_param_from_alias (
61
+ params ,
62
+ param = "api_key" ,
63
+ alias = [
64
+ "token" ,
65
+ "api_token" ,
55
66
],
56
67
)
68
+
57
69
params ["generation_params" ] = TextGenerationParamsMapper .of (
58
70
params ["generation_params" ]
59
71
).initialize ()
@@ -74,10 +86,11 @@ def initialize(self, model_dir: Optional[FileMetadata] = None):
74
86
gpu_memory_utilization = self .hyperparams .gpu_memory_utilization ,
75
87
max_model_len = self .hyperparams .max_model_len ,
76
88
)
77
-
89
+ kwargs [ "hf_overrides" ]: Dict = dict ()
78
90
if self .cache_dir is not None :
79
91
kwargs ["download_dir" ] = self .cache_dir .path
80
-
92
+ if self .hyperparams .api_key is not None :
93
+ kwargs ["hf_overrides" ]["api_key" ] = self .hyperparams .api_key
81
94
print (f"Initializing vllm with kwargs: { kwargs } " )
82
95
self .llm = LLM (** kwargs )
83
96
@@ -103,6 +116,7 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict:
103
116
outputs = self .llm .generate (
104
117
prompts ,
105
118
sampling_params = sampling_params ,
119
+ use_tqdm = False ,
106
120
)
107
121
108
122
result = {GENERATED_TEXTS_COL : [output .outputs [0 ].text for output in outputs ]}
0 commit comments