diff --git a/src/platform/src/Bridge/VertexAi/Gemini/ResultConverter.php b/src/platform/src/Bridge/VertexAi/Gemini/ResultConverter.php index 96f333b495..f1becbbb6b 100644 --- a/src/platform/src/Bridge/VertexAi/Gemini/ResultConverter.php +++ b/src/platform/src/Bridge/VertexAi/Gemini/ResultConverter.php @@ -94,6 +94,10 @@ public function getTokenUsageExtractor(): TokenUsageExtractor private function convertStream(RawResultInterface $result): \Generator { foreach ($result->getDataStream() as $data) { + if (isset($data['usageMetadata'])) { + yield $this->getTokenUsageExtractor()->fromUsageMetadata($data['usageMetadata']); + } + $choices = array_values(array_filter(array_map($this->convertChoice(...), $data['candidates'] ?? []))); if (!$choices) { diff --git a/src/platform/src/Bridge/VertexAi/Gemini/TokenUsageExtractor.php b/src/platform/src/Bridge/VertexAi/Gemini/TokenUsageExtractor.php index c4903626b2..447439d973 100644 --- a/src/platform/src/Bridge/VertexAi/Gemini/TokenUsageExtractor.php +++ b/src/platform/src/Bridge/VertexAi/Gemini/TokenUsageExtractor.php @@ -34,7 +34,7 @@ public function extract(RawResultInterface $rawResult, array $options = []): ?To return null; } - return $this->extractUsageMetadata($content['usageMetadata']); + return $this->fromUsageMetadata($content['usageMetadata']); } /** @@ -46,7 +46,7 @@ public function extract(RawResultInterface $rawResult, array $options = []): ?To * totalTokenCount?: int * } $usage */ - private function extractUsageMetadata(array $usage): TokenUsage + public function fromUsageMetadata(array $usage): TokenUsage { return new TokenUsage( promptTokens: $usage['promptTokenCount'] ?? null, diff --git a/src/platform/src/Bridge/VertexAi/Tests/Gemini/ResultConverterTest.php b/src/platform/src/Bridge/VertexAi/Tests/Gemini/ResultConverterTest.php index 3bf90228a3..becc709177 100644 --- a/src/platform/src/Bridge/VertexAi/Tests/Gemini/ResultConverterTest.php +++ b/src/platform/src/Bridge/VertexAi/Tests/Gemini/ResultConverterTest.php @@ -28,6 +28,7 @@ use Symfony\AI\Platform\Result\TextResult; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Result\ToolCallResult; +use Symfony\AI\Platform\TokenUsage\TokenUsageInterface; use Symfony\Contracts\HttpClient\ResponseInterface; final class ResultConverterTest extends TestCase @@ -329,4 +330,50 @@ public static function streamDeltaProvider(): iterable ], ], ChoiceDelta::class]; } + + public function testStreamingYieldsTokenUsageWhenUsageMetadataIsPresent() + { + $response = $this->createStub(ResponseInterface::class); + $response->method('getStatusCode')->willReturn(200); + + $rawResult = $this->createStub(RawResultInterface::class); + $rawResult->method('getObject')->willReturn($response); + $rawResult->method('getDataStream')->willReturn((static function (): \Generator { + yield [ + 'candidates' => [[ + 'content' => ['parts' => [['text' => 'Hello']]], + ]], + ]; + yield [ + 'candidates' => [[ + 'content' => ['parts' => [['text' => ' world']]], + ]], + 'usageMetadata' => [ + 'promptTokenCount' => 15, + 'candidatesTokenCount' => 25, + 'thoughtsTokenCount' => 3, + 'totalTokenCount' => 43, + ], + ]; + })()); + + $result = (new ResultConverter())->convert($rawResult, ['stream' => true]); + + $this->assertInstanceOf(StreamResult::class, $result); + + $items = iterator_to_array($result->getContent(), false); + + $this->assertCount(3, $items); + $this->assertInstanceOf(TextDelta::class, $items[0]); + $this->assertSame('Hello', $items[0]->getText()); + + $this->assertInstanceOf(TokenUsageInterface::class, $items[1]); + $this->assertSame(15, $items[1]->getPromptTokens()); + $this->assertSame(25, $items[1]->getCompletionTokens()); + $this->assertSame(3, $items[1]->getThinkingTokens()); + $this->assertSame(43, $items[1]->getTotalTokens()); + + $this->assertInstanceOf(TextDelta::class, $items[2]); + $this->assertSame(' world', $items[2]->getText()); + } }