Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/models/gemma3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Run and exactly reproduce gemma3 results!
# mme as an example

accelerate launch --num_processes=8 --main_process_port=12345 -m lmms_eval \
--model gemma3 \
--model_args=pretrained=google/gemma-3-4b-it \
--tasks mmmu_val,ai2d,mathvista_testmini \
--batch_size 1 --output_path ./logs/
Comment thread
Luodian marked this conversation as resolved.
Outdated
1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"fuyu": "Fuyu",
"gemini_api": "GeminiAPI",
"gpt4o_audio": "GPT4OAudio",
"gemma3": "Gemma3",
"gpt4v": "GPT4V",
"idefics2": "Idefics2",
"instructblip": "InstructBLIP",
Expand Down
299 changes: 299 additions & 0 deletions lmms_eval/models/simple/gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
import base64
import os
import re
import uuid
import warnings
from io import BytesIO
from typing import List, Optional, Tuple, Union
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

import decord
import torch
from accelerate import Accelerator, DistributedType
from PIL import Image
from tqdm import tqdm

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")

from loguru import logger as eval_logger
from transformers import AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration


@register_model("gemma3")
class Gemma3(lmms):
"""
Gemma3 Model
https://huggingface.co/google/gemma-3-27b-it
"""

def __init__(
self,
pretrained: str = "google/gemma-3-27b-it",
device: Optional[str] = "cuda",
device_map: Optional[str] = "auto",
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = True,
use_cache=True,
Comment on lines +40 to +42
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

batch_size type mismatch.

Annotated as Optional[Union[int, str]] but cast with int(); non‑numeric strings will crash.

Apply this diff to accept only int:

-        batch_size: Optional[Union[int, str]] = 1,
+        batch_size: int = 1,
@@
-        self.batch_size_per_gpu = int(batch_size)
+        self.batch_size_per_gpu = batch_size

Also applies to: 80-80

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 39-41 (and also at line 80),
the batch_size parameter is annotated as Optional[Union[int, str]] but the
implementation casts it with int(), which will crash on non-numeric strings;
change the annotation to Optional[int] and update the code to accept only ints
by removing string handling/casting (or explicitly validate and raise a clear
TypeError if a non-int is passed), and adjust any callers or docstrings
accordingly so batch_size is strictly an int (or None).

attn_implementation: Optional[str] = None,
min_pixels: int = 256 * 28 * 28,
max_pixels: int = 1605632,
max_num_frames: int = 32,
interleave_visuals: Optional[bool] = False,
system_prompt: Optional[str] = "You are a helpful assistant.",
reasoning_prompt: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator = Accelerator()
Comment on lines +34 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Conflicting kwargs usage for max_length.

You assert kwargs == {} but later read max_length from kwargs; unreachable path.

Apply this diff to make max_length explicit and remove dead kwargs access:

-        system_prompt: Optional[str] = "You are a helpful assistant.",
-        reasoning_prompt: Optional[str] = None,
-        **kwargs,
+        system_prompt: Optional[str] = "You are a helpful assistant.",
+        reasoning_prompt: Optional[str] = None,
+        max_length: int = 2048,
     ) -> None:
         super().__init__()
-        # Do not use kwargs for now
-        assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
+        # NOTE: no extra kwargs are accepted to avoid silent misconfigurations
+        # (add parameters explicitly if needed)
@@
-        self._max_length = kwargs.get("max_length", 2048)
+        self._max_length = int(max_length)

Also applies to: 77-79

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 34 to 55 (and also adjust
lines ~77-79), the constructor asserts kwargs == {} but later code reads
max_length from kwargs, making that path unreachable; add max_length as an
explicit parameter (e.g., max_length: Optional[int] = None) to the __init__
signature, remove any reads of max_length from kwargs, and drop the kwargs
assertion only if other legitimate kwargs are expected (otherwise keep assertion
and ensure no kwargs are used anywhere); update the later block at lines ~77-79
to reference the explicit max_length parameter instead of accessing kwargs.

if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
else:
self._device = torch.device(device)
self.device_map = device_map if device_map else device

# Prepare model loading arguments
model_kwargs = {
"torch_dtype": torch.bfloat16,
"device_map": self.device_map,
}

