Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions olive/cli/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser):
help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.",
)

lmeval_group.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
default=False,
help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).",
)

add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
Expand Down Expand Up @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict:
("evaluators", "evaluator", "model_class"),
None if self.args.backend == "auto" else self.args.backend,
),
(
("evaluators", "evaluator", "confirm_run_unsafe_code"),
self.args.confirm_run_unsafe_code or None,
Comment thread
natke marked this conversation as resolved.
Outdated
),
]

for keys, value in to_replace:
Expand Down
119 changes: 117 additions & 2 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl
raise NotImplementedError("Yet to be implemented!")

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
raise NotImplementedError("Yet to be implemented!")
raise NotImplementedError(
"generate_until is not supported by this model backend. "
"Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval."
)


@register_model("ort")
Expand Down Expand Up @@ -509,7 +512,16 @@ def __init__(
self.max_length = max_length
else:
self.max_length = genai_config["search"]["max_length"]
self._eot_token_id = genai_config["model"]["eos_token_id"]
eos = genai_config["model"]["eos_token_id"]
# eos_token_id can be a single int or a list of ints
if isinstance(eos, list):
if not eos:
raise ValueError("genai_config model.eos_token_id must not be an empty list")
self._eot_token_id = eos[0]
self.eos_token_ids = set(eos)
else:
self._eot_token_id = eos
self.eos_token_ids = {eos}
self.params = og.GeneratorParams(self.model)
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)

Expand Down Expand Up @@ -573,5 +585,108 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor
# seq dimension so the continuation slice still lands on the correct positions.
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is found or max tokens reached.

Supports generative evaluation tasks such as MBPP and HumanEval.
Each request is a tuple of (context_string, gen_kwargs_dict).
"""
results = []
for request in requests:
context = request.args[0]
gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {}

# Extract stop sequences
until = gen_kwargs.get("until", [])
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
until = [until]
until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq]
Comment thread
natke marked this conversation as resolved.
Outdated

# Extract generation parameters
max_gen_toks = gen_kwargs.get("max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens")))
try:
max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256
except (TypeError, ValueError):
max_gen_toks = 256
max_gen_toks = max(max_gen_toks, 0)
temperature = gen_kwargs.get("temperature", 0.0)
do_sample = gen_kwargs.get("do_sample", temperature > 0)
Comment thread
natke marked this conversation as resolved.
Outdated

# Tokenize the prompt
prompt_ids = self.tokenizer.encode(context).tolist()
prompt_len = len(prompt_ids)

Comment thread
natke marked this conversation as resolved.
# Compute total max_length: prompt + new tokens, capped by model limit
total_max_length = min(prompt_len + max_gen_toks, self.max_length)

# If the prompt already fills or exceeds the model limit, no generation is possible.
if prompt_len >= self.max_length or max_gen_toks == 0:
results.append("")
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, "")
continue

# Create fresh generator params per request to avoid state leakage
params = og.GeneratorParams(self.model)
search_options = {
"max_length": total_max_length,
"past_present_share_buffer": False,
"batch_size": 1,
}
if do_sample:
search_options["temperature"] = temperature
else:
search_options["temperature"] = 0.0
params.set_search_options(**search_options)
Comment thread
natke marked this conversation as resolved.

# Run generation token by token to check for stop sequences
generator = og.Generator(self.model, params)
generator.append_tokens([prompt_ids])

generated_chunks = []
stop_idx = None
# Tail buffer wide enough to detect any stop sequence across chunk boundaries
max_stop_len = max((len(s) for s in until), default=0)
Comment thread
natke marked this conversation as resolved.
Outdated

while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_sequence(0)[-1]

# Check for EOS token(s)
if new_token in self.eos_token_ids:
break

chunk = self.tokenizer.decode([new_token])
generated_chunks.append(chunk)

# Check stop sequences against a tail window to avoid O(n²) full join
if until:
tail = "".join(generated_chunks[-(max_stop_len + 1) :]) if max_stop_len else ""
Comment thread
natke marked this conversation as resolved.
Outdated
tail_offset = len("".join(generated_chunks)) - len(tail)
Comment thread
natke marked this conversation as resolved.
Outdated
earliest = None
for stop_seq in until:
idx = tail.find(stop_seq)
if idx != -1:
abs_idx = tail_offset + idx
if earliest is None or abs_idx < earliest:
earliest = abs_idx
if earliest is not None:
stop_idx = earliest
break

generated_text = "".join(generated_chunks) if stop_idx is None else "".join(generated_chunks)[:stop_idx]
Comment thread
natke marked this conversation as resolved.
Outdated
Comment thread
natke marked this conversation as resolved.
Outdated

results.append(generated_text)

# lm-eval cache hook
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, generated_text)

return results

def complete(self):
pass
2 changes: 2 additions & 0 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ def __init__(self, tasks: list[str], **kwargs):
self.ep = kwargs.get("execution_provider")
self.ep_options = kwargs.get("provider_options")
self.device = kwargs.get("device")
self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False)

def evaluate(
self,
Expand Down Expand Up @@ -1108,6 +1109,7 @@ def evaluate(
batch_size=self.batch_size,
device=device,
limit=self.limit,
confirm_run_unsafe_code=self.confirm_run_unsafe_code,
)
Comment thread
natke marked this conversation as resolved.
Outdated

for task_name in sorted(results["results"].keys()):
Expand Down
Loading
Loading