Skip to content

Commit c19fd97

Browse files
committed
refactor: remove provider-local caches in favor of ResponseCache
1 parent 0f33ba0 commit c19fd97

15 files changed

Lines changed: 185 additions & 323 deletions

lmms_eval/api/instance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class Instance:
7171
request_type: Literal["loglikelihood", "generate_until", "generate_until_multi_round", "generate_until_agentic"]
7272
arguments: tuple
7373
idx: int
74-
metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here
74+
metadata: Dict[str, Union[str, int]] = field(default_factory=dict)
7575
resps: list = field(default_factory=list)
7676
filtered_resps: dict = field(default_factory=dict)
7777
raw_filtered_resps: dict = field(default_factory=dict)

lmms_eval/caching/response_cache.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
Activation: ``python -m lmms_eval --model ... --tasks ... --use_cache ./eval_cache``
88
9-
Cache key: sha256(request_type, task_name, doc_id, idx, canonical gen_kwargs).
9+
Cache key: sha256(request_type, task_name, doc_id, idx, canonical gen_kwargs, content_hash).
1010
Scoped per model: ``{use_cache}/{model_hash}/rank{N}.db``
1111
"""
1212

@@ -164,12 +164,14 @@ def fingerprint_callable(fn) -> str:
164164

165165

166166
def _extract_content_hash(instance: Instance) -> str:
167-
"""Hash the text content of loglikelihood args to prevent collisions.
167+
"""Hash leading text arguments to prevent cache-key collisions.
168168
169-
For multiple_choice with acc_mutual_info, conditional requests have
170-
``(ctx, continuation, ...)`` while unconditional have ``("", choice)``.
171-
Both share the same (task_name, doc_id, idx) so we need this hash
172-
to distinguish them.
169+
Some flows can issue multiple deterministic requests that share the same
170+
``(task_name, doc_id, idx, gen_kwargs)`` while differing in prompt text.
171+
This is common in multi-round / agentic generation loops.
172+
173+
We hash the leading consecutive string arguments (for example context and
174+
continuation) so those requests do not alias to the same cache entry.
173175
"""
174176
args = instance.args
175177
text_parts = []
@@ -375,7 +377,7 @@ def execute(self, lm: Any, reqtype: str, requests: List[Instance]) -> list:
375377
self._skipped += 1
376378
continue
377379

378-
ch = _extract_content_hash(req) if reqtype == "loglikelihood" else ""
380+
ch = _extract_content_hash(req)
379381
tf = self._task_fingerprints.get(req.task_name, "")
380382
cache_key = compute_cache_key(
381383
request_type=reqtype,
@@ -406,7 +408,7 @@ def execute(self, lm: Any, reqtype: str, requests: List[Instance]) -> list:
406408
cacheable = self._extract_cacheable(resp)
407409
gen_kwargs = extract_gen_kwargs(req)
408410
deterministic = is_deterministic(reqtype, gen_kwargs)
409-
ch = _extract_content_hash(req) if reqtype == "loglikelihood" else ""
411+
ch = _extract_content_hash(req)
410412
tf = self._task_fingerprints.get(req.task_name, "")
411413
cache_key = compute_cache_key(request_type=reqtype, task_name=req.task_name, doc_id=req.doc_id, gen_kwargs=gen_kwargs, idx=req.idx, content_hash=ch, task_fingerprint=tf) if deterministic else ""
412414
self._log_to_audit(reqtype, req.task_name, req.doc_id, req.idx, gen_kwargs, cacheable, cache_key=cache_key, deterministic=deterministic)

lmms_eval/evaluator.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,12 @@ def _adjust_config(task_dict):
463463
decontaminate_suffix = "_decontaminate"
464464

465465

466-
def _run_generate_until_agentic(lm, requests: list[Instance], agentic_trace_mode: str = "basic") -> list[str]:
466+
def _run_generate_until_agentic(
467+
lm,
468+
requests: list[Instance],
469+
agentic_trace_mode: str = "basic",
470+
response_cache: Optional[ResponseCache] = None,
471+
) -> list[str]:
467472
responses: list[str] = []
468473

469474
for req in requests:
@@ -522,7 +527,11 @@ def _agentic_doc_to_messages(_doc):
522527
idx=0,
523528
metadata=req.metadata,
524529
)
525-
current_output = lm.generate_until([single_req])[0]
530+
if response_cache is not None:
531+
current_raw_output = response_cache.execute(lm, "generate_until", [single_req])[0]
532+
else:
533+
current_raw_output = lm.generate_until([single_req])[0]
534+
current_output, _ = unwrap_generation_output(current_raw_output)
526535
model_outputs.append(current_output)
527536
final_response = current_output
528537