# Add attention implementation if specified
if attn_implementation is not None:
model_kwargs["attn_implementation"] = attn_implementation

self._model = Gemma3ForConditionalGeneration.from_pretrained(pretrained, **model_kwargs).eval()
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)

Comment on lines +82 to +84
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

AutoTokenizer.from_pretrained does not accept device_map.

Passing device_map will raise a TypeError; remove it from tokenizer init.

Apply this diff:

-        self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
+        self._tokenizer = AutoTokenizer.from_pretrained(
+            pretrained,
+            trust_remote_code=trust_remote_code,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
self._tokenizer = AutoTokenizer.from_pretrained(
pretrained,
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 74-76,
AutoTokenizer.from_pretrained is being called with a device_map keyword which
AutoTokenizer does not accept; remove the device_map argument from the tokenizer
initialization and ensure any device_map usage is applied when loading the model
(not the tokenizer), i.e., call AutoTokenizer.from_pretrained(pretrained,
trust_remote_code=trust_remote_code) and keep device_map only for model loading
logic.

self._config = self._model.config
self._max_length = kwargs.get("max_length", 2048)
self.model.tie_weights()
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache
self.system_prompt = system_prompt
self.interleave_visuals = interleave_visuals

self.max_pixels = max_pixels
self.min_pixels = min_pixels
self.max_num_frames = max_num_frames

Comment on lines +93 to +96
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unused parameters: interleave_visuals, max_num_frames.

They are not used; either implement or remove to avoid confusion.

Would you like me to wire these into visual processing (e.g., limit frame sampling)?

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 85 to 88, the constructor
assigns max_num_frames and ignores interleave_visuals (unused) which causes
confusion; either remove the unused parameters from the signature and delete
their assignments, or persist and use them: store self.interleave_visuals =
interleave_visuals and self.max_num_frames = max_num_frames (if not already),
then wire them into visual processing code—use interleave_visuals to control
whether visual frames are interleaved with other modalities and enforce
self.max_num_frames when sampling/iterating frames (truncate or sample to that
limit) wherever frames are prepared or batched.

if reasoning_prompt:
self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n")
else:
self.reasoning_prompt = None

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self.model.to(self._device)
self._rank = 0
self._world_size = 1
self.model.eval()

@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# return self.tokenizer.eod_id
return self.tokenizer.eos_token_id

@property
def max_length(self):
return self._max_length

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

Comment on lines +122 to +164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add missing type hints and docstrings for public API.

Guidelines require type hints and docstrings; properties and flatten lack them.

Apply this diff:

 class Gemma3(lmms):
-    """
-    Gemma3 Model
-    https://huggingface.co/google/gemma-3-27b-it
-    """
+    """Gemma-3 simple wrapper for lmms-eval.
+
+    Loads a Gemma-3 IT checkpoint and provides batched `generate_until`.
+    """

@@
-    def config(self):
-        # return the associated transformers.AutoConfig for the given pretrained model.
-        return self._config
+    def config(self) -> PretrainedConfig:
+        """Return the underlying Transformers config."""
+        return self._config
@@
-    def tokenizer(self):
-        return self._tokenizer
+    def tokenizer(self) -> PreTrainedTokenizerBase:
+        """Return the tokenizer/processor tokenizer."""
+        return self._tokenizer
@@
-    def model(self):
-        # returns the model, unwrapping it if using Accelerate
+    def model(self) -> Gemma3ForConditionalGeneration:
+        """Return the model, unwrapped if using Accelerate."""
         if hasattr(self, "accelerator"):
             return self.accelerator.unwrap_model(self._model)
         else:
             return self._model
@@
-    def eot_token_id(self):
+    def eot_token_id(self) -> int:
         # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
         # return self.tokenizer.eod_id
         return self.tokenizer.eos_token_id
@@
-    def max_length(self):
+    def max_length(self) -> int:
         return self._max_length
@@
-    def batch_size(self):
+    def batch_size(self) -> int:
         return self.batch_size_per_gpu
@@
-    def device(self):
+    def device(self) -> torch.device:
         return self._device
@@
-    def rank(self):
+    def rank(self) -> int:
         return self._rank
@@
-    def world_size(self):
+    def world_size(self) -> int:
         return self._world_size
@@
-    def flatten(self, input):
-        new_list = []
-        for i in input:
-            for j in i:
-                new_list.append(j)
-        return new_list
+    def flatten(self, items: List[List[str]]) -> List[str]:
+        """Flatten a list-of-lists one level."""
+        return [j for i in items for j in i]
@@
-    def generate_until(self, requests: List[Instance]) -> List[str]:
+    def generate_until(self, requests: List[Instance]) -> List[str]:
+        """Generate responses for each Instance until stop tokens."""
         res = []
@@
-    def generate_until_multi_round(self, requests) -> List[str]:
-        raise NotImplementedError("TODO: Implement multi-round generation")
+    def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
+        """Multi-round chat generation (not yet implemented)."""
+        raise NotImplementedError("TODO: Implement multi-round generation")

Also applies to: 160-166, 167-167, 298-300

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 114 to 156 (also apply similar
changes at 160-166, 167, and 298-300), several public @property methods lack
type hints and docstrings; add explicit return type annotations for each
property (config, tokenizer, model, eot_token_id, max_length, batch_size,
device, rank, world_size) and short one-line docstrings describing the returned
value and types (e.g., "Return the transformers.AutoConfig for the pretrained
model."), and fix the batch_size property to return self.batch_size_per_gpu with
the proper type hint; ensure imports support any forward types if needed and
keep docstrings consistent with project style.

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Not implemented for Gemma3.")

def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests: List[Instance]) -> List[str]:
res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tokenizer.encode(x[0])
return -len(toks), x[0]

pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
gen_kwargs = all_gen_kwargs[0]

# Set default until or update values from gen_kwargs if present
until = gen_kwargs.get("until", [self.tokenizer.decode(self.eot_token_id)])

if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}")

# Avoid using '\n\n' as a stopper to prevent truncation, which can lead to incorrect results
until = [item for item in until if item != "\n\n"]

if isinstance(contexts, tuple):
contexts = list(contexts)

for i in range(len(contexts)):
if "<image>" in contexts[i]:
contexts[i] = contexts[i].replace("<image>", "")

batched_messages = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")

Comment on lines +231 to +239
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Duplicate '' cleanup.

You strip twice (Lines 207–211 and 213–215). Keep one.

Apply this diff to remove the first loop:

-            for i in range(len(contexts)):
-                if "<image>" in contexts[i]:
-                    contexts[i] = contexts[i].replace("<image>", "")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i in range(len(contexts)):
if "<image>" in contexts[i]:
contexts[i] = contexts[i].replace("<image>", "")
batched_messages = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")
batched_messages = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 207 to 215 there is duplicate
removal of the "<image>" token (first loop lines 207–211 and again inside the
batched_messages loop lines 213–215); remove the first standalone loop (lines
207–211) so that "<image>" is only stripped once when building batched_messages,
leaving the single replacement inside the batched_messages creation.

message = [{"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}]

if self.reasoning_prompt:
context = context.strip() + self.reasoning_prompt
contexts[i] = context

processed_visuals = []
for visual in visual_list[i]:
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})
elif isinstance(visual, Image.Image): # Handle both single and multiple images
base64_image = visual.convert("RGB")
buffer = BytesIO()
base64_image.save(buffer, format="JPEG")
base64_bytes = base64.b64encode(buffer.getvalue())
base64_string = base64_bytes.decode("utf-8")
processed_visuals.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}", "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})

