Skip to content

Commit eef8e0c

Browse files
robot-ci-heartexniklubnik
authored
feat: DIA-1969: Support MIG Image use case (#343)
Co-authored-by: niklub <[email protected]> Co-authored-by: nik <[email protected]> Co-authored-by: niklub <[email protected]>
1 parent 777074f commit eef8e0c

File tree

10 files changed

+487
-278
lines changed

10 files changed

+487
-278
lines changed

Diff for: adala/runtimes/_litellm.py

+10
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,16 @@ async def batch_to_batch(
539539
)
540540

541541
extra_fields = extra_fields or {}
542+
input_field_types = input_field_types or {}
542543
records = batch.to_dict(orient="records")
544+
# in multi-image cases, the number of tokens can be too large for the context window
545+
# so we need to split the payloads into chunks
546+
# we use this heuristic for MIG projects as they more likely to have multi-image inputs
547+
# for other data types, we skip checking the context window as it will be slower
548+
ensure_messages_fit_in_context_window = any(
549+
input_field_types.get(field) == MessageChunkType.IMAGE_URLS
550+
for field in input_field_types
551+
)
543552

544553
df_data = await arun_instructor_with_payloads(
545554
client=self.client,
@@ -556,6 +565,7 @@ async def batch_to_batch(
556565
instructions_first=instructions_first,
557566
instructions_template=instructions_template,
558567
extra_fields=extra_fields,
568+
ensure_messages_fit_in_context_window=ensure_messages_fit_in_context_window,
559569
**self.model_extra,
560570
)
561571

Diff for: adala/skills/collection/label_studio.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,15 @@ async def aapply(
163163
f"Image tag {tag.name} has multiple variables: {variables}. Cannot mark these variables as image inputs."
164164
)
165165
continue
166-
input_field_types[variables[0]] = MessageChunkType.IMAGE_URL
166+
input_field_types[variables[0]] = (
167+
MessageChunkType.IMAGE_URLS
168+
if tag.attr.get("valueList")
169+
else MessageChunkType.IMAGE_URL
170+
)
171+
172+
logger.debug(
173+
f"Using VisionRuntime with input field types: {input_field_types}"
174+
)
167175
output = await runtime.batch_to_batch(
168176
input,
169177
input_template=self.input_template,

0 commit comments

Comments
 (0)