@@ -607,7 +616,7 @@ def _agentic_doc_to_messages(_doc):
607616

608617
@positional_deprecated
609618
def evaluate(
610-
lm: "LM",
619+
lm,
611620
task_dict,
612621
limit: Optional[int] = None,
613622
offset: int = 0,
@@ -694,7 +703,7 @@ def evaluate(
694703
lm.accelerator = Accelerator()
695704

696705
for task_output in eval_tasks:
697-
task: Task = task_output.task
706+
task = task_output.task
698707
task_name = task_output.task_name
699708
task.args = cli_args
700709

@@ -790,7 +799,12 @@ def evaluate(
790799
trace_mode = "basic"
791800
if cli_args is not None:
792801
trace_mode = getattr(cli_args, "agentic_trace_mode", "basic")
793-
resps = _run_generate_until_agentic(lm, cloned_reqs, agentic_trace_mode=trace_mode)
802+
resps = _run_generate_until_agentic(
803+
lm,
804+
cloned_reqs,
805+
agentic_trace_mode=trace_mode,
806+
response_cache=response_cache,
807+
)
794808
elif response_cache is not None:
795809
resps = response_cache.execute(lm, reqtype, cloned_reqs)
796810
else:

lmms_eval/models/chat/bagel_lmms_engine.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ def __init__(
6262
text_temperature: float = 0.3,
6363
seed: int = 0,
6464
image_ratio: str = "1:1",
65-
continual_mode: bool = True,
66-
response_persistent_folder: Optional[str] = None,
6765
device: Optional[str] = "cuda",
6866
device_map: Optional[str] = None,
6967
**kwargs,
@@ -74,7 +72,6 @@ def __init__(
7472
self.load_in_4bit = load_in_4bit
7573
self.load_in_8bit = load_in_8bit
7674
self.show_thinking = show_thinking
77-
self.continual_mode = continual_mode
7875

7976
# Generation hyperparameters
8077
self.cfg_text_scale = cfg_text_scale
@@ -106,7 +103,7 @@ def __init__(
106103
self.image_shapes = (1024, 1024)
107104

108105
if output_image_dir is None:
109-
self.output_image_dir = os.path.join(self.response_persistent_folder, "bagel_generated_images")
106+
self.output_image_dir = "./logs/bagel_generated_images"
110107
else:
111108
self.output_image_dir = output_image_dir
112109

lmms_eval/models/chat/openai.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import time
32
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
43
from typing import List, Union
@@ -71,13 +70,14 @@ def generate_until(self, requests) -> List[GenerationResult]:
7170
latencies: List[float] = []
7271
completed_since_adapt = 0
7372
in_flight = {}
74-
doc_uuids: List[Union[str, None]] = [None] * len(reordered_requests)
7573
max_workers = max(
7674
1,
7775
self.adaptive_config.max_concurrency if self.adaptive_concurrency else current_concurrency,
7876
)
7977

80-
def process_single_request(local_index: int, payload: dict):
78+
def process_single_request(local_index: int, payload: dict | None):
79+
if payload is None:
80+
return "", local_index, False, False, 0.0, 0, 0, 0
8181
started_at = time.time()
8282
rate_limited = False
8383
last_error_msg = "unknown error"
@@ -170,15 +170,9 @@ def maybe_update_concurrency(force: bool = False) -> None:
170170
latencies = []
171171
completed_since_adapt = 0
172172

173-
def build_payload_for_index(global_index: int):
173+
def build_payload_for_index(global_index: int) -> dict:
174174
req = reordered_requests[global_index]
175175
_, doc_to_messages, gen_kwargs, doc_id, task, split = req.args
176-
doc_uuid = f"{task}___{split}___{doc_id}"
177-
178-
if self.continual_mode and self.cache_mode == "resume":
179-
cached_response = self.response_cache.get(doc_uuid)
180-
if cached_response:
181-
return doc_uuid, cached_response, None
182176

183177
chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id])
184178
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw})
@@ -199,16 +193,15 @@ def build_payload_for_index(global_index: int):
199193
payload["response_format"] = {"type": "text"}
200194
payload["max_completion_tokens"] = 5000
201195

202-
return doc_uuid, None, payload
196+
return payload
203197

204198
with ThreadPoolExecutor(max_workers=max_workers) as executor:
205199
while cursor < len(dispatch_order) or in_flight:
206200
while cursor < len(dispatch_order) and len(in_flight) < max(1, current_concurrency):
207201
request_index = dispatch_order[cursor]
208-
doc_uuid, cached_response, payload = build_payload_for_index(request_index)
209-
doc_uuids[request_index] = doc_uuid
210-
if cached_response is not None:
211-
responses[request_index] = GenerationResult(text=cached_response, token_counts=TokenCounts())
202+
payload = build_payload_for_index(request_index)
203+
if payload is None:
204+
responses[request_index] = GenerationResult(text="", token_counts=TokenCounts())
212205
pbar.update(1)
213206
cursor += 1
214207
continue
@@ -255,19 +248,13 @@ def build_payload_for_index(global_index: int):
255248
if rate_limited:
256249
rate_limited_requests += 1
257250
completed_since_adapt += 1
258-
if self.continual_mode and doc_uuids[local_index] is not None:
259-
self.response_cache[doc_uuids[local_index]] = response_text
260251
totals = get_running_totals()
261252
pbar.set_postfix({"tokens": f"{totals['total_tokens']:,}"}, refresh=False)
262253
pbar.update(1)
263254
maybe_update_concurrency(force=False)
264255

265256
maybe_update_concurrency(force=True)
266257

267-
if self.continual_mode:
268-
with open(self.response_persistent_file, "w") as f:
269-
json.dump(self.response_cache, f)
270-
271258
avg_speed = total_tokens / total_latency if total_latency > 0 else 0
272259
log_metrics(
273260
total_elapsed_time=total_latency,

lmms_eval/models/simple/claude.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
import time
43
from copy import deepcopy
@@ -43,8 +42,6 @@ def __init__(
4342
system_prompt: str = "", # Whether you want some special system prompt here
4443
modality: str = "image",
4544
max_frames_num: int = 10,
46-
continual_mode: bool = False,
47-
response_persistent_folder: str = None,
4845
**kwargs,
4946
) -> None:
5047
super().__init__()
@@ -53,24 +50,6 @@ def __init__(
5350
self.system_prompt = system_prompt
5451
self.modality = modality
5552
self.max_frames_num = max_frames_num
56-
57-
self.continual_mode = continual_mode
58-
if self.continual_mode:
59-
if response_persistent_folder is None:
60-
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")
61-
62-
os.makedirs(response_persistent_folder, exist_ok=True)
63-
self.response_persistent_folder = response_persistent_folder
64-
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
65-
66-
if os.path.exists(self.response_persistent_file):
67-
with open(self.response_persistent_file, "r") as f:
68-
self.response_cache = json.load(f)
69-
self.cache_mode = "resume"
70-
else:
71-
self.response_cache = {}
72-
self.cache_mode = "start"
73-
7453
accelerator = Accelerator()
7554
if accelerator.num_processes > 1:
7655
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
@@ -174,16 +153,6 @@ def generate_until(self, requests) -> List[GenerationResult]:
174153
res.append(GenerationResult(text="", token_counts=None))
175154
pbar.update(1)
176155
continue
177-
###################### CONTINUAL MODE ######################
178-
if self.continual_mode is True and self.cache_mode == "resume":
179-
doc_uuid = f"{task}___{split}___{doc_id}"
180-
if doc_uuid in self.response_cache:
181-
response_text = self.response_cache[doc_uuid]
182-
if response_text:
183-
res.append(GenerationResult(text=response_text, token_counts=None))
184-
pbar.update(1)
185-
continue
186-
187156
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
188157
visuals = self.flatten(visuals)
189158
imgs = []
@@ -271,14 +240,6 @@ def generate_until(self, requests) -> List[GenerationResult]:
271240
res.append(GenerationResult(text=response_text, token_counts=token_counts))
272241
pbar.update(1)
273242

274-
###################### CONTINUAL MODE ######################
275-
if self.continual_mode is True: # Cache the response
276-
response_text = message.content[0].text
277-
doc_uuid = f"{task}___{split}___{doc_id}"
278-
self.response_cache[doc_uuid] = response_text
279-
with open(self.response_persistent_file, "w") as f:
280-
json.dump(self.response_cache, f, indent=4)
281-
282243
pbar.close()
283244

284245
return res

lmms_eval/models/simple/gemini_api.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
import json
32
import os
43
import pathlib
54
import re
@@ -41,38 +40,17 @@ def __init__(
4140
model_version: str = "gemini-1.5-pro",
4241
# modality: str = "image",
4342
timeout: int = 120,
44-
continual_mode: bool = True,
45-
response_persistent_folder: str = "./logs/gemini_persistent_folder",
4643
interleave: bool = False,
47-
# We will cache the Gemini API response in this path and use it for future requests
4844
**kwargs,
4945
) -> None:
5046
super().__init__()
5147
self.model_version = model_version
5248
self.timeout = timeout
5349
self.model = genai.GenerativeModel(model_version)
54-
self.continual_mode = continual_mode
55-
self.response_persistent_file = ""
5650
self.interleave = interleave
57-
# if self.continual_mode and response_persistent_folder is None:
58-
# raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.")
59-
if self.continual_mode:
60-
self.response_persistent_folder = response_persistent_folder
61-
if not os.path.exists(self.response_persistent_folder):
62-
os.makedirs(self.response_persistent_folder)
63-
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
64-
65-
if os.path.exists(self.response_persistent_file):
66-
with open(self.response_persistent_file, "r") as f:
67-
self.response_cache = json.load(f)
68-
self.cache_mode = "resume"
69-
else:
70-
self.response_cache = {}
71-
self.cache_mode = "start"
7251

7352
accelerator = Accelerator()
7453
if accelerator.num_processes > 1:
75-
assert self.continual_mode is False, "Continual mode is not supported with distributed inference."
7654
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
7755
self.accelerator = accelerator
7856
if self.accelerator.is_local_main_process:
@@ -163,15 +141,6 @@ def get_uuid(task, split, doc_id):
163141
res.append(GenerationResult(text="", token_counts=None))
164142
pbar.update(1)
165143
continue
166-
if self.continual_mode and self.cache_mode == "resume":
167-
doc_uuid = get_uuid(task, split, doc_id)
168-
if doc_uuid in self.response_cache:
169-
content = self.response_cache[doc_uuid]
170-
if content:
171-
res.append(GenerationResult(text=content, token_counts=None))
172-
pbar.update(1)
173-
continue
174-
175144
if "max_new_tokens" not in gen_kwargs:
176145
gen_kwargs["max_new_tokens"] = 1024
177146
if "temperature" not in gen_kwargs:
@@ -239,12 +208,6 @@ def get_uuid(task, split, doc_id):
239208

240209
self.free_video()
241210

242-
if self.continual_mode is True: # Cache the response
243-
doc_uuid = get_uuid(task, split, doc_id)
244-
self.response_cache[doc_uuid] = content
245-
with open(self.response_persistent_file, "w") as f:
246-
json.dump(self.response_cache, f)
247-
248211
pbar.close()
249212
return res
250213

0 commit comments

Comments
 (0)