message.append(
{
"role": "user",
"content": processed_visuals + [{"type": "text", "text": context}],
}
)

batched_messages.append(message)

inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to(
self.model.device, dtype=torch.bfloat16
)

if self.device_map == "auto":
inputs = inputs.to("cuda")
else:
inputs = inputs.to(self.device)

Comment on lines +278 to +282
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid double .to() and wrap long call.

Inputs are moved twice; also exceed 88 chars.

Apply this diff:

-            inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to(
-                self.model.device, dtype=torch.bfloat16
-            )
-
-            if self.device_map == "auto":
-                inputs = inputs.to("cuda")
-            else:
-                inputs = inputs.to(self.device)
+            inputs = self.processor.apply_chat_template(
+                batched_messages,
+                add_generation_prompt=True,
+                tokenize=True,
+                return_dict=True,
+                return_tensors="pt",
+                padding="max_length",
+                pad_to_multiple_of=8,
+                max_length=self.max_length,
+            )
+            target_device = "cuda" if self.device_map == "auto" else self.device
+            inputs = inputs.to(target_device, dtype=torch.bfloat16)

Also applies to: 243-245

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 243-245 and 247-251, the code
calls .to() twice on inputs and creates lines that exceed the 88-char limit;
replace the conditional device selection with a single device variable (e.g.
device = "cuda" if self.device_map == "auto" else self.device) and then call
inputs = inputs.to(device) once, breaking the line if necessary to keep it under
88 characters.

