Skip to content

Commit 15b92ae

Browse files
authored
feat(gemini): add stream support for Gemini (#284)
1 parent 04a0aff commit 15b92ae

File tree

5 files changed

+437
-1
lines changed

5 files changed

+437
-1
lines changed

Diff for: src/Providers/Gemini/Gemini.php

+7-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
use Prism\Prism\Exceptions\PrismException;
1515
use Prism\Prism\Providers\Gemini\Handlers\Cache;
1616
use Prism\Prism\Providers\Gemini\Handlers\Embeddings;
17+
use Prism\Prism\Providers\Gemini\Handlers\Stream;
1718
use Prism\Prism\Providers\Gemini\Handlers\Structured;
1819
use Prism\Prism\Providers\Gemini\Handlers\Text;
1920
use Prism\Prism\Providers\Gemini\ValueObjects\GeminiCachedObject;
@@ -66,7 +67,12 @@ public function embeddings(EmbeddingRequest $request): EmbeddingResponse
6667
#[\Override]
6768
public function stream(TextRequest $request): Generator
6869
{
69-
throw PrismException::unsupportedProviderAction(__METHOD__, class_basename($this));
70+
$handler = new Stream(
71+
$this->client($request->clientOptions(), $request->clientRetry()),
72+
$this->apiKey
73+
);
74+
75+
return $handler->handle($request);
7076
}
7177

7278
/**

Diff for: src/Providers/Gemini/Handlers/Stream.php

+277
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Prism\Prism\Providers\Gemini\Handlers;
6+
7+
use Generator;
8+
use Illuminate\Http\Client\PendingRequest;
9+
use Illuminate\Http\Client\Response;
10+
use Illuminate\Support\Str;
11+
use Prism\Prism\Concerns\CallsTools;
12+
use Prism\Prism\Enums\FinishReason;
13+
use Prism\Prism\Enums\Provider;
14+
use Prism\Prism\Exceptions\PrismChunkDecodeException;
15+
use Prism\Prism\Exceptions\PrismException;
16+
use Prism\Prism\Providers\Gemini\Maps\FinishReasonMap;
17+
use Prism\Prism\Providers\Gemini\Maps\MessageMap;
18+
use Prism\Prism\Providers\Gemini\Maps\ToolChoiceMap;
19+
use Prism\Prism\Providers\Gemini\Maps\ToolMap;
20+
use Prism\Prism\Text\Chunk;
21+
use Prism\Prism\Text\Request;
22+
use Prism\Prism\ValueObjects\Messages\AssistantMessage;
23+
use Prism\Prism\ValueObjects\Messages\ToolResultMessage;
24+
use Prism\Prism\ValueObjects\ToolCall;
25+
use Psr\Http\Message\StreamInterface;
26+
use Throwable;
27+
28+
class Stream
29+
{
30+
use CallsTools;
31+
32+
public function __construct(
33+
protected PendingRequest $client,
34+
#[\SensitiveParameter] protected string $apiKey,
35+
) {}
36+
37+
/**
38+
* @return Generator<Chunk>
39+
*/
40+
public function handle(Request $request): Generator
41+
{
42+
$response = $this->sendRequest($request);
43+
44+
yield from $this->processStream($response, $request);
45+
}
46+
47+
/**
48+
* @return Generator<Chunk>
49+
*/
50+
protected function processStream(Response $response, Request $request, int $depth = 0): Generator
51+
{
52+
// Prevent infinite recursion with tool calls
53+
if ($depth >= $request->maxSteps()) {
54+
throw new PrismException('Maximum tool call chain depth exceeded');
55+
}
56+
57+
$text = '';
58+
$toolCalls = [];
59+
60+
while (! $response->getBody()->eof()) {
61+
$data = $this->parseNextDataLine($response->getBody());
62+
63+
// Skip empty data
64+
if ($data === null) {
65+
continue;
66+
}
67+
68+
// Process tool calls
69+
if ($this->hasToolCalls($data)) {
70+
$toolCalls = $this->extractToolCalls($data, $toolCalls);
71+
72+
// Check if this is the final part of the tool calls
73+
if ($this->mapFinishReason($data) === FinishReason::ToolCalls) {
74+
yield from $this->handleToolCalls($request, $text, $toolCalls, $depth);
75+
}
76+
77+
continue;
78+
}
79+
80+
// Handle content
81+
$content = data_get($data, 'candidates.0.content.parts.0.text') ?? '';
82+
$text .= $content;
83+
84+
$finishReason = data_get($data, 'done', false) ? FinishReason::Stop : FinishReason::Unknown;
85+
86+
yield new Chunk(
87+
text: $content,
88+
finishReason: $finishReason !== FinishReason::Unknown ? $finishReason : null
89+
);
90+
}
91+
}
92+
93+
/**
94+
* @return array<string, mixed>|null Parsed JSON data or null if line should be skipped
95+
*/
96+
protected function parseNextDataLine(StreamInterface $stream): ?array
97+
{
98+
$line = $this->readLine($stream);
99+
100+
if (! str_starts_with($line, 'data:')) {
101+
return null;
102+
}
103+
104+
$line = trim(substr($line, strlen('data: ')));
105+
106+
if ($line === '' || $line === '[DONE]') {
107+
return null;
108+
}
109+
110+
try {
111+
return json_decode($line, true, flags: JSON_THROW_ON_ERROR);
112+
} catch (Throwable $e) {
113+
throw new PrismChunkDecodeException('Gemini', $e);
114+
}
115+
}
116+
117+
/**
118+
* @param array<string, mixed> $data
119+
* @param array<int, array<string, mixed>> $toolCalls
120+
* @return array<int, array<string, mixed>>
121+
*/
122+
protected function extractToolCalls(array $data, array $toolCalls): array
123+
{
124+
$parts = data_get($data, 'candidates.0.content.parts', []);
125+
126+
foreach ($parts as $index => $part) {
127+
if (isset($part['functionCall'])) {
128+
$toolCalls[$index]['name'] = data_get($part, 'functionCall.name');
129+
$toolCalls[$index]['arguments'] = data_get($part, 'functionCall.args', '');
130+
}
131+
}
132+
133+
return $toolCalls;
134+
}
135+
136+
/**
137+
* @param array<int, array<string, mixed>> $toolCalls
138+
* @return Generator<Chunk>
139+
*/
140+
protected function handleToolCalls(
141+
Request $request,
142+
string $text,
143+
array $toolCalls,
144+
int $depth
145+
): Generator {
146+
// Convert collected tool call data to ToolCall objects
147+
$toolCalls = $this->mapToolCalls($toolCalls);
148+
149+
// Call the tools and get results
150+
$toolResults = $this->callTools($request->tools(), $toolCalls);
151+
152+
$request->addMessage(new AssistantMessage($text, $toolCalls));
153+
$request->addMessage(new ToolResultMessage($toolResults));
154+
155+
// Yield the tool call chunk
156+
yield new Chunk(
157+
text: '',
158+
toolCalls: $toolCalls,
159+
toolResults: $toolResults,
160+
);
161+
162+
// Continue the conversation with tool results
163+
$nextResponse = $this->sendRequest($request);
164+
yield from $this->processStream($nextResponse, $request, $depth + 1);
165+
}
166+
167+
/**
168+
* Convert raw tool call data to ToolCall objects.
169+
*
170+
* @param array<int, array<string, mixed>> $toolCalls
171+
* @return array<int, ToolCall>
172+
*/
173+
protected function mapToolCalls(array $toolCalls): array
174+
{
175+
return collect($toolCalls)
176+
->map(fn ($toolCall): ToolCall => new ToolCall(
177+
(string) array_key_exists('id', $toolCall) !== '' && (string) array_key_exists('id', $toolCall) !== '0' ? $toolCall['id'] : 'gm-'.Str::random(20),
178+
data_get($toolCall, 'name'),
179+
data_get($toolCall, 'arguments'),
180+
))
181+
->toArray();
182+
}
183+
184+
/**
185+
* @param array<string, mixed> $data
186+
*/
187+
protected function hasToolCalls(array $data): bool
188+
{
189+
$parts = data_get($data, 'candidates.0.content.parts', []);
190+
191+
foreach ($parts as $part) {
192+
if (isset($part['functionCall'])) {
193+
return true;
194+
}
195+
}
196+
197+
return false;
198+
}
199+
200+
/**
201+
* @param array<string, mixed> $data
202+
*/
203+
protected function mapFinishReason(array $data): FinishReason
204+
{
205+
$finishReason = data_get($data, 'candidates.0.finishReason');
206+
207+
if (! $finishReason) {
208+
return FinishReason::Unknown;
209+
}
210+
211+
$isToolCall = $this->hasToolCalls($data);
212+
213+
return FinishReasonMap::map($finishReason, $isToolCall);
214+
}
215+
216+
protected function sendRequest(Request $request): Response
217+
{
218+
try {
219+
$providerMeta = $request->providerMeta(Provider::Gemini);
220+
221+
if ($request->tools() !== [] && ($providerMeta['searchGrounding'] ?? false)) {
222+
throw new PrismException('Use of search grounding with custom tools is not currently supported by Prism.');
223+
}
224+
225+
$tools = match (true) {
226+
$providerMeta['searchGrounding'] ?? false => [
227+
[
228+
'google_search' => (object) [],
229+
],
230+
],
231+
$request->tools() !== [] => ['function_declarations' => ToolMap::map($request->tools())],
232+
default => [],
233+
};
234+
235+
return $this->client
236+
->withOptions(['stream' => true])
237+
->post(
238+
"{$request->model()}:streamGenerateContent?alt=sse",
239+
array_filter([
240+
...(new MessageMap($request->messages(), $request->systemPrompts()))(),
241+
'cachedContent' => $providerMeta['cachedContentName'] ?? null,
242+
'generationConfig' => array_filter([
243+
'temperature' => $request->temperature(),
244+
'topP' => $request->topP(),
245+
'maxOutputTokens' => $request->maxTokens(),
246+
]),
247+
'tools' => $tools !== [] ? $tools : null,
248+
'tool_config' => $request->toolChoice() ? ToolChoiceMap::map($request->toolChoice()) : null,
249+
'safetySettings' => $providerMeta['safetySettings'] ?? null,
250+
])
251+
);
252+
} catch (Throwable $e) {
253+
throw PrismException::providerRequestError($request->model(), $e);
254+
}
255+
}
256+
257+
protected function readLine(StreamInterface $stream): string
258+
{
259+
$buffer = '';
260+
261+
while (! $stream->eof()) {
262+
$byte = $stream->read(1);
263+
264+
if ($byte === '') {
265+
return $buffer;
266+
}
267+
268+
$buffer .= $byte;
269+
270+
if ($byte === "\n") {
271+
break;
272+
}
273+
}
274+
275+
return $buffer;
276+
}
277+
}

Diff for: tests/Fixtures/gemini/stream-basic-text-1.sse

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
data: {"candidates": [{"content": {"parts": [{"text": "AI"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 21,"totalTokenCount": 21,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}]},"modelVersion": "gemini-2.0-flash"}
2+
data: {"candidates": [{"content": {"parts": [{"text": "? It's simple! We just feed a computer a HUGE pile of information"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 21,"totalTokenCount": 21,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}]},"modelVersion": "gemini-2.0-flash"}
3+
data: {"candidates": [{"content": {"parts": [{"text": ", tell it to find patterns, and then it pretends to be smart! Like teaching"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 21,"totalTokenCount": 21,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}]},"modelVersion": "gemini-2.0-flash"}
4+
data: {"candidates": [{"content": {"parts": [{"text": " a parrot to say cool things. Mostly magic, though.\n"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 47,"totalTokenCount": 68,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 21}],"candidatesTokensDetails": [{"modality": "TEXT","tokenCount": 47}]},"modelVersion": "gemini-2.0-flash"}

0 commit comments

Comments
 (0)