Skip to content

Commit 75247bb

Browse files
committed
fix(vertexai): aggregate streamed meta tool calls
1 parent d9292c7 commit 75247bb

2 files changed

Lines changed: 128 additions & 4 deletions

File tree

packages/genkit_vertexai/lib/src/meta_model.dart

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ ModelResponse fromMetaChatCompletionResponse(Map<String, dynamic> response) {
182182

183183
ModelResponse fromMetaChatCompletionChunks(List<Map<String, dynamic>> chunks) {
184184
final content = StringBuffer();
185+
final toolCallDeltas = <int, Map<String, dynamic>>{};
185186
var finishReason = FinishReason.unknown;
186187
GenerationUsage? usage;
187188

@@ -192,23 +193,94 @@ ModelResponse fromMetaChatCompletionChunks(List<Map<String, dynamic>> chunks) {
192193
final delta = choice['delta'] as Map<String, dynamic>?;
193194
final text = delta?['content'] as String?;
194195
if (text != null) content.write(text);
196+
_accumulateMetaToolCalls(toolCallDeltas, delta?['tool_calls'] as List?);
195197
if (choice['finish_reason'] != null) {
196198
finishReason = _mapMetaFinishReason(choice['finish_reason'] as String?);
197199
}
198200
usage = _fromMetaUsage(chunk['usage'] as Map<String, dynamic>?) ?? usage;
199201
}
200202

203+
final toolCalls = _toMetaAccumulatedToolCalls(toolCallDeltas);
204+
final message = _fromMetaMessage({
205+
'content': '$content',
206+
if (toolCalls.isNotEmpty) 'tool_calls': toolCalls,
207+
});
208+
201209
return ModelResponse(
202210
finishReason: finishReason,
203-
message: Message(
204-
role: Role.model,
205-
content: [TextPart(text: '$content')],
206-
),
211+
message: message.content.isEmpty
212+
? Message(
213+
role: Role.model,
214+
content: [TextPart(text: '$content')],
215+
)
216+
: message,
207217
raw: {'chunks': chunks},
208218
usage: usage,
209219
);
210220
}
211221

222+
void _accumulateMetaToolCalls(
223+
Map<int, Map<String, dynamic>> accumulated,
224+
List? toolCalls,
225+
) {
226+
if (toolCalls == null) return;
227+
for (final toolCall in toolCalls) {
228+
if (toolCall is! Map) continue;
229+
final delta = toolCall.cast<String, dynamic>();
230+
final index = (delta['index'] as num?)?.toInt() ?? accumulated.length;
231+
final target = accumulated.putIfAbsent(
232+
index,
233+
() => {
234+
'type': 'function',
235+
'function': <String, dynamic>{'arguments': ''},
236+
},
237+
);
238+
239+
final id = delta['id'] as String?;
240+
if (id != null && id.isNotEmpty) target['id'] = id;
241+
242+
final type = delta['type'] as String?;
243+
if (type != null && type.isNotEmpty) target['type'] = type;
244+
245+
final function = delta['function'];
246+
if (function is! Map) continue;
247+
final targetFunction = (target['function'] as Map).cast<String, dynamic>();
248+
final functionDelta = function.cast<String, dynamic>();
249+
250+
final name = functionDelta['name'] as String?;
251+
if (name != null && name.isNotEmpty) targetFunction['name'] = name;
252+
253+
final arguments = functionDelta['arguments'] as String?;
254+
if (arguments != null) {
255+
targetFunction['arguments'] = '${targetFunction['arguments']}$arguments';
256+
}
257+
}
258+
}
259+
260+
List<Map<String, dynamic>> _toMetaAccumulatedToolCalls(
261+
Map<int, Map<String, dynamic>> accumulated,
262+
) {
263+
final entries = accumulated.entries.toList()
264+
..sort((a, b) => a.key.compareTo(b.key));
265+
266+
return entries
267+
.map((entry) {
268+
final toolCall = entry.value;
269+
final function = (toolCall['function'] as Map).cast<String, dynamic>();
270+
if ((function['arguments'] as String).isEmpty) {
271+
function['arguments'] = '{}';
272+
}
273+
return toolCall;
274+
})
275+
.where((toolCall) {
276+
final id = toolCall['id'] as String?;
277+
final function = (toolCall['function'] as Map).cast<String, dynamic>();
278+
final name = function['name'] as String?;
279+
return id != null && id.isNotEmpty && name != null && name.isNotEmpty;
280+
})
281+
.toList();
282+
}
283+
212284
List<Map<String, dynamic>> _toMetaMessages(List<Message> messages) {
213285
final result = <Map<String, dynamic>>[];
214286
for (final message in messages) {

packages/genkit_vertexai/test/meta_test.dart

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,57 @@ void main() {
123123
),
124124
);
125125
});
126+
127+
test('aggregates streamed tool call chunks', () {
128+
final response = fromMetaChatCompletionChunks([
129+
{
130+
'choices': [
131+
{
132+
'delta': {
133+
'tool_calls': [
134+
{
135+
'index': 0,
136+
'id': 'call_weather',
137+
'type': 'function',
138+
'function': {'name': 'getWeather', 'arguments': '{"loc'},
139+
},
140+
],
141+
},
142+
},
143+
],
144+
},
145+
{
146+
'choices': [
147+
{
148+
'delta': {
149+
'tool_calls': [
150+
{
151+
'index': 0,
152+
'function': {'arguments': 'ation":"Boston"}'},
153+
},
154+
],
155+
},
156+
'finish_reason': 'tool_calls',
157+
},
158+
],
159+
'usage': {
160+
'prompt_tokens': 3,
161+
'completion_tokens': 2,
162+
'total_tokens': 5,
163+
},
164+
},
165+
]);
166+
167+
final content = response.message!.content;
168+
expect(response.finishReason, FinishReason.stop);
169+
expect(content, hasLength(1));
170+
expect(content.first.isToolRequest, true);
171+
172+
final toolRequest = content.first.toolRequest!;
173+
expect(toolRequest.ref, 'call_weather');
174+
expect(toolRequest.name, 'getWeather');
175+
expect(toolRequest.input, {'location': 'Boston'});
176+
expect(response.usage?.totalTokens, 5);
177+
});
126178
});
127179
}

0 commit comments

Comments
 (0)