Skip to content

Commit b203e91

Browse files
committed
feat(vlm-clients): pass cachecontrol to the correct part of an anthropic prompt
1 parent 00c9e91 commit b203e91

4 files changed

Lines changed: 252 additions & 52 deletions

File tree

src/paint_by_language_model/services/clients/stroke_vlm_client.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,25 +219,34 @@ def suggest_strokes(
219219

220220
# Query VLM
221221
try:
222-
images: list[tuple[bytes, str]] = []
222+
# Static stroke sample images: routed to ``cached_images`` so they
223+
# form the cached prefix of the user message on Anthropic. These
224+
# are byte-identical across all iterations of a run and are
225+
# therefore ideal cache content. ``cache_control`` is placed by
226+
# ``VLMClient`` on the last cached image block.
227+
cached_images: list[tuple[bytes, str]] = []
223228
allowed_lower = (
224229
[t.lower() for t in self.allowed_stroke_types]
225230
if self.allowed_stroke_types
226231
else None
227232
)
228233
for stroke_type, sample_bytes in self._stroke_samples.items():
229234
if allowed_lower is None or stroke_type.lower() in allowed_lower:
230-
images.append((sample_bytes, f"{stroke_type.upper()} stroke sample"))
231-
images.append((canvas_image, "Current canvas"))
235+
cached_images.append((sample_bytes, f"{stroke_type.upper()} stroke sample"))
236+
237+
# Dynamic per-iteration content: just the current canvas.
238+
images: list[tuple[bytes, str]] = [(canvas_image, "Current canvas")]
232239
logger.debug(
233-
f"Attaching {len(images) - 1} stroke sample image(s) "
240+
f"Attaching {len(cached_images)} stroke sample image(s) as cached prefix "
241+
f"and 1 canvas image as dynamic content "
234242
f"(allowed: {self.allowed_stroke_types or 'all'})"
235243
)
236244

237245
response_text = self.client.query_multimodal_multi_image(
238246
prompt=user_prompt,
239247
images=images,
240248
system_prompt=system_prompt,
249+
cached_images=cached_images,
241250
)
242251

243252
# Store raw response immediately so it is always available,

src/paint_by_language_model/vlm_client.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,13 @@ def _log_request(
8787
payload (dict): Request body sent to the API
8888
response (requests.Response): The HTTP response received
8989
"""
90-
log_dir = Path(GLOBAL_PROMPT_LOG_DIR)
90+
now = datetime.now()
91+
log_dir = (
92+
Path(GLOBAL_PROMPT_LOG_DIR) / f"{now.year:04d}" / f"{now.month:02d}" / f"{now.day:02d}"
93+
)
9194
log_dir.mkdir(parents=True, exist_ok=True)
9295

93-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
96+
timestamp = now.strftime("%Y%m%d_%H%M%S_%f")
9497
log_path = log_dir / f"{timestamp}-request.log"
9598

9699
# Mask sensitive header values
@@ -475,27 +478,71 @@ def _build_multi_image_payload(
475478
max_tokens: int,
476479
*,
477480
system_prompt: str,
481+
cached_images: list[tuple[bytes, str]] | None = None,
478482
) -> dict:
479483
"""
480484
Build request payload for a multi-image multimodal query.
481485
482486
Each image is preceded by a text label block. A final text block
483487
containing the main prompt is appended after all image blocks.
484488
489+
When ``cached_images`` is provided, those (image_bytes, label) pairs are
490+
prepended to the user message before the dynamic ``images``. On
491+
Anthropic, ``cache_control: ephemeral`` is placed on the **last cached
492+
image block**, telling Anthropic to cache the entire prompt prefix up
493+
to and including that block (system prompt + all cached images). The
494+
dynamic per-request ``images`` and the final prompt text follow and are
495+
not cached.
496+
485497
Args:
486498
prompt (str): The main text prompt appended after all images
487-
images (list[tuple[bytes, str]]): List of (image_bytes, label) pairs
499+
images (list[tuple[bytes, str]]): Dynamic per-request
500+
(image_bytes, label) pairs that change between calls.
488501
max_tokens (int): Maximum tokens in the response
489502
system_prompt (str): System-level instructions; provider-agnostic.
490503
Anthropic: placed in top-level ``system`` field as a content
491504
block array with block-level ``cache_control``. OpenAI-compatible
492505
providers: prepended as a ``role: system`` message.
506+
cached_images (list[tuple[bytes, str]] | None): Optional list of
507+
static (image_bytes, label) pairs that are byte-identical
508+
across requests. Prepended to the user message before the
509+
dynamic ``images``. On Anthropic, the **last** of these image
510+
blocks carries ``cache_control: ephemeral`` to mark the cache
511+
prefix boundary. On OpenAI-compatible providers, they are
512+
still prepended (no cache marker — caching is Anthropic-only).
493513
494514
Returns:
495515
dict: Request payload structure for the API
496516
"""
497517
message_content: list[dict] = []
498518

519+
# Prepend cached static images (with cache_control on the last image
520+
# block for Anthropic). Anthropic caches everything up to and
521+
# including the marked block, so the system prefix + all of these
522+
# images become the cache prefix.
523+
cached_list = cached_images or []
524+
for idx, (image_bytes, label) in enumerate(cached_list):
525+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
526+
message_content.append({"type": "text", "text": label})
527+
is_last_cached = idx == len(cached_list) - 1
528+
if self.provider == "anthropic":
529+
image_block: dict = {
530+
"type": "image",
531+
"source": {
532+
"type": "base64",
533+
"media_type": "image/png",
534+
"data": base64_image,
535+
},
536+
}
537+
if is_last_cached:
538+
image_block["cache_control"] = {"type": "ephemeral"}
539+
message_content.append(image_block)
540+
else:
541+
# OpenAI-compatible: use data URL format, no cache support
542+
data_url = f"data:image/png;base64,{base64_image}"
543+
message_content.append({"type": "image_url", "image_url": {"url": data_url}})
544+
545+
# Append dynamic per-request images
499546
for image_bytes, label in images:
500547
base64_image = base64.b64encode(image_bytes).decode("utf-8")
501548
# Label block before each image
@@ -523,6 +570,11 @@ def _build_multi_image_payload(
523570
payload["model"] = self.model
524571

525572
if self.provider == "anthropic":
573+
# When cached_images are present, the cache breakpoint lives on
574+
# the last cached image block in the user message. The system
575+
# block keeps its own breakpoint for text-only / fallback cases,
576+
# which is harmless (Anthropic allows up to 4 breakpoints per
577+
# request and caches the longest matching prefix).
526578
payload["system"] = [
527579
{
528580
"type": "text",
@@ -549,6 +601,7 @@ def query_multimodal_multi_image(
549601
max_tokens: int = MAX_TOKENS,
550602
*,
551603
system_prompt: str,
604+
cached_images: list[tuple[bytes, str]] | None = None,
552605
) -> str:
553606
"""
554607
Send multiple labelled images and a text prompt to the VLM in one request.
@@ -559,13 +612,20 @@ def query_multimodal_multi_image(
559612
560613
Args:
561614
prompt (str): The main text prompt sent after all images
562-
images (list[tuple[bytes, str]]): List of (image_bytes, label) pairs
615+
images (list[tuple[bytes, str]]): Dynamic per-request
616+
(image_bytes, label) pairs that change between calls.
563617
max_tokens (int): Maximum tokens in response (default from config)
564618
system_prompt (str): System-level instructions sent to the model.
565619
Required keyword-only argument. Anthropic: placed in the
566620
top-level ``system`` field as a content block with
567621
``cache_control``. OpenAI-compatible providers: prepended as
568622
the first ``role: system`` message.
623+
cached_images (list[tuple[bytes, str]] | None): Optional list of
624+
static images to include in the cached prompt prefix.
625+
Prepended to the user message. On Anthropic, ``cache_control``
626+
is placed on the last cached image block to mark the cache
627+
boundary. Ignored as a cache marker on OpenAI-compatible
628+
providers (still prepended for content parity).
569629
570630
Returns:
571631
str: The VLM's response text
@@ -575,14 +635,22 @@ def query_multimodal_multi_image(
575635
ValueError: If image encoding fails
576636
requests.RequestException: For other HTTP errors
577637
"""
638+
cached_count = len(cached_images) if cached_images else 0
578639
total_bytes = sum(len(img_bytes) for img_bytes, _ in images)
640+
cached_bytes = sum(len(img_bytes) for img_bytes, _ in cached_images) if cached_images else 0
579641
logger.info(
580-
f"Sending multi-image query to VLM ({len(images)} images, {total_bytes} total bytes)"
642+
f"Sending multi-image query to VLM "
643+
f"({len(images)} dynamic images / {total_bytes} bytes; "
644+
f"{cached_count} cached images / {cached_bytes} bytes)"
581645
)
582646

583647
try:
584648
payload = self._build_multi_image_payload(
585-
prompt, images, max_tokens, system_prompt=system_prompt
649+
prompt,
650+
images,
651+
max_tokens,
652+
system_prompt=system_prompt,
653+
cached_images=cached_images,
586654
)
587655

588656
# Retry loop for rate limiting
@@ -624,6 +692,7 @@ def query_multimodal_multi_image(
624692
response.raise_for_status()
625693

626694
response_data = response.json()
695+
self.last_usage = response_data.get("usage")
627696
response_text: str = self._extract_response_text(response_data)
628697

629698
logger.info(f"Received VLM response ({len(response_text)} characters)")

tests/test_stroke_vlm_client.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,12 @@ def test_sample_generator_initialized_at_init() -> None:
6666

6767

6868
def test_suggest_strokes_sends_sample_images() -> None:
69-
"""suggest_strokes() calls query_multimodal_multi_image with canvas + 9 sample images.
69+
"""suggest_strokes() routes stroke samples to cached_images and canvas to images.
7070
7171
Verifies:
7272
- ``query_multimodal_multi_image`` is called (not ``query_multimodal``)
73-
- The ``images`` argument contains exactly 11 entries (1 canvas + 10 samples)
74-
- The first image label is ``"Current canvas""
75-
- The remaining labels match the expected stroke sample names
73+
- The ``images`` argument contains exactly 1 entry (the current canvas)
74+
- The ``cached_images`` argument contains the 10 stroke sample entries
7675
- ``system_prompt`` keyword argument is passed
7776
"""
7877
client = StrokeVLMClient()
@@ -107,22 +106,18 @@ def test_suggest_strokes_sends_sample_images() -> None:
107106
assert isinstance(call_kwargs["system_prompt"], str)
108107
assert len(call_kwargs["system_prompt"]) > 0
109108

110-
# Inspect the images argument (passed as keyword argument)
111-
images: list[tuple[bytes, str]] = (
112-
call_kwargs.get("images") or mock_multi.call_args.args[1]
113-
)
114-
115-
assert len(images) == 11, (
116-
f"Expected 11 images (1 canvas + 10 samples), got {len(images)}"
117-
)
109+
# images = dynamic per-iteration content (just the canvas)
110+
images: list[tuple[bytes, str]] = call_kwargs["images"]
111+
assert len(images) == 1, f"Expected 1 image (canvas only), got {len(images)}"
112+
assert images[0][1] == "Current canvas"
113+
assert images[0][0] == b"fake_canvas_bytes"
118114

119-
# Last entry must be the current canvas
120-
assert images[-1][1] == "Current canvas", (
121-
f"Last image label should be 'Current canvas', got '{images[-1][1]}'"
115+
# cached_images = static prefix content (the 10 stroke samples)
116+
cached_images: list[tuple[bytes, str]] = call_kwargs["cached_images"]
117+
assert len(cached_images) == 10, (
118+
f"Expected 10 stroke samples in cached_images, got {len(cached_images)}"
122119
)
123-
124-
# First 10 labels must be the stroke sample labels
125-
sample_labels = {label for _, label in images[:-1]}
120+
sample_labels = {label for _, label in cached_images}
126121
assert sample_labels == _EXPECTED_SAMPLE_LABELS, (
127122
f"Sample labels mismatch. Expected {_EXPECTED_SAMPLE_LABELS}, got {sample_labels}"
128123
)
@@ -234,9 +229,8 @@ def test_suggest_strokes_filters_samples_to_allowed_type() -> None:
234229
"""suggest_strokes() only attaches sample images for allowed stroke types.
235230
236231
When ``allowed_stroke_types=["line"]`` is set, exactly one sample image
237-
(the LINE sample) should be appended beyond the canvas image, giving a total
238-
of 2 entries in the ``images`` argument passed to
239-
``query_multimodal_multi_image``.
232+
(the LINE sample) should appear in ``cached_images``. ``images`` always
233+
contains only the current canvas.
240234
"""
241235
client = StrokeVLMClient(allowed_stroke_types=["line"])
242236

@@ -253,27 +247,25 @@ def test_suggest_strokes_filters_samples_to_allowed_type() -> None:
253247
)
254248

255249
mock_multi.assert_called_once()
256-
call_kwargs = mock_multi.call_args
257-
images: list[tuple[bytes, str]] = (
258-
call_kwargs.kwargs.get("images") or call_kwargs.args[1]
259-
)
250+
call_kwargs = mock_multi.call_args.kwargs
260251

261-
assert len(images) == 2, (
262-
f"Expected 2 images (1 canvas + 1 allowed sample), got {len(images)}"
263-
)
264-
assert images[0][1] == "LINE stroke sample", (
265-
f"First image label should be 'LINE stroke sample', got '{images[0][1]}'"
266-
)
267-
assert images[-1][1] == "Current canvas", (
268-
f"Last image label should be 'Current canvas', got '{images[-1][1]}'"
252+
images: list[tuple[bytes, str]] = call_kwargs["images"]
253+
assert len(images) == 1, f"Expected 1 image (canvas), got {len(images)}"
254+
assert images[0][1] == "Current canvas"
255+
256+
cached_images: list[tuple[bytes, str]] = call_kwargs["cached_images"]
257+
assert len(cached_images) == 1, (
258+
f"Expected 1 sample in cached_images (LINE only), got {len(cached_images)}"
269259
)
260+
assert cached_images[0][1] == "LINE stroke sample"
270261

271262

272263
def test_suggest_strokes_sends_all_samples_when_allowed_none() -> None:
273-
"""suggest_strokes() attaches all sample images when allowed_stroke_types is None.
264+
"""suggest_strokes() attaches all sample images to cached_images when allowed is None.
274265
275266
When no ``allowed_stroke_types`` restriction is set (the default), all ten
276-
stroke sample images should be attached giving 11 total (canvas + 10 samples).
267+
stroke sample images should appear in ``cached_images``. ``images`` contains
268+
only the current canvas.
277269
"""
278270
client = StrokeVLMClient() # allowed_stroke_types defaults to None
279271

@@ -290,15 +282,17 @@ def test_suggest_strokes_sends_all_samples_when_allowed_none() -> None:
290282
)
291283

292284
mock_multi.assert_called_once()
293-
call_kwargs = mock_multi.call_args
294-
images: list[tuple[bytes, str]] = (
295-
call_kwargs.kwargs.get("images") or call_kwargs.args[1]
296-
)
285+
call_kwargs = mock_multi.call_args.kwargs
286+
287+
images: list[tuple[bytes, str]] = call_kwargs["images"]
288+
assert len(images) == 1
289+
assert images[0][1] == "Current canvas"
297290

298-
assert len(images) == 11, (
299-
f"Expected 11 images (1 canvas + 10 samples), got {len(images)}"
291+
cached_images: list[tuple[bytes, str]] = call_kwargs["cached_images"]
292+
assert len(cached_images) == 10, (
293+
f"Expected 10 stroke samples in cached_images, got {len(cached_images)}"
300294
)
301-
sample_labels = {label for _, label in images[:-1]}
295+
sample_labels = {label for _, label in cached_images}
302296
assert sample_labels == _EXPECTED_SAMPLE_LABELS, (
303297
f"Sample labels mismatch. Expected {_EXPECTED_SAMPLE_LABELS}, got {sample_labels}"
304298
)

0 commit comments

Comments
 (0)