-
Notifications
You must be signed in to change notification settings - Fork 592
[Feature] Support for Gemma-3 Models #821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
3602b03
797140f
dff6085
68ae823
c28f92d
c0f7e0c
96d75c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Run and exactly reproduce gemma3 results! | ||
| # mme as an example | ||
|
|
||
| accelerate launch --num_processes=8 --main_process_port=12345 -m lmms_eval \ | ||
| --model gemma3 \ | ||
| --model_args=pretrained=google/gemma-3-4b-it \ | ||
| --tasks mmmu_val,ai2d,mathvista_testmini \ | ||
| --batch_size 1 --output_path ./logs/ | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,299 @@ | ||||||||||||||||||||||||||
| import base64 | ||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||
| import re | ||||||||||||||||||||||||||
| import uuid | ||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||
| from io import BytesIO | ||||||||||||||||||||||||||
| from typing import List, Optional, Tuple, Union | ||||||||||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import decord | ||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| from accelerate import Accelerator, DistributedType | ||||||||||||||||||||||||||
| from PIL import Image | ||||||||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from lmms_eval import utils | ||||||||||||||||||||||||||
| from lmms_eval.api.instance import Instance | ||||||||||||||||||||||||||
| from lmms_eval.api.model import lmms | ||||||||||||||||||||||||||
| from lmms_eval.api.registry import register_model | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| warnings.simplefilter("ignore", category=DeprecationWarning) | ||||||||||||||||||||||||||
| warnings.filterwarnings("ignore") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from loguru import logger as eval_logger | ||||||||||||||||||||||||||
| from transformers import AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @register_model("gemma3") | ||||||||||||||||||||||||||
| class Gemma3(lmms): | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Gemma3 Model | ||||||||||||||||||||||||||
| https://huggingface.co/google/gemma-3-27b-it | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| pretrained: str = "google/gemma-3-27b-it", | ||||||||||||||||||||||||||
| device: Optional[str] = "cuda", | ||||||||||||||||||||||||||
| device_map: Optional[str] = "auto", | ||||||||||||||||||||||||||
| batch_size: Optional[Union[int, str]] = 1, | ||||||||||||||||||||||||||
| trust_remote_code: Optional[bool] = True, | ||||||||||||||||||||||||||
| use_cache=True, | ||||||||||||||||||||||||||
|
Comment on lines
+40
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch_size type mismatch. Annotated as Optional[Union[int, str]] but cast with int(); non‑numeric strings will crash. Apply this diff to accept only int: - batch_size: Optional[Union[int, str]] = 1,
+ batch_size: int = 1,
@@
- self.batch_size_per_gpu = int(batch_size)
+ self.batch_size_per_gpu = batch_sizeAlso applies to: 80-80 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| attn_implementation: Optional[str] = None, | ||||||||||||||||||||||||||
| min_pixels: int = 256 * 28 * 28, | ||||||||||||||||||||||||||
| max_pixels: int = 1605632, | ||||||||||||||||||||||||||
| max_num_frames: int = 32, | ||||||||||||||||||||||||||
| interleave_visuals: Optional[bool] = False, | ||||||||||||||||||||||||||
| system_prompt: Optional[str] = "You are a helpful assistant.", | ||||||||||||||||||||||||||
| reasoning_prompt: Optional[str] = None, | ||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||
| # Do not use kwargs for now | ||||||||||||||||||||||||||
| assert kwargs == {}, f"Unexpected kwargs: {kwargs}" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| accelerator = Accelerator() | ||||||||||||||||||||||||||
|
Comment on lines
+34
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conflicting kwargs usage for max_length. You assert kwargs == {} but later read max_length from kwargs; unreachable path. Apply this diff to make max_length explicit and remove dead kwargs access: - system_prompt: Optional[str] = "You are a helpful assistant.",
- reasoning_prompt: Optional[str] = None,
- **kwargs,
+ system_prompt: Optional[str] = "You are a helpful assistant.",
+ reasoning_prompt: Optional[str] = None,
+ max_length: int = 2048,
) -> None:
super().__init__()
- # Do not use kwargs for now
- assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
+ # NOTE: no extra kwargs are accepted to avoid silent misconfigurations
+ # (add parameters explicitly if needed)
@@
- self._max_length = kwargs.get("max_length", 2048)
+ self._max_length = int(max_length)Also applies to: 77-79 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| if accelerator.num_processes > 1: | ||||||||||||||||||||||||||
| self._device = torch.device(f"cuda:{accelerator.local_process_index}") | ||||||||||||||||||||||||||
| self.device_map = f"cuda:{accelerator.local_process_index}" | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| self._device = torch.device(device) | ||||||||||||||||||||||||||
| self.device_map = device_map if device_map else device | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Prepare model loading arguments | ||||||||||||||||||||||||||
| model_kwargs = { | ||||||||||||||||||||||||||
| "torch_dtype": torch.bfloat16, | ||||||||||||||||||||||||||
| "device_map": self.device_map, | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Add attention implementation if specified | ||||||||||||||||||||||||||
| if attn_implementation is not None: | ||||||||||||||||||||||||||
| model_kwargs["attn_implementation"] = attn_implementation | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self._model = Gemma3ForConditionalGeneration.from_pretrained(pretrained, **model_kwargs).eval() | ||||||||||||||||||||||||||
| self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map) | ||||||||||||||||||||||||||
| self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+82
to
+84
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AutoTokenizer.from_pretrained does not accept device_map. Passing device_map will raise a TypeError; remove it from tokenizer init. Apply this diff: - self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ pretrained,
+ trust_remote_code=trust_remote_code,
+ )📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| self._config = self._model.config | ||||||||||||||||||||||||||
| self._max_length = kwargs.get("max_length", 2048) | ||||||||||||||||||||||||||
| self.model.tie_weights() | ||||||||||||||||||||||||||
| self.batch_size_per_gpu = int(batch_size) | ||||||||||||||||||||||||||
| self.use_cache = use_cache | ||||||||||||||||||||||||||
| self.system_prompt = system_prompt | ||||||||||||||||||||||||||
| self.interleave_visuals = interleave_visuals | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.max_pixels = max_pixels | ||||||||||||||||||||||||||
| self.min_pixels = min_pixels | ||||||||||||||||||||||||||
| self.max_num_frames = max_num_frames | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+93
to
+96
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Unused parameters: interleave_visuals, max_num_frames. They are not used; either implement or remove to avoid confusion. Would you like me to wire these into visual processing (e.g., limit frame sampling)? 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| if reasoning_prompt: | ||||||||||||||||||||||||||
| self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n") | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| self.reasoning_prompt = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if accelerator.num_processes > 1: | ||||||||||||||||||||||||||
| assert accelerator.distributed_type in [ | ||||||||||||||||||||||||||
| DistributedType.FSDP, | ||||||||||||||||||||||||||
| DistributedType.MULTI_GPU, | ||||||||||||||||||||||||||
| ], "Unsupported distributed type provided. Only DDP and FSDP are supported." | ||||||||||||||||||||||||||
| if accelerator.distributed_type == DistributedType.FSDP: | ||||||||||||||||||||||||||
| self._model = accelerator.prepare(self.model) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| self._model = accelerator.prepare_model(self.model, evaluation_mode=True) | ||||||||||||||||||||||||||
| self.accelerator = accelerator | ||||||||||||||||||||||||||
| if self.accelerator.is_local_main_process: | ||||||||||||||||||||||||||
| eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") | ||||||||||||||||||||||||||
| self._rank = self.accelerator.local_process_index | ||||||||||||||||||||||||||
| self._world_size = self.accelerator.num_processes | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| self.model.to(self._device) | ||||||||||||||||||||||||||
| self._rank = 0 | ||||||||||||||||||||||||||
| self._world_size = 1 | ||||||||||||||||||||||||||
| self.model.eval() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def config(self): | ||||||||||||||||||||||||||
| # return the associated transformers.AutoConfig for the given pretrained model. | ||||||||||||||||||||||||||
| return self._config | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def tokenizer(self): | ||||||||||||||||||||||||||
| return self._tokenizer | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def model(self): | ||||||||||||||||||||||||||
| # returns the model, unwrapping it if using Accelerate | ||||||||||||||||||||||||||
| if hasattr(self, "accelerator"): | ||||||||||||||||||||||||||
| return self.accelerator.unwrap_model(self._model) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| return self._model | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def eot_token_id(self): | ||||||||||||||||||||||||||
| # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* | ||||||||||||||||||||||||||
| # return self.tokenizer.eod_id | ||||||||||||||||||||||||||
| return self.tokenizer.eos_token_id | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def max_length(self): | ||||||||||||||||||||||||||
| return self._max_length | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def batch_size(self): | ||||||||||||||||||||||||||
| return self.batch_size_per_gpu | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def device(self): | ||||||||||||||||||||||||||
| return self._device | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def rank(self): | ||||||||||||||||||||||||||
| return self._rank | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def world_size(self): | ||||||||||||||||||||||||||
| return self._world_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+122
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add missing type hints and docstrings for public API. Guidelines require type hints and docstrings; properties and flatten lack them. Apply this diff: class Gemma3(lmms):
- """
- Gemma3 Model
- https://huggingface.co/google/gemma-3-27b-it
- """
+ """Gemma-3 simple wrapper for lmms-eval.
+
+ Loads a Gemma-3 IT checkpoint and provides batched `generate_until`.
+ """
@@
- def config(self):
- # return the associated transformers.AutoConfig for the given pretrained model.
- return self._config
+ def config(self) -> PretrainedConfig:
+ """Return the underlying Transformers config."""
+ return self._config
@@
- def tokenizer(self):
- return self._tokenizer
+ def tokenizer(self) -> PreTrainedTokenizerBase:
+ """Return the tokenizer/processor tokenizer."""
+ return self._tokenizer
@@
- def model(self):
- # returns the model, unwrapping it if using Accelerate
+ def model(self) -> Gemma3ForConditionalGeneration:
+ """Return the model, unwrapped if using Accelerate."""
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@@
- def eot_token_id(self):
+ def eot_token_id(self) -> int:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# return self.tokenizer.eod_id
return self.tokenizer.eos_token_id
@@
- def max_length(self):
+ def max_length(self) -> int:
return self._max_length
@@
- def batch_size(self):
+ def batch_size(self) -> int:
return self.batch_size_per_gpu
@@
- def device(self):
+ def device(self) -> torch.device:
return self._device
@@
- def rank(self):
+ def rank(self) -> int:
return self._rank
@@
- def world_size(self):
+ def world_size(self) -> int:
return self._world_size
@@
- def flatten(self, input):
- new_list = []
- for i in input:
- for j in i:
- new_list.append(j)
- return new_list
+ def flatten(self, items: List[List[str]]) -> List[str]:
+ """Flatten a list-of-lists one level."""
+ return [j for i in items for j in i]
@@
- def generate_until(self, requests: List[Instance]) -> List[str]:
+ def generate_until(self, requests: List[Instance]) -> List[str]:
+ """Generate responses for each Instance until stop tokens."""
res = []
@@
- def generate_until_multi_round(self, requests) -> List[str]:
- raise NotImplementedError("TODO: Implement multi-round generation")
+ def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
+ """Multi-round chat generation (not yet implemented)."""
+ raise NotImplementedError("TODO: Implement multi-round generation")Also applies to: 160-166, 167-167, 298-300 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: | ||||||||||||||||||||||||||
| raise NotImplementedError("Not implemented for Gemma3.") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def flatten(self, input): | ||||||||||||||||||||||||||
| new_list = [] | ||||||||||||||||||||||||||
| for i in input: | ||||||||||||||||||||||||||
| for j in i: | ||||||||||||||||||||||||||
| new_list.append(j) | ||||||||||||||||||||||||||
| return new_list | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def generate_until(self, requests: List[Instance]) -> List[str]: | ||||||||||||||||||||||||||
| res = [] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _collate(x): | ||||||||||||||||||||||||||
| # the negative sign on len(toks) sorts descending - this has a few advantages: | ||||||||||||||||||||||||||
| # - time estimates will always be over not underestimates, which is more useful for planning | ||||||||||||||||||||||||||
| # - to know the size of a batch when going through the list, you know the first one is always the batch | ||||||||||||||||||||||||||
| # padded context length. this is useful to simplify the batching logic and more importantly to make | ||||||||||||||||||||||||||
| # automatic adaptive batches much much easier to implement | ||||||||||||||||||||||||||
| # - any OOMs will happen right away rather than near the end | ||||||||||||||||||||||||||
| toks = self.tokenizer.encode(x[0]) | ||||||||||||||||||||||||||
| return -len(toks), x[0] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") | ||||||||||||||||||||||||||
| # we group requests by their generation_kwargs, | ||||||||||||||||||||||||||
| # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling | ||||||||||||||||||||||||||
| # in the same batch. | ||||||||||||||||||||||||||
| re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) | ||||||||||||||||||||||||||
| chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) | ||||||||||||||||||||||||||
| for chunk in chunks: | ||||||||||||||||||||||||||
| contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) | ||||||||||||||||||||||||||
| task = task[0] | ||||||||||||||||||||||||||
| split = split[0] | ||||||||||||||||||||||||||
| visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] | ||||||||||||||||||||||||||
| gen_kwargs = all_gen_kwargs[0] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Set default until or update values from gen_kwargs if present | ||||||||||||||||||||||||||
| until = gen_kwargs.get("until", [self.tokenizer.decode(self.eot_token_id)]) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if isinstance(until, str): | ||||||||||||||||||||||||||
| until = [until] | ||||||||||||||||||||||||||
| elif not isinstance(until, list): | ||||||||||||||||||||||||||
| raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Avoid using '\n\n' as a stopper to prevent truncation, which can lead to incorrect results | ||||||||||||||||||||||||||
| until = [item for item in until if item != "\n\n"] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if isinstance(contexts, tuple): | ||||||||||||||||||||||||||
| contexts = list(contexts) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| for i in range(len(contexts)): | ||||||||||||||||||||||||||
| if "<image>" in contexts[i]: | ||||||||||||||||||||||||||
| contexts[i] = contexts[i].replace("<image>", "") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| batched_messages = [] | ||||||||||||||||||||||||||
| for i, context in enumerate(contexts): | ||||||||||||||||||||||||||
| if "<image>" in context: | ||||||||||||||||||||||||||
| context = context.replace("<image>", "") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+231
to
+239
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion You strip Apply this diff to remove the first loop: - for i in range(len(contexts)):
- if "<image>" in contexts[i]:
- contexts[i] = contexts[i].replace("<image>", "")📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| message = [{"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if self.reasoning_prompt: | ||||||||||||||||||||||||||
| context = context.strip() + self.reasoning_prompt | ||||||||||||||||||||||||||
| contexts[i] = context | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| processed_visuals = [] | ||||||||||||||||||||||||||
| for visual in visual_list[i]: | ||||||||||||||||||||||||||
| if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file | ||||||||||||||||||||||||||
| processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels}) | ||||||||||||||||||||||||||
| elif isinstance(visual, Image.Image): # Handle both single and multiple images | ||||||||||||||||||||||||||
| base64_image = visual.convert("RGB") | ||||||||||||||||||||||||||
| buffer = BytesIO() | ||||||||||||||||||||||||||
| base64_image.save(buffer, format="JPEG") | ||||||||||||||||||||||||||
| base64_bytes = base64.b64encode(buffer.getvalue()) | ||||||||||||||||||||||||||
| base64_string = base64_bytes.decode("utf-8") | ||||||||||||||||||||||||||
| processed_visuals.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}", "max_pixels": self.max_pixels, "min_pixels": self.min_pixels}) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| message.append( | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| "role": "user", | ||||||||||||||||||||||||||
| "content": processed_visuals + [{"type": "text", "text": context}], | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| batched_messages.append(message) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to( | ||||||||||||||||||||||||||
| self.model.device, dtype=torch.bfloat16 | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if self.device_map == "auto": | ||||||||||||||||||||||||||
| inputs = inputs.to("cuda") | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| inputs = inputs.to(self.device) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+278
to
+282
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid double .to() and wrap long call. Inputs are moved twice; also exceed 88 chars. Apply this diff: - inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to(
- self.model.device, dtype=torch.bfloat16
- )
-
- if self.device_map == "auto":
- inputs = inputs.to("cuda")
- else:
- inputs = inputs.to(self.device)
+ inputs = self.processor.apply_chat_template(
+ batched_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding="max_length",
+ pad_to_multiple_of=8,
+ max_length=self.max_length,
+ )
+ target_device = "cuda" if self.device_map == "auto" else self.device
+ inputs = inputs.to(target_device, dtype=torch.bfloat16)Also applies to: 243-245 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| # Set default generation kwargs | ||||||||||||||||||||||||||
| default_gen_kwargs = { | ||||||||||||||||||||||||||
| "max_new_tokens": 128, | ||||||||||||||||||||||||||
| "temperature": 0.0, # Set to 0 for greedy default | ||||||||||||||||||||||||||
| "top_p": None, | ||||||||||||||||||||||||||
| "num_beams": 1, | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| # Update with provided kwargs | ||||||||||||||||||||||||||
| current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs} | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if current_gen_kwargs["temperature"] > 0: | ||||||||||||||||||||||||||
| current_gen_kwargs["do_sample"] = True | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| current_gen_kwargs["do_sample"] = False | ||||||||||||||||||||||||||
| current_gen_kwargs["temperature"] = None | ||||||||||||||||||||||||||
| current_gen_kwargs["top_p"] = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+284
to
+299
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generation kwargs: avoid passing None values to generate(). temperature/top_p=None can raise at runtime. Only pass when set; keep greedy defaults. Apply this diff: - default_gen_kwargs = {
- "max_new_tokens": 128,
- "temperature": 0.0, # Set to 0 for greedy default
- "top_p": None,
- "num_beams": 1,
- }
+ default_gen_kwargs = {
+ "max_new_tokens": 128,
+ "temperature": 0.0, # greedy
+ "num_beams": 1,
+ }
@@
- if current_gen_kwargs["temperature"] > 0:
- current_gen_kwargs["do_sample"] = True
- else:
- current_gen_kwargs["do_sample"] = False
- current_gen_kwargs["temperature"] = None
- current_gen_kwargs["top_p"] = None
+ do_sample = float(current_gen_kwargs.get("temperature", 0.0)) > 0.0
+ top_p = current_gen_kwargs.get("top_p", None)
+ if do_sample and top_p is None:
+ top_p = 0.95
@@
- cont = self.model.generate(
- **inputs,
- do_sample=current_gen_kwargs["do_sample"],
- temperature=current_gen_kwargs["temperature"],
- top_p=current_gen_kwargs["top_p"],
- num_beams=current_gen_kwargs["num_beams"],
- max_new_tokens=current_gen_kwargs["max_new_tokens"],
- use_cache=self.use_cache,
- )
+ gen_args = dict(
+ **inputs,
+ num_beams=current_gen_kwargs["num_beams"],
+ max_new_tokens=current_gen_kwargs["max_new_tokens"],
+ use_cache=self.use_cache,
+ )
+ if do_sample:
+ gen_args.update(do_sample=True, temperature=float(current_gen_kwargs["temperature"]), top_p=float(top_p))
+ cont = self.model.generate(**gen_args)Also applies to: 269-277 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| cont = self.model.generate( | ||||||||||||||||||||||||||
| **inputs, | ||||||||||||||||||||||||||
| do_sample=current_gen_kwargs["do_sample"], | ||||||||||||||||||||||||||
| temperature=current_gen_kwargs["temperature"], | ||||||||||||||||||||||||||
| top_p=current_gen_kwargs["top_p"], | ||||||||||||||||||||||||||
| num_beams=current_gen_kwargs["num_beams"], | ||||||||||||||||||||||||||
| max_new_tokens=current_gen_kwargs["max_new_tokens"], | ||||||||||||||||||||||||||
| use_cache=self.use_cache, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)] | ||||||||||||||||||||||||||
| answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) | ||||||||||||||||||||||||||
| for i, ans in enumerate(answers): | ||||||||||||||||||||||||||
| # print(f"Raw answer {i}: {ans}") | ||||||||||||||||||||||||||
| for term in until: | ||||||||||||||||||||||||||
| if len(term) > 0: | ||||||||||||||||||||||||||
| ans = ans.split(term)[0] | ||||||||||||||||||||||||||
| answers[i] = ans | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| for ans, context in zip(answers, contexts): | ||||||||||||||||||||||||||
| res.append(ans) | ||||||||||||||||||||||||||
| self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans) | ||||||||||||||||||||||||||
| pbar.update(1) | ||||||||||||||||||||||||||
| # reorder this group of results back to original unsorted form | ||||||||||||||||||||||||||
| res = re_ords.get_original(res) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| pbar.close() | ||||||||||||||||||||||||||
| return res | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def generate_until_multi_round(self, requests) -> List[str]: | ||||||||||||||||||||||||||
| raise NotImplementedError("TODO: Implement multi-round generation") | ||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.