# Set default generation kwargs
default_gen_kwargs = {
"max_new_tokens": 128,
"temperature": 0.0, # Set to 0 for greedy default
"top_p": None,
"num_beams": 1,
}
# Update with provided kwargs
current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs}

if current_gen_kwargs["temperature"] > 0:
current_gen_kwargs["do_sample"] = True
else:
current_gen_kwargs["do_sample"] = False
current_gen_kwargs["temperature"] = None
current_gen_kwargs["top_p"] = None

Comment on lines +284 to +299
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Generation kwargs: avoid passing None values to generate().

temperature/top_p=None can raise at runtime. Only pass when set; keep greedy defaults.

Apply this diff:

-            default_gen_kwargs = {
-                "max_new_tokens": 128,
-                "temperature": 0.0,  # Set to 0 for greedy default
-                "top_p": None,
-                "num_beams": 1,
-            }
+            default_gen_kwargs = {
+                "max_new_tokens": 128,
+                "temperature": 0.0,  # greedy
+                "num_beams": 1,
+            }
@@
-            if current_gen_kwargs["temperature"] > 0:
-                current_gen_kwargs["do_sample"] = True
-            else:
-                current_gen_kwargs["do_sample"] = False
-                current_gen_kwargs["temperature"] = None
-                current_gen_kwargs["top_p"] = None
+            do_sample = float(current_gen_kwargs.get("temperature", 0.0)) > 0.0
+            top_p = current_gen_kwargs.get("top_p", None)
+            if do_sample and top_p is None:
+                top_p = 0.95
@@
-            cont = self.model.generate(
-                **inputs,
-                do_sample=current_gen_kwargs["do_sample"],
-                temperature=current_gen_kwargs["temperature"],
-                top_p=current_gen_kwargs["top_p"],
-                num_beams=current_gen_kwargs["num_beams"],
-                max_new_tokens=current_gen_kwargs["max_new_tokens"],
-                use_cache=self.use_cache,
-            )
+            gen_args = dict(
+                **inputs,
+                num_beams=current_gen_kwargs["num_beams"],
+                max_new_tokens=current_gen_kwargs["max_new_tokens"],
+                use_cache=self.use_cache,
+            )
+            if do_sample:
+                gen_args.update(do_sample=True, temperature=float(current_gen_kwargs["temperature"]), top_p=float(top_p))
+            cont = self.model.generate(**gen_args)

Also applies to: 269-277

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 253-268 (and also apply same
fix to 269-277), the merged generation kwargs may include keys with value None
which will raise at runtime when passed to generate(); after merging
default_gen_kwargs and gen_kwargs, determine do_sample = True if temperature is
provided and > 0, otherwise set do_sample = False and ensure temperature and
top_p are omitted rather than set to None; finally, remove any keys whose value
is None from current_gen_kwargs (or build current_gen_kwargs only from keys with
non-None values) so generate() only receives valid parameters.

cont = self.model.generate(
**inputs,
do_sample=current_gen_kwargs["do_sample"],
temperature=current_gen_kwargs["temperature"],
top_p=current_gen_kwargs["top_p"],
num_beams=current_gen_kwargs["num_beams"],
max_new_tokens=current_gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
)

generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for i, ans in enumerate(answers):
# print(f"Raw answer {i}: {ans}")
for term in until:
if len(term) > 0:
ans = ans.split(term)[0]
answers[i] = ans

for ans, context in zip(answers, contexts):
res.append(ans)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)

pbar.close()
return res

def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation")