Skip to content

Commit 5fdb6f9

Browse files
committed
feat(gemini): add stream support for Gemini
1 parent bb662ec commit 5fdb6f9

File tree

5 files changed

+429
-1
lines changed

5 files changed

+429
-1
lines changed

src/Providers/Gemini/Gemini.php

+7-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Prism\Prism\Exceptions\PrismException;
1515
use Prism\Prism\Providers\Gemini\Handlers\Cache;
1616
use Prism\Prism\Providers\Gemini\Handlers\Embeddings;
17+
use Prism\Prism\Providers\Gemini\Handlers\Stream;
1718
use Prism\Prism\Providers\Gemini\Handlers\Structured;
1819
use Prism\Prism\Providers\Gemini\Handlers\Text;
1920
use Prism\Prism\Providers\Gemini\ValueObjects\GeminiCachedObject;
@@ -66,7 +67,12 @@ public function embeddings(EmbeddingRequest $request): EmbeddingResponse
6667
#[\Override]
6768
public function stream(TextRequest $request): Generator
6869
{
69-
throw PrismException::unsupportedProviderAction(__METHOD__, class_basename($this));
70+
$handler = new Stream(
71+
$this->client($request->clientOptions(), $request->clientRetry()),
72+
$this->apiKey
73+
);
74+
75+
return $handler->handle($request);
7076
}
7177

7278
/**
+310
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Prism\Prism\Providers\Gemini\Handlers;
6+
7+
use Generator;
8+
use Illuminate\Http\Client\PendingRequest;
9+
use Illuminate\Http\Client\Response;
10+
use Illuminate\Support\Str;
11+
use Prism\Prism\Concerns\CallsTools;
12+
use Prism\Prism\Enums\FinishReason;
13+
use Prism\Prism\Enums\Provider;
14+
use Prism\Prism\Exceptions\PrismChunkDecodeException;
15+
use Prism\Prism\Exceptions\PrismException;
16+
use Prism\Prism\Providers\Gemini\Concerns\ExtractSearchGroundings;
17+
use Prism\Prism\Providers\Gemini\Concerns\ValidatesResponse;
18+
use Prism\Prism\Providers\Gemini\Maps\FinishReasonMap;
19+
use Prism\Prism\Providers\Gemini\Maps\MessageMap;
20+
use Prism\Prism\Providers\Gemini\Maps\ToolChoiceMap;
21+
use Prism\Prism\Providers\Gemini\Maps\ToolMap;
22+
use Prism\Prism\Text\Chunk;
23+
use Prism\Prism\Text\Request;
24+
use Prism\Prism\ValueObjects\Messages\AssistantMessage;
25+
use Prism\Prism\ValueObjects\Messages\ToolResultMessage;
26+
use Prism\Prism\ValueObjects\ToolCall;
27+
use Psr\Http\Message\StreamInterface;
28+
use Throwable;
29+
30+
class Stream
31+
{
32+
use CallsTools, ExtractSearchGroundings, ValidatesResponse;
33+
34+
public function __construct(
35+
protected PendingRequest $client,
36+
#[\SensitiveParameter] protected string $apiKey,
37+
) {}
38+
39+
/**
40+
* @return Generator<Chunk>
41+
*/
42+
public function handle(Request $request): Generator
43+
{
44+
$response = $this->sendRequest($request);
45+
46+
yield from $this->processStream($response, $request);
47+
}
48+
49+
/**
50+
* @return Generator<Chunk>
51+
*/
52+
protected function processStream(Response $response, Request $request, int $depth = 0): Generator
53+
{
54+
// Prevent infinite recursion with tool calls
55+
if ($depth >= $request->maxSteps()) {
56+
throw new PrismException('Maximum tool call chain depth exceeded');
57+
}
58+
59+
$text = '';
60+
$toolCalls = [];
61+
$groundingSupports = [];
62+
$groundingChunks = [];
63+
64+
while (! $response->getBody()->eof()) {
65+
$data = $this->parseNextDataLine($response->getBody());
66+
67+
// Skip empty data
68+
if ($data === null) {
69+
continue;
70+
}
71+
72+
// Process tool calls
73+
if ($this->hasToolCalls($data)) {
74+
$toolCalls = $this->extractToolCalls($data, $toolCalls);
75+
76+
continue;
77+
}
78+
79+
// Extract search grounding information if present
80+
if ($this->hasSearchGrounding($data)) {
81+
$groundingSupports = array_merge($groundingSupports, data_get($data, 'candidates.0.content.parts.0.text.groundingSupport', []));
82+
$groundingChunks = array_merge($groundingChunks, data_get($data, 'groundingChunks', []));
83+
}
84+
85+
// Handle content
86+
$content = data_get($data, 'candidates.0.content.parts.0.text') ?? '';
87+
if (is_string($content)) {
88+
$text .= $content;
89+
90+
yield new Chunk(
91+
text: $content,
92+
finishReason: null,
93+
);
94+
}
95+
96+
// Handle finish reason
97+
$finishReason = $this->mapFinishReason($data);
98+
if ($finishReason === FinishReason::Unknown) {
99+
continue;
100+
}
101+
102+
if ($finishReason === FinishReason::ToolCalls) {
103+
yield from $this->handleToolCalls($request, $text, $toolCalls, $depth);
104+
105+
return;
106+
}
107+
108+
yield new Chunk(
109+
text: '',
110+
finishReason: $finishReason,
111+
);
112+
113+
return;
114+
}
115+
}
116+
117+
/**
118+
* @return array<string, mixed>|null Parsed JSON data or null if line should be skipped
119+
*/
120+
protected function parseNextDataLine(StreamInterface $stream): ?array
121+
{
122+
$line = $this->readLine($stream);
123+
124+
if (! str_starts_with($line, 'data:')) {
125+
return null;
126+
}
127+
128+
$line = trim(substr($line, strlen('data: ')));
129+
130+
if ($line === '' || $line === '[DONE]') {
131+
return null;
132+
}
133+
134+
try {
135+
return json_decode($line, true, flags: JSON_THROW_ON_ERROR);
136+
} catch (Throwable $e) {
137+
throw new PrismChunkDecodeException('Gemini', $e);
138+
}
139+
}
140+
141+
/**
142+
* @param array<string, mixed> $data
143+
* @param array<int, array<string, mixed>> $toolCalls
144+
* @return array<int, array<string, mixed>>
145+
*/
146+
protected function extractToolCalls(array $data, array $toolCalls): array
147+
{
148+
$parts = data_get($data, 'candidates.0.content.parts', []);
149+
150+
foreach ($parts as $index => $part) {
151+
if (isset($part['functionCall'])) {
152+
$toolCalls[$index]['name'] = data_get($part, 'functionCall.name');
153+
$toolCalls[$index]['arguments'] = data_get($part, 'functionCall.args', '');
154+
}
155+
}
156+
157+
return $toolCalls;
158+
}
159+
160+
/**
161+
* @param array<int, array<string, mixed>> $toolCalls
162+
* @return Generator<Chunk>
163+
*/
164+
protected function handleToolCalls(
165+
Request $request,
166+
string $text,
167+
array $toolCalls,
168+
int $depth
169+
): Generator {
170+
// Convert collected tool call data to ToolCall objects
171+
$toolCalls = $this->mapToolCalls($toolCalls);
172+
173+
// Call the tools and get results
174+
$toolResults = $this->callTools($request->tools(), $toolCalls);
175+
176+
$request->addMessage(new AssistantMessage($text, $toolCalls));
177+
$request->addMessage(new ToolResultMessage($toolResults));
178+
179+
// Yield the tool call chunk
180+
yield new Chunk(
181+
text: '',
182+
toolCalls: $toolCalls,
183+
toolResults: $toolResults,
184+
);
185+
186+
// Continue the conversation with tool results
187+
$nextResponse = $this->sendRequest($request);
188+
yield from $this->processStream($nextResponse, $request, $depth + 1);
189+
}
190+
191+
/**
192+
* Convert raw tool call data to ToolCall objects.
193+
*
194+
* @param array<int, array<string, mixed>> $toolCalls
195+
* @return array<int, ToolCall>
196+
*/
197+
protected function mapToolCalls(array $toolCalls): array
198+
{
199+
return collect($toolCalls)
200+
->map(fn ($toolCall): ToolCall => new ToolCall(
201+
(string) array_key_exists('id', $toolCall) !== '' && (string) array_key_exists('id', $toolCall) !== '0' ? $toolCall['id'] : 'gm-'.Str::random(20),
202+
data_get($toolCall, 'name'),
203+
data_get($toolCall, 'arguments'),
204+
))
205+
->toArray();
206+
}
207+
208+
/**
209+
* @param array<string, mixed> $data
210+
*/
211+
protected function hasToolCalls(array $data): bool
212+
{
213+
$parts = data_get($data, 'candidates.0.content.parts', []);
214+
215+
foreach ($parts as $part) {
216+
if (isset($part['functionCall'])) {
217+
return true;
218+
}
219+
}
220+
221+
return false;
222+
}
223+
224+
/**
225+
* @param array<string, mixed> $data
226+
*/
227+
protected function hasSearchGrounding(array $data): bool
228+
{
229+
return ! empty(data_get($data, 'candidates.0.content.parts.0.text.groundingSupport'))
230+
&& ! empty(data_get($data, 'groundingChunks'));
231+
}
232+
233+
/**
234+
* @param array<string, mixed> $data
235+
*/
236+
protected function mapFinishReason(array $data): FinishReason
237+
{
238+
$finishReason = data_get($data, 'candidates.0.finishReason');
239+
240+
if (! $finishReason) {
241+
return FinishReason::Unknown;
242+
}
243+
244+
$isToolCall = $this->hasToolCalls($data);
245+
246+
return FinishReasonMap::map($finishReason, $isToolCall);
247+
}
248+
249+
protected function sendRequest(Request $request): Response
250+
{
251+
try {
252+
$providerMeta = $request->providerMeta(Provider::Gemini);
253+
254+
if ($request->tools() !== [] && ($providerMeta['searchGrounding'] ?? false)) {
255+
throw new PrismException('Use of search grounding with custom tools is not currently supported by Prism.');
256+
}
257+
258+
$tools = match (true) {
259+
$providerMeta['searchGrounding'] ?? false => [
260+
[
261+
'google_search' => (object) [],
262+
],
263+
],
264+
$request->tools() !== [] => ['function_declarations' => ToolMap::map($request->tools())],
265+
default => [],
266+
};
267+
268+
return $this->client
269+
->withOptions(['stream' => true])
270+
->post(
271+
"{$request->model()}:streamGenerateContent?alt=sse",
272+
array_filter([
273+
...(new MessageMap($request->messages(), $request->systemPrompts()))(),
274+
'cachedContent' => $providerMeta['cachedContentName'] ?? null,
275+
'generationConfig' => array_filter([
276+
'temperature' => $request->temperature(),
277+
'topP' => $request->topP(),
278+
'maxOutputTokens' => $request->maxTokens(),
279+
]),
280+
'tools' => $tools !== [] ? $tools : null,
281+
'tool_config' => $request->toolChoice() ? ToolChoiceMap::map($request->toolChoice()) : null,
282+
'safetySettings' => $providerMeta['safetySettings'] ?? null,
283+
])
284+
);
285+
} catch (Throwable $e) {
286+
throw PrismException::providerRequestError($request->model(), $e);
287+
}
288+
}
289+
290+
protected function readLine(StreamInterface $stream): string
291+
{
292+
$buffer = '';
293+
294+
while (! $stream->eof()) {
295+
$byte = $stream->read(1);
296+
297+
if ($byte === '') {
298+
return $buffer;
299+
}
300+
301+
$buffer .= $byte;
302+
303+
if ($byte === "\n") {
304+
break;
305+
}
306+
}
307+
308+
return $buffer;
309+
}
310+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
data: {"candidates": [{"content": {"parts": [{"text": "AI"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 21,"totalTokenCount": 21,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}]},"modelVersion": "gemini-2.0-flash"}
2+
data: {"candidates": [{"content": {"parts": [{"text": "? It's simple! We just feed a computer a HUGE pile of information"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 21,"totalTokenCount": 21,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}]},"modelVersion": "gemini-2.0-flash"}
3+
data: {"candidates": [{"content": {"parts": [{"text": ", tell it to find patterns, and then it pretends to be smart! Like teaching"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 21,"totalTokenCount": 21,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}]},"modelVersion": "gemini-2.0-flash"}
4+
data: {"candidates": [{"content": {"parts": [{"text": " a parrot to say cool things. Mostly magic, though.\n"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 47,"totalTokenCount": 68,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}],"candidatesTokensDetails": [{"modality": "TEXT","tokenCount": 47}]},"modelVersion": "gemini-2.0-flash"}

0 commit comments

Comments
 (0)