Skip to content

Commit fd85862

Browse files
Merge pull request #26 from amazon-science/vllm
[#23] 🚀 Added support for vllm language models.
2 parents 080b519 + b0e7122 commit fd85862

File tree

3 files changed

+233
-52
lines changed

3 files changed

+233
-52
lines changed

src/fmcore/algorithm/vllm.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
import numpy as np
4+
from bears import FileMetadata
5+
from bears.util import EnvUtil, get_default, optional_dependency, set_param_from_alias
6+
from pydantic import confloat, conint, model_validator
7+
8+
from fmcore.framework._task.text_generation import (
9+
GENERATED_TEXTS_COL,
10+
GenerativeLM,
11+
NextTokens,
12+
Prompts,
13+
TextGenerationParams,
14+
TextGenerationParamsMapper,
15+
)
16+
17+
with optional_dependency("vllm"):
18+
from vllm import LLM, SamplingParams
19+
20+
class VLLMGenerativeLM(GenerativeLM):
21+
aliases = ["vllm"]
22+
23+
llm: Optional[LLM] = None
24+
cache_dir: Optional[Union[FileMetadata, Dict, str]] = None
25+
26+
class Hyperparameters(GenerativeLM.Hyperparameters):
27+
model_name: str
28+
tensor_parallel_size: Optional[conint(ge=1)] = None
29+
gpu_memory_utilization: confloat(gt=0.0, le=1.0) = 0.95
30+
max_model_len: conint(ge=1)
31+
generation_params: Union[TextGenerationParams, Dict, str]
32+
33+
@model_validator(mode="before")
34+
@classmethod
35+
def set_params(cls, params: Dict) -> Dict:
36+
set_param_from_alias(
37+
params,
38+
param="model_name",
39+
alias=[
40+
"model",
41+
"pretrained_model_name_or_path",
42+
"model_name_or_path",
43+
],
44+
)
45+
set_param_from_alias(
46+
params,
47+
param="max_model_len",
48+
alias=[
49+
"max_len",
50+
"max_model_len",
51+
"max_sequence_length",
52+
"max_sequence_len",
53+
"max_input_length",
54+
"max_input_len",
55+
],
56+
)
57+
params["generation_params"] = TextGenerationParamsMapper.of(
58+
params["generation_params"]
59+
).initialize()
60+
if params.get("cache_dir") is not None:
61+
params["cache_dir"] = FileMetadata.of(params["cache_dir"])
62+
return params
63+
64+
def initialize(self, model_dir: Optional[FileMetadata] = None):
65+
"""Initialize the VLLM model"""
66+
tensor_parallel_size: Optional[conint(ge=1)] = get_default(
67+
self.hyperparams.tensor_parallel_size,
68+
EnvUtil.num_gpus(), # Use all GPUs by default
69+
)
70+
71+
kwargs = dict(
72+
model=self.hyperparams.model_name,
73+
tensor_parallel_size=tensor_parallel_size,
74+
gpu_memory_utilization=self.hyperparams.gpu_memory_utilization,
75+
max_model_len=self.hyperparams.max_model_len,
76+
)
77+
78+
if self.cache_dir is not None:
79+
kwargs["download_dir"] = self.cache_dir.path
80+
81+
print(f"Initializing vllm with kwargs: {kwargs}")
82+
self.llm = LLM(**kwargs)
83+
84+
def predict_step(self, batch: Prompts, **kwargs) -> Dict:
85+
"""Run prediction on a batch of prompts"""
86+
prompts: List[str] = batch.prompts().to_list()
87+
88+
sampling_params = SamplingParams(
89+
min_tokens=self.hyperparams.generation_params.min_new_tokens,
90+
max_tokens=self.hyperparams.generation_params.max_new_tokens,
91+
temperature=0.0
92+
if not self.hyperparams.generation_params.do_sample
93+
else self.hyperparams.generation_params.temperature,
94+
top_p=self.hyperparams.generation_params.top_p
95+
if hasattr(self.hyperparams.generation_params, "top_p")
96+
else 1.0,
97+
top_k=self.hyperparams.generation_params.top_k
98+
if hasattr(self.hyperparams.generation_params, "top_k")
99+
else -1,
100+
stop=self.hyperparams.generation_params.stop_sequences,
101+
logprobs=self.hyperparams.generation_params.output_scores,
102+
)
103+
outputs = self.llm.generate(
104+
prompts,
105+
sampling_params=sampling_params,
106+
)
107+
108+
result = {GENERATED_TEXTS_COL: [output.outputs[0].text for output in outputs]}
109+
110+
if self.hyperparams.generation_params.output_scores:
111+
# Get token IDs and logprobs for each generation
112+
token_ids = []
113+
tokens = []
114+
token_scores = []
115+
116+
for output in outputs:
117+
# Get the first (and only) generation
118+
generation = output.outputs[0]
119+
120+
# Extract token IDs, tokens and logprobs
121+
gen_token_ids = generation.token_ids
122+
gen_tokens = generation.tokens
123+
gen_logprobs = generation.logprobs
124+
125+
# Convert scores based on output_scores_format
126+
if self.hyperparams.generation_params.output_scores_format == "probabilities":
127+
# Convert from log probabilities to probabilities
128+
gen_logprobs = np.exp(gen_logprobs)
129+
# Filter based on tolerance
130+
if self.hyperparams.generation_params.output_scores_tolerance is not None:
131+
mask = gen_logprobs >= self.hyperparams.generation_params.output_scores_tolerance
132+
gen_token_ids = [t for t, m in zip(gen_token_ids, mask) if m]
133+
gen_tokens = [t for t, m in zip(gen_tokens, mask) if m]
134+
gen_logprobs = [s for s, m in zip(gen_logprobs, mask) if m]
135+
136+
elif self.hyperparams.generation_params.output_scores_format == "log-probabilities":
137+
# Already in log probabilities format
138+
# Filter based on tolerance
139+
if self.hyperparams.generation_params.output_scores_tolerance is not None:
140+
mask = gen_logprobs >= self.hyperparams.generation_params.output_scores_tolerance
141+
gen_token_ids = [t for t, m in zip(gen_token_ids, mask) if m]
142+
gen_tokens = [t for t, m in zip(gen_tokens, mask) if m]
143+
gen_logprobs = [s for s, m in zip(gen_logprobs, mask) if m]
144+
145+
elif self.hyperparams.generation_params.output_scores_format == "logits":
146+
# Don't filter or modify scores when using raw logits
147+
pass
148+
149+
token_ids.append(gen_token_ids)
150+
tokens.append(gen_tokens)
151+
token_scores.append(gen_logprobs)
152+
153+
result.update(
154+
{
155+
"generated_token_ids": token_ids,
156+
"generated_tokens": tokens,
157+
"generated_token_scores": token_scores,
158+
}
159+
)
160+
161+
return result
162+
163+
def _create_predictions(self, batch: Prompts, predictions: Any, **kwargs) -> NextTokens:
164+
"""Convert raw predictions to NextTokens format"""
165+
return NextTokens.from_task_data(data=batch, predictions=predictions, **kwargs)
166+
167+
@property
168+
def max_num_generated_tokens(self) -> int:
169+
return self.hyperparams.generation_params.max_new_tokens
170+
171+
def cleanup(self):
172+
"""Cleanup the llm"""
173+
if self.llm is not None:
174+
del self.llm
175+
self.llm = None

src/fmcore/framework/_evaluator/Evaluator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def evaluate(
369369
)
370370

371371
Alias.set_return_predictions(kwargs)
372-
return_predictions: bool = kwargs.pop("return_predictions", False)
372+
return_predictions: bool = kwargs.pop("return_predictions", True)
373373

374374
Alias.set_predictions_destination(kwargs)
375375
predictions_destination: Optional[Union[io.IOBase, FileMetadata, Dict, str]] = kwargs.pop(
@@ -419,9 +419,8 @@ def evaluate(
419419
kwargs["tracker"]: Tracker = Tracker.of(kwargs["tracker"])
420420

421421
try:
422-
self._evaluator_is_running: bool = (
423-
True ## Ensures we do not accidentally delete the models while running.
424-
)
422+
## Ensures we do not accidentally delete the models while running.
423+
self._evaluator_is_running: bool = True
425424
evaluated_predictions, evaluated_metrics = self._run_evaluation(
426425
dataset,
427426
metrics=metrics,
@@ -430,6 +429,8 @@ def evaluate(
430429
progress_bar=progress_bar,
431430
**kwargs,
432431
)
432+
except Exception as e:
433+
raise e
433434
finally:
434435
if self.cache_timeout is not None: ## Rests the timeout
435436
self.cache_timeout.reset_timeout()

src/fmcore/framework/_evaluator/RayEvaluator.py

+53-48
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,14 @@ def __init__(
126126
self.actor = actor
127127
self.request_counter: RequestCounter = request_counter
128128

129-
def is_available(self) -> bool:
129+
def get_evaluator_status(self) -> str:
130130
try:
131-
return self.evaluator is None
132-
except Exception:
133-
return False
131+
if self.evaluator is None:
132+
return "Evaluator not initialized."
133+
assert isinstance(self.evaluator, Evaluator)
134+
return (self.evaluator.class_name, self.evaluator.model.class_name)
135+
except Exception as e:
136+
return String.format_exception_msg(e)
134137

135138
def get_ip_address(self) -> Optional[str]:
136139
try:
@@ -363,6 +366,10 @@ def ray_evaluator_params(cls, params: Dict) -> Dict:
363366
return params
364367

365368
def initialize(self, reinit_ray: bool = False, **kwargs):
369+
if self.model_num_gpus <= 1 or self.AlgorithmClass.class_name == "VLLMGenerativeLM":
370+
self.nested_evaluator_name: str = get_default(self.nested_evaluator_name, "local")
371+
else:
372+
self.nested_evaluator_name: str = get_default(self.nested_evaluator_name, "accelerate")
366373
## Connect to the Ray cluster
367374
if not ray.is_initialized() or reinit_ray is True:
368375
ray.init(
@@ -385,7 +392,7 @@ def _load_model(
385392
**kwargs,
386393
) -> List[RayActorComposite]:
387394
num_actors: int = get_default(num_actors, self.num_actors)
388-
progress_bar: Optional[Dict] = self._run_evaluation_progress_bar(progress_bar)
395+
progress_bar: Union[Dict, bool] = self._run_evaluation_progress_bar(progress_bar)
389396
nested_evaluator_params: Dict = self._create_nested_evaluator_params(**kwargs)
390397

391398
def actor_factory(*, request_counter: Any, actor_i: int, actor_id: str, **kwargs):
@@ -475,10 +482,7 @@ def num_actors(self) -> int:
475482
return num_actors
476483

477484
def _create_nested_evaluator_params(self, **kwargs) -> Dict:
478-
nested_evaluator_name: str = get_default(
479-
self.nested_evaluator_name,
480-
"accelerate" if self.model_num_gpus > 1 else "local",
481-
)
485+
nested_evaluator_name: str = self.nested_evaluator_name
482486
if self.model_dir is not None and not self.model_dir.is_remote_storage():
483487
raise ValueError(
484488
f"When passing `model_dir` to {self.class_name}.of(...), the model directory "
@@ -563,44 +567,44 @@ def _run_evaluation(
563567
evaluated_predictions: Optional[Predictions] = None
564568
evaluated_metrics: Optional[List[Metric]] = None
565569

566-
try:
567-
timer: Timer = Timer(silent=True)
568-
timer.start()
569-
## Verbosity >= 1: progress bars
570-
progress_bar: Optional[Dict] = self._run_evaluation_progress_bar(progress_bar)
571-
## Verbosity >= 2: basic logging
572-
main_logger: Callable = partial(
573-
self.ray_logger,
574-
## Unless we request silence (verbosity=0), print important information.
575-
should_log=self.verbosity >= 2,
576-
tracker=tracker,
577-
)
578-
## Verbosity >= 3: detailed logging
579-
debug_logger: Callable = partial(
580-
self.ray_logger,
581-
## Unless we request silence (verbosity=0), print important information.
582-
should_log=self.verbosity >= 3,
583-
tracker=tracker,
570+
timer: Timer = Timer(silent=True)
571+
timer.start()
572+
## Verbosity >= 1: progress bars
573+
progress_bar: Union[Dict, bool] = self._run_evaluation_progress_bar(progress_bar)
574+
## Verbosity >= 2: basic logging
575+
main_logger: Callable = partial(
576+
self.ray_logger,
577+
## Unless we request silence (verbosity=0), print important information.
578+
should_log=self.verbosity >= 2,
579+
tracker=tracker,
580+
)
581+
## Verbosity >= 3: detailed logging
582+
debug_logger: Callable = partial(
583+
self.ray_logger,
584+
## Unless we request silence (verbosity=0), print important information.
585+
should_log=self.verbosity >= 3,
586+
tracker=tracker,
587+
)
588+
main_logger(self._evaluate_start_msg(tracker=tracker, **kwargs))
589+
if batch_size is None:
590+
raise ValueError(
591+
f"Could not find batch_size in model hyperparams; "
592+
f"please pass it explicitly like so: {self.class_name}.evaluate(batch_size=...)"
584593
)
585-
main_logger(self._evaluate_start_msg(tracker=tracker, **kwargs))
586-
if batch_size is None:
594+
if predictions_destination is not None:
595+
if predictions_destination.storage is not Storage.S3:
587596
raise ValueError(
588-
f"Could not find batch_size in model hyperparams; "
589-
f"please pass it explicitly like so: {self.class_name}.evaluate(batch_size=...)"
597+
f"Results can only be saved to {Storage.S3}; "
598+
f"found storage {predictions_destination.storage} having path: {predictions_destination.path}"
590599
)
591-
if predictions_destination is not None:
592-
if predictions_destination.storage is not Storage.S3:
593-
raise ValueError(
594-
f"Results can only be saved to {Storage.S3}; "
595-
f"found storage {predictions_destination.storage} having path: {predictions_destination.path}"
596-
)
597-
if not predictions_destination.is_path_valid_dir():
598-
raise ValueError(
599-
f"Expected predictions destination to be a valid directory; "
600-
f'found: "{predictions_destination.path}"...did you forget a "/" at the end?'
601-
)
602-
assert predictions_destination.format is not None ## Checked in .evaluate().
600+
if not predictions_destination.is_path_valid_dir():
601+
raise ValueError(
602+
f"Expected predictions destination to be a valid directory; "
603+
f'found: "{predictions_destination.path}"...did you forget a "/" at the end?'
604+
)
605+
assert predictions_destination.format is not None ## Checked in .evaluate().
603606

607+
try:
604608
actors_were_created_in_this_call: bool = self.init_model(progress_bar=progress_bar, **kwargs)
605609
num_actors_created: int = len(self.model)
606610
if actors_were_created_in_this_call:
@@ -869,15 +873,16 @@ def _run_evaluation(
869873
)
870874
)
871875
return evaluated_predictions, evaluated_metrics
876+
except Exception as e:
877+
raise e
872878
except KeyboardInterrupt as e:
873879
raise e
874880
finally:
875881
if "row_counter" in locals():
876882
accumulate(ray.kill(row_counter))
877883
del row_counter
878-
if (
879-
self.cache_timeout is None
880-
): ## If we don't have a timeout, delete actors after every execution.
884+
## If we don't have a timeout, delete actors after every execution.
885+
if self.cache_timeout is None:
881886
self.cleanup_model()
882887
return evaluated_predictions, evaluated_metrics
883888

@@ -894,10 +899,10 @@ def _get_actor_usages(self) -> List[Tuple[int, float, str]]:
894899
)
895900
return actor_usages
896901

897-
def _run_evaluation_progress_bar(self, progress_bar: Optional[Dict], **kwargs) -> Optional[Dict]:
902+
def _run_evaluation_progress_bar(self, progress_bar: Optional[Dict], **kwargs) -> Union[Dict, bool]:
898903
if self.verbosity >= 2:
899904
return progress_bar
900-
return None
905+
return False
901906

902907
def _evaluate_start_msg(self, *, tracker: Tracker, **kwargs) -> str:
903908
if tracker.tracker_name == "noop":

0 commit comments

Comments
 (0)