diff --git a/packages/genkit_google_genai/lib/common.dart b/packages/genkit_google_genai/lib/common.dart index 4ca02b8b..64a73f57 100644 --- a/packages/genkit_google_genai/lib/common.dart +++ b/packages/genkit_google_genai/lib/common.dart @@ -17,4 +17,11 @@ library; export 'src/api_client.dart'; export 'src/common_plugin.dart'; +export 'src/generated/generativelanguage.dart' + show + BatchEmbedContentsRequest, + Content, + ContentEmbedding, + EmbedContentRequest, + Part; export 'src/model.dart'; diff --git a/packages/genkit_google_genai/lib/src/api_client.dart b/packages/genkit_google_genai/lib/src/api_client.dart index e667b98b..469e3ff9 100644 --- a/packages/genkit_google_genai/lib/src/api_client.dart +++ b/packages/genkit_google_genai/lib/src/api_client.dart @@ -40,6 +40,15 @@ class GenerativeLanguageBaseClient { return EmbedContentResponse.fromJson(res); } + Future batchEmbedContents( + BatchEmbedContentsRequest request, { + required String model, + }) async { + final url = '$apiUrlPrefix$model:batchEmbedContents'; + final res = await _call('POST', url, request.toJson()); + return BatchEmbedContentsResponse.fromJson(res); + } + Future generateContent( GenerateContentRequest request, { required String model, diff --git a/packages/genkit_google_genai/lib/src/generated/generativelanguage.dart b/packages/genkit_google_genai/lib/src/generated/generativelanguage.dart index 6ea92a04..b84ac2d2 100644 --- a/packages/genkit_google_genai/lib/src/generated/generativelanguage.dart +++ b/packages/genkit_google_genai/lib/src/generated/generativelanguage.dart @@ -1769,7 +1769,7 @@ extension type ContentEmbedding._(Map _data) { set values(List? value) => _data['values'] = value; - /// This field stores the soft tokens tensor frame shape (e.g. [1, 1, 256, 2048]). + /// This field stores the soft tokens tensor frame shape (e.g. `[1, 1, 256, 2048]`). List? get shape { final v = _data['shape']; if (v == null) return null; diff --git a/packages/genkit_vertexai/lib/src/embedders.dart b/packages/genkit_vertexai/lib/src/embedders.dart new file mode 100644 index 00000000..70b71e2e --- /dev/null +++ b/packages/genkit_vertexai/lib/src/embedders.dart @@ -0,0 +1,561 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:genkit/plugin.dart'; +import 'package:genkit_google_genai/common.dart' as google; + +List> listVertexEmbedders({ + required String pluginName, + required List publisherModels, +}) { + return publisherModels + .where((m) { + final modelMap = m as Map; + final name = modelMap['name'] as String?; + return name != null && name.contains('embedding'); + }) + .map((m) { + final modelMap = m as Map; + final modelName = (modelMap['name'] as String).split('/').last; + return _vertexEmbedderMetadata('$pluginName/$modelName'); + }) + .toList(); +} + +ActionMetadata _vertexEmbedderMetadata( + String name, +) { + final metadata = embedderMetadata( + name, + customOptions: google.TextEmbedderOptions.$schema, + ); + return ActionMetadata( + name: metadata.name, + description: metadata.description, + actionType: metadata.actionType, + inputSchema: EmbedRequest.$schema, + outputSchema: EmbedResponse.$schema, + metadata: metadata.metadata, + ); +} + +Embedder createVertexEmbedder({ + required String pluginName, + required String embedderName, + required Future Function() getApiClient, + required GenkitException Function(Object, StackTrace) handleException, + required bool closeService, +}) { + return Embedder( + name: '$pluginName/$embedderName', + fn: (req, ctx) async { + if (req == null || req.input.isEmpty) { + return EmbedResponse(embeddings: []); + } + + final service = await getApiClient(); + try { + final options = req.options != null + ? google.TextEmbedderOptions.fromJson(req.options!) + : null; + + final embeddings = switch (_requestShapeFor(embedderName)) { + _VertexEmbedderRequestShape.geminiEmbedding => + _runGeminiEmbeddingRequests( + service: service, + embedderName: embedderName, + docs: req.input, + options: options, + ), + _VertexEmbedderRequestShape.multimodalPredict => + _runMultimodalPredictRequests( + service: service, + embedderName: embedderName, + docs: req.input, + options: options, + ), + _VertexEmbedderRequestShape.textPredict => _runTextPredictRequests( + service: service, + embedderName: embedderName, + docs: req.input, + options: options, + ), + }; + return EmbedResponse(embeddings: await embeddings); + } catch (e, stack) { + throw handleException(e, stack); + } finally { + if (closeService) { + service.client.close(); + } + } + }, + ); +} + +String _documentText(DocumentData doc) { + return doc.content.where((p) => p.isText).map((p) => p.text).join('\n'); +} + +google.EmbedContentRequest _embedContentRequest( + DocumentData doc, + google.TextEmbedderOptions? options, +) { + final text = _documentText(doc); + return google.EmbedContentRequest( + content: google.Content(parts: [google.Part(text: text)]), + outputDimensionality: options?.outputDimensionality, + taskType: options?.taskType, + title: options?.title, + ); +} + +Future> _runGeminiEmbeddingRequests({ + required google.GenerativeLanguageBaseClient service, + required String embedderName, + required List docs, + required google.TextEmbedderOptions? options, +}) async { + if (docs.length == 1) { + return [ + await _runEmbedContentRequest( + service: service, + embedderName: embedderName, + doc: docs.single, + options: options, + ), + ]; + } + + final res = await service.batchEmbedContents( + google.BatchEmbedContentsRequest( + requests: docs.map((doc) => _embedContentRequest(doc, options)).toList(), + ), + model: 'models/$embedderName', + ); + final embeddings = _requireBatchEmbeddings( + res.embeddings, + expectedCount: docs.length, + ); + return embeddings + .map((embedding) => Embedding(embedding: embedding.values ?? const [])) + .toList(); +} + +Future _runEmbedContentRequest({ + required google.GenerativeLanguageBaseClient service, + required String embedderName, + required DocumentData doc, + required google.TextEmbedderOptions? options, +}) async { + final res = await service.embedContent( + _embedContentRequest(doc, options), + model: 'models/$embedderName', + ); + return Embedding(embedding: res.embedding?.values ?? []); +} + +List _requireBatchEmbeddings( + List? embeddings, { + required int expectedCount, +}) { + if (embeddings == null || embeddings.isEmpty) { + throw GenkitException( + 'Vertex AI returned no embeddings.', + status: StatusCodes.INTERNAL, + ); + } + if (embeddings.length != expectedCount) { + throw GenkitException( + 'Vertex AI returned ${embeddings.length} embeddings for $expectedCount input documents.', + status: StatusCodes.INTERNAL, + ); + } + return embeddings; +} + +List> _requirePredictions( + Object? rawPredictions, { + required int expectedCount, +}) { + if (rawPredictions is! List || rawPredictions.isEmpty) { + throw GenkitException( + 'Vertex AI returned no predictions.', + status: StatusCodes.INTERNAL, + ); + } + if (rawPredictions.length != expectedCount) { + throw GenkitException( + 'Vertex AI returned ${rawPredictions.length} predictions for $expectedCount input documents.', + status: StatusCodes.INTERNAL, + ); + } + + return rawPredictions.map((prediction) { + if (prediction is! Map) { + throw GenkitException( + 'Vertex AI returned an invalid prediction payload.', + status: StatusCodes.INTERNAL, + ); + } + return prediction; + }).toList(); +} + +Future> _runMultimodalPredictRequests({ + required google.GenerativeLanguageBaseClient service, + required String embedderName, + required List docs, + required google.TextEmbedderOptions? options, +}) async { + // Multimodal embedders use a different predict request and response shape. + final instances = [ + for (var i = 0; i < docs.length; i++) + _toMultimodalInstance(docs[i], documentIndex: i), + ]; + final parameters = {}; + if (options?.outputDimensionality != null) { + // Multimodal predict expects `parameters.dimension`, not + // `outputDimensionality`. + parameters['dimension'] = options!.outputDimensionality; + } + + final res = await service.predict({ + 'instances': instances.map((instance) => instance.instance).toList(), + if (parameters.isNotEmpty) 'parameters': parameters, + }, model: 'models/$embedderName'); + + final predictions = _requirePredictions( + res['predictions'], + expectedCount: instances.length, + ); + return [ + for (var i = 0; i < predictions.length; i++) + ..._multimodalPredictionEmbeddings( + predictions[i], + expectedOutputs: instances[i].expectedOutputs, + ), + ]; +} + +Future> _runTextPredictRequests({ + required google.GenerativeLanguageBaseClient service, + required String embedderName, + required List docs, + required google.TextEmbedderOptions? options, +}) async { + // Older text embedders still use the predict payload shape. + final instances = docs.map((doc) { + final instance = {'content': _documentText(doc)}; + if (options?.title != null) { + instance['title'] = options!.title; + } + if (options?.taskType != null) { + instance['task_type'] = options!.taskType; + } + return instance; + }).toList(); + + final parameters = {}; + if (options?.outputDimensionality != null) { + parameters['outputDimensionality'] = options!.outputDimensionality; + } + + final res = await service.predict({ + 'instances': instances, + if (parameters.isNotEmpty) 'parameters': parameters, + }, model: 'models/$embedderName'); + + final predictions = _requirePredictions( + res['predictions'], + expectedCount: docs.length, + ); + return predictions.map(_textPredictionEmbedding).toList(); +} + +Embedding _textPredictionEmbedding(Map prediction) { + final embeddingData = prediction['embeddings']; + final values = embeddingData is Map + ? embeddingData['values'] + : null; + if (values is! List) { + throw GenkitException( + 'Vertex AI returned an invalid prediction payload.', + status: StatusCodes.INTERNAL, + ); + } + + return Embedding( + embedding: values.map((value) => (value as num).toDouble()).toList(), + ); +} + +_MultimodalInstance _toMultimodalInstance( + DocumentData doc, { + required int documentIndex, +}) { + final text = _documentText(doc).trim(); + final instance = {}; + final expectedOutputs = <_MultimodalExpectedOutput>[]; + + if (text.isNotEmpty) { + instance['text'] = text; + expectedOutputs.add( + _MultimodalExpectedOutput( + output: _MultimodalOutput.text, + metadata: { + 'documentIndex': documentIndex, + 'modality': 'text', + 'partIndices': [ + for (var i = 0; i < doc.content.length; i++) + if (doc.content[i].isText && + (doc.content[i].text?.trim().isNotEmpty ?? false)) + i, + ], + }, + ), + ); + } + + for (var i = 0; i < doc.content.length; i++) { + final part = doc.content[i]; + if (!part.isMedia) continue; + + final mediaField = _toMultimodalMediaField(part.media!); + if (instance.containsKey(mediaField.key)) { + throw GenkitException( + 'Vertex multimodalembedding supports at most one ${mediaField.key} part per input document.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + + instance[mediaField.key] = mediaField.value; + expectedOutputs.add( + _MultimodalExpectedOutput( + output: mediaField.key == 'image' + ? _MultimodalOutput.image + : _MultimodalOutput.video, + metadata: { + 'documentIndex': documentIndex, + 'modality': mediaField.key, + 'partIndex': i, + }, + ), + ); + } + + if (instance.isEmpty) { + throw GenkitException( + 'Vertex multimodalembedding requires text, image, or video input.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + + return _MultimodalInstance( + instance: instance, + expectedOutputs: expectedOutputs, + ); +} + +MapEntry> _toMultimodalMediaField(Media media) { + final mimeType = _mediaMimeType(media); + final fieldName = _multimodalFieldName(mimeType); + + // Convert the media input into the format Vertex expects. + if (media.url.startsWith('data:')) { + final data = Uri.tryParse(media.url)?.data; + if (data == null) { + throw GenkitException( + 'Vertex multimodalembedding media inputs require a valid data URI.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + + return MapEntry(fieldName, { + 'bytesBase64Encoded': base64Encode(data.contentAsBytes()), + if (mimeType != null && mimeType.isNotEmpty) 'mimeType': mimeType, + }); + } + + if (media.url.startsWith('gs://')) { + return MapEntry(fieldName, { + 'gcsUri': media.url, + if (mimeType != null && mimeType.isNotEmpty) 'mimeType': mimeType, + }); + } + + throw GenkitException( + 'Vertex multimodalembedding media inputs must use gs:// URIs or inline data URIs.', + status: StatusCodes.INVALID_ARGUMENT, + ); +} + +String? _mediaMimeType(Media media) { + if (media.contentType?.isNotEmpty == true) { + return media.contentType; + } + + if (media.url.startsWith('data:')) { + return Uri.tryParse(media.url)?.data?.mimeType; + } + + return null; +} + +String _multimodalFieldName(String? mimeType) { + if (mimeType == null || mimeType.isEmpty) { + throw GenkitException( + 'Vertex multimodalembedding media inputs require a MIME type.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + + if (mimeType.startsWith('image/')) { + return 'image'; + } + if (mimeType.startsWith('video/')) { + return 'video'; + } + + throw GenkitException( + 'Unsupported Vertex multimodalembedding media MIME type: $mimeType', + status: StatusCodes.INVALID_ARGUMENT, + ); +} + +List _multimodalPredictionEmbeddings( + Map prediction, { + required List<_MultimodalExpectedOutput> expectedOutputs, +}) { + final embeddings = []; + for (final expectedOutput in expectedOutputs) { + switch (expectedOutput.output) { + case _MultimodalOutput.text: + embeddings.add( + _embeddingFromMultimodalValues( + prediction['textEmbedding'] as List?, + expectedOutput: expectedOutput, + ), + ); + case _MultimodalOutput.image: + embeddings.add( + _embeddingFromMultimodalValues( + prediction['imageEmbedding'] as List?, + expectedOutput: expectedOutput, + ), + ); + case _MultimodalOutput.video: + final videoEmbeddings = prediction['videoEmbeddings'] as List?; + if (videoEmbeddings == null || videoEmbeddings.isEmpty) { + throw GenkitException( + 'Vertex multimodalembedding did not return a video embedding.', + status: StatusCodes.INTERNAL, + ); + } + + for (var i = 0; i < videoEmbeddings.length; i++) { + final videoEmbedding = videoEmbeddings[i] as Map; + embeddings.add( + _embeddingFromMultimodalValues( + videoEmbedding['embedding'] as List?, + expectedOutput: expectedOutput, + metadata: { + ...expectedOutput.metadata, + 'segmentIndex': i, + if (videoEmbedding['startOffsetSec'] != null) + 'startOffsetSec': videoEmbedding['startOffsetSec'], + if (videoEmbedding['endOffsetSec'] != null) + 'endOffsetSec': videoEmbedding['endOffsetSec'], + }, + ), + ); + } + } + } + return embeddings; +} + +Embedding _embeddingFromMultimodalValues( + List? values, { + required _MultimodalExpectedOutput expectedOutput, + Map? metadata, +}) { + if (values == null) { + throw GenkitException( + 'Vertex multimodalembedding did not return a ${expectedOutput.output.name} embedding.', + status: StatusCodes.INTERNAL, + ); + } + + return Embedding( + embedding: values.map((value) => (value as num).toDouble()).toList(), + metadata: metadata ?? expectedOutput.metadata, + ); +} + +String _baseModelName(String modelName) { + final atIndex = modelName.indexOf('@'); + if (atIndex == -1) return modelName; + return modelName.substring(0, atIndex); +} + +_VertexEmbedderRequestShape _requestShapeFor(String modelName) { + final baseModelName = _baseModelName(modelName); + final exactShape = _requestShapeByExactModel[baseModelName]; + if (exactShape != null) return exactShape; + + if (_isMultimodalEmbeddingFamily(baseModelName)) { + return _VertexEmbedderRequestShape.multimodalPredict; + } + if (_isGeminiEmbeddingFamily(baseModelName)) { + return _VertexEmbedderRequestShape.geminiEmbedding; + } + return _VertexEmbedderRequestShape.textPredict; +} + +bool _isMultimodalEmbeddingFamily(String modelName) { + return modelName.contains('multimodal') && modelName.contains('embedding'); +} + +bool _isGeminiEmbeddingFamily(String modelName) { + return modelName.startsWith('gemini-embedding-'); +} + +const _requestShapeByExactModel = { + 'gemini-embedding-001': _VertexEmbedderRequestShape.textPredict, +}; + +class _MultimodalInstance { + final Map instance; + final List<_MultimodalExpectedOutput> expectedOutputs; + + _MultimodalInstance({required this.instance, required this.expectedOutputs}); +} + +class _MultimodalExpectedOutput { + final _MultimodalOutput output; + final Map metadata; + + _MultimodalExpectedOutput({required this.output, required this.metadata}); +} + +enum _MultimodalOutput { text, image, video } + +enum _VertexEmbedderRequestShape { + geminiEmbedding, + multimodalPredict, + textPredict, +} diff --git a/packages/genkit_vertexai/lib/src/vertex_api_client.dart b/packages/genkit_vertexai/lib/src/vertex_api_client.dart index dd0a5cf7..7717eb67 100644 --- a/packages/genkit_vertexai/lib/src/vertex_api_client.dart +++ b/packages/genkit_vertexai/lib/src/vertex_api_client.dart @@ -19,6 +19,7 @@ import 'package:http/http.dart' as http; import 'package:meta/meta.dart'; import 'auth.dart'; +import 'embedders.dart'; @visibleForTesting class VertexAiPluginImpl extends CommonGoogleGenPlugin { @@ -107,20 +108,10 @@ class VertexAiPluginImpl extends CommonGoogleGenPlugin { }) .toList(); - final embedders = publisherModels - .where((m) { - final modelMap = m as Map; - final name = modelMap['name'] as String?; - return name != null && - (name.contains('text-embedding-') || - name.contains('embedding-')); - }) - .map((m) { - final modelMap = m as Map; - final modelName = (modelMap['name'] as String).split('/').last; - return embedderMetadata('$name/$modelName'); - }) - .toList(); + final embedders = listVertexEmbedders( + pluginName: name, + publisherModels: publisherModels, + ); return [...models, ...embedders]; } catch (e, stack) { @@ -136,58 +127,12 @@ class VertexAiPluginImpl extends CommonGoogleGenPlugin { @override Embedder createEmbedder(String embedderName) { - return Embedder( - name: '$name/$embedderName', - fn: (req, ctx) async { - if (req == null || req.input.isEmpty) { - return EmbedResponse(embeddings: []); - } - final service = await getApiClient(); - try { - final options = req.options != null - ? TextEmbedderOptions.fromJson(req.options!) - : null; - - final instances = req.input.map((doc) { - final text = doc.content - .where((p) => p.isText) - .map((p) => p.text) - .join('\n'); - return {'content': text}; - }).toList(); - - final parameters = {}; - if (options?.outputDimensionality != null) { - parameters['outputDimensionality'] = options!.outputDimensionality; - } - if (options?.taskType != null) { - parameters['taskType'] = options!.taskType; - } - - final res = await service.predict({ - 'instances': instances, - if (parameters.isNotEmpty) 'parameters': parameters, - }, model: 'models/$embedderName'); - - final predictions = res['predictions'] as List; - final embeddings = predictions.map((p) { - final emb = - (p as Map)['embeddings'] - as Map; - final vals = emb['values'] as List; - return Embedding( - embedding: vals.map((e) => (e as num).toDouble()).toList(), - ); - }).toList(); - return EmbedResponse(embeddings: embeddings); - } catch (e, stack) { - throw handleException(e, stack); - } finally { - if (authClient == null) { - service.client.close(); - } - } - }, + return createVertexEmbedder( + pluginName: name, + embedderName: embedderName, + getApiClient: getApiClient, + handleException: handleException, + closeService: authClient == null, ); } } diff --git a/packages/genkit_vertexai/test/embedders_test.dart b/packages/genkit_vertexai/test/embedders_test.dart new file mode 100644 index 00000000..9f0d2c9f --- /dev/null +++ b/packages/genkit_vertexai/test/embedders_test.dart @@ -0,0 +1,633 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:genkit/genkit.dart'; +import 'package:genkit_vertexai/src/vertex_api_client.dart'; +import 'package:test/test.dart'; + +import 'test_http_client.dart'; + +typedef _EmbedderAction = Action; + +_EmbedderAction _resolveEmbedder(VertexAiPluginImpl plugin, String name) { + return plugin.resolve('embedder', name)! as _EmbedderAction; +} + +void main() { + group('Vertex AI Embedders', () { + test('lists embedders with schemas and custom options', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final actions = await plugin.list(); + final embedder = actions.firstWhere( + (action) => action.name == 'vertexai/text-embedding-005', + ); + + expect(embedder.name, 'vertexai/text-embedding-005'); + expect(embedder.inputSchema, same(EmbedRequest.$schema)); + expect(embedder.outputSchema, same(EmbedResponse.$schema)); + + final modelMetadata = embedder.metadata['model'] as Map; + final customOptions = + modelMetadata['customOptions'] as Map; + expect(customOptions['properties'], contains('outputDimensionality')); + expect(customOptions['properties'], contains('taskType')); + }); + + test('uses embedContent for a single Gemini preview input', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'gemini-embedding-2-preview'); + final req = EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + ], + ); + + final response = await embedder.run(req); + + expect(mockClient.lastUrl, isNotNull); + expect( + mockClient.lastUrl.toString(), + 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/my-project/locations/us-central1/publishers/google/models/gemini-embedding-2-preview:embedContent', + ); + expect(response.result.embeddings, hasLength(1)); + expect(response.result.embeddings.first.embedding, [0.1, 0.2, 0.3]); + }); + + test('uses batchEmbedContents for multiple Gemini preview inputs', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'gemini-embedding-2-preview'); + final req = EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + DocumentData(content: [TextPart(text: 'world')]), + ], + ); + + final response = await embedder.run(req); + + expect(mockClient.lastUrl, isNotNull); + expect( + mockClient.lastUrl.toString(), + 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/my-project/locations/us-central1/publishers/google/models/gemini-embedding-2-preview:batchEmbedContents', + ); + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final requests = requestBody['requests'] as List; + expect(requests, hasLength(2)); + expect(response.result.embeddings, hasLength(2)); + expect(response.result.embeddings[0].embedding, [0.1, 0.2, 0.3]); + expect(response.result.embeddings[1].embedding, [1.1, 1.2, 1.3]); + }); + + test( + 'uses predict directly for legacy gemini-embedding-001 inputs', + () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'gemini-embedding-001'); + final req = EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + DocumentData(content: [TextPart(text: 'world')]), + ], + ); + + final response = await embedder.run(req); + + expect(mockClient.lastUrl, isNotNull); + expect( + mockClient.lastUrl.toString(), + 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/my-project/locations/us-central1/publishers/google/models/gemini-embedding-001:predict', + ); + expect( + mockClient.requestUrls.where( + (url) => url.path.endsWith('gemini-embedding-001:embedContent'), + ), + isEmpty, + ); + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final instances = requestBody['instances'] as List; + expect(instances, hasLength(2)); + expect(response.result.embeddings, hasLength(2)); + expect(response.result.embeddings[0].embedding, [0.4, 0.5, 0.6]); + expect(response.result.embeddings[1].embedding, [1.4, 1.5, 1.6]); + }, + ); + + test('uses text predict REST option schema', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'text-embedding-005'); + final req = EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + ], + options: { + 'outputDimensionality': 256, + 'taskType': 'RETRIEVAL_DOCUMENT', + 'title': 'document title', + }, + ); + + await embedder.run(req); + + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final instances = requestBody['instances'] as List; + expect(instances.single, { + 'content': 'hello', + 'task_type': 'RETRIEVAL_DOCUMENT', + 'title': 'document title', + }); + expect(requestBody['parameters'], {'outputDimensionality': 256}); + }); + + test('throws when a text prediction omits embedding values', () async { + final mockClient = MockHttpClient(returnInvalidTextPrediction: true); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'text-embedding-005'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains('Vertex AI returned an invalid prediction payload.'), + ), + ), + ); + }); + + test( + 'uses multimodal predict schema for text-only multimodal inputs', + () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + final req = EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + ], + options: { + 'outputDimensionality': 256, + 'taskType': 'RETRIEVAL_DOCUMENT', + }, + ); + + final response = await embedder.run(req); + + expect(mockClient.lastUrl, isNotNull); + expect( + mockClient.lastUrl.toString(), + 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/my-project/locations/us-central1/publishers/google/models/multimodalembedding:predict', + ); + + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final instances = requestBody['instances'] as List; + expect(instances.single, {'text': 'hello'}); + expect(requestBody['parameters'], {'dimension': 256}); + expect(response.result.embeddings, hasLength(1)); + expect(response.result.embeddings.first.embedding, [0.7, 0.8, 0.9]); + expect(response.result.embeddings.first.metadata, { + 'documentIndex': 0, + 'modality': 'text', + 'partIndices': [0], + }); + }, + ); + + test('flattens mixed multimodal outputs with source metadata', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + final req = EmbedRequest( + input: [ + DocumentData( + content: [ + TextPart(text: 'hello'), + MediaPart(media: Media(url: 'data:image/png;base64,AA==')), + MediaPart(media: Media(url: 'data:video/mp4;base64,AA==')), + ], + ), + DocumentData(content: [TextPart(text: 'world')]), + ], + ); + + final response = await embedder.run(req); + + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final instances = requestBody['instances'] as List; + expect(instances, hasLength(2)); + expect(instances.first, { + 'text': 'hello', + 'image': {'bytesBase64Encoded': 'AA==', 'mimeType': 'image/png'}, + 'video': {'bytesBase64Encoded': 'AA==', 'mimeType': 'video/mp4'}, + }); + expect(instances[1], {'text': 'world'}); + + final embeddings = response.result.embeddings; + expect(embeddings, hasLength(4)); + expect(embeddings[0].embedding, [0.7, 0.8, 0.9]); + expect(embeddings[0].metadata, { + 'documentIndex': 0, + 'modality': 'text', + 'partIndices': [0], + }); + expect(embeddings[1].embedding, [1.7, 1.8, 1.9]); + expect(embeddings[1].metadata, { + 'documentIndex': 0, + 'modality': 'image', + 'partIndex': 1, + }); + expect(embeddings[2].embedding, [2.7, 2.8, 2.9]); + expect(embeddings[2].metadata, { + 'documentIndex': 0, + 'modality': 'video', + 'partIndex': 2, + 'segmentIndex': 0, + 'startOffsetSec': 0, + 'endOffsetSec': 16, + }); + expect(embeddings[3].embedding, [1.7, 1.8, 1.9]); + expect(embeddings[3].metadata, { + 'documentIndex': 1, + 'modality': 'text', + 'partIndices': [0], + }); + }); + + test('uses data URI image inputs in multimodal predict requests', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart(media: Media(url: 'data:image/png;base64,AA==')), + ], + ), + ], + ), + ); + + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final instances = requestBody['instances'] as List; + expect(instances.single, { + 'image': {'bytesBase64Encoded': 'AA==', 'mimeType': 'image/png'}, + }); + }); + + test('uses gs image inputs in multimodal predict requests', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart( + media: Media( + url: 'gs://my-bucket/image.png', + contentType: 'image/png', + ), + ), + ], + ), + ], + ), + ); + + final requestBody = + jsonDecode(mockClient.lastBody!) as Map; + final instances = requestBody['instances'] as List; + expect(instances.single, { + 'image': { + 'gcsUri': 'gs://my-bucket/image.png', + 'mimeType': 'image/png', + }, + }); + }); + + test('throws when a multimodal data URI is malformed', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart( + media: Media( + url: 'data:image/png;base64,%', + contentType: 'image/png', + ), + ), + ], + ), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains( + 'Vertex multimodalembedding media inputs require a valid data URI.', + ), + ), + ), + ); + }); + + test('throws when a multimodal data URI MIME cannot be parsed', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart(media: Media(url: 'data:image/png;base64,%')), + ], + ), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains( + 'Vertex multimodalembedding media inputs require a MIME type.', + ), + ), + ), + ); + }); + + test('throws when a multimodal document has multiple images', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart(media: Media(url: 'data:image/png;base64,AA==')), + MediaPart(media: Media(url: 'data:image/jpeg;base64,AA==')), + ], + ), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains( + 'Vertex multimodalembedding supports at most one image part per input document.', + ), + ), + ), + ); + }); + + test('throws when multimodal media has an unsupported MIME type', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart(media: Media(url: 'data:audio/wav;base64,AA==')), + ], + ), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains( + 'Unsupported Vertex multimodalembedding media MIME type: audio/wav', + ), + ), + ), + ); + }); + + test('throws when multimodal media is missing a MIME type', () async { + final mockClient = MockHttpClient(); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData( + content: [ + MediaPart(media: Media(url: 'gs://my-bucket/image.png')), + ], + ), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains( + 'Vertex multimodalembedding media inputs require a MIME type.', + ), + ), + ), + ); + }); + + test( + 'throws a descriptive error when Vertex returns no predictions', + () async { + final mockClient = MockHttpClient(returnEmptyPredictions: true); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'gemini-embedding-001'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains('Vertex AI returned no predictions.'), + ), + ), + ); + }, + ); + + test( + 'throws when a multimodal response omits the expected embedding field', + () async { + final mockClient = MockHttpClient( + returnMissingMultimodalEmbedding: true, + ); + final plugin = VertexAiPluginImpl( + projectId: 'my-project', + location: 'us-central1', + authClient: mockClient, + ); + + final embedder = _resolveEmbedder(plugin, 'multimodalembedding'); + + await expectLater( + () => embedder.run( + EmbedRequest( + input: [ + DocumentData(content: [TextPart(text: 'hello')]), + ], + ), + ), + throwsA( + isA().having( + (error) => error.message, + 'message', + contains('did not return a text embedding'), + ), + ), + ); + }, + ); + }); +} diff --git a/packages/genkit_vertexai/test/test_http_client.dart b/packages/genkit_vertexai/test/test_http_client.dart new file mode 100644 index 00000000..ee040a64 --- /dev/null +++ b/packages/genkit_vertexai/test/test_http_client.dart @@ -0,0 +1,163 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:http/http.dart' as http; + +class MockHttpClient extends http.BaseClient { + MockHttpClient({ + this.returnEmptyPredictions = false, + this.returnInvalidTextPrediction = false, + this.returnMissingMultimodalEmbedding = false, + }); + + final bool returnEmptyPredictions; + final bool returnInvalidTextPrediction; + final bool returnMissingMultimodalEmbedding; + final List requestUrls = []; + final List requestBodies = []; + Uri? lastUrl; + String? lastBody; + + @override + Future send(http.BaseRequest request) async { + requestUrls.add(request.url); + lastUrl = request.url; + final requestBody = request is http.Request ? request.body : null; + if (request is http.Request) { + lastBody = requestBody; + requestBodies.add(requestBody!); + } + if (request.url.host == 'metadata.google.internal' || + request.url.host == 'oauth2.googleapis.com') { + return http.StreamedResponse( + Stream.value( + utf8.encode( + '{"access_token": "ya29.mock", "expires_in": 3600, "token_type": "Bearer"}', + ), + ), + 200, + headers: {'content-type': 'application/json'}, + ); + } + if (request.url.path == '/v1beta1/publishers/google/models') { + return http.StreamedResponse( + Stream.value( + utf8.encode( + '{"publisherModels": [{"name": "publishers/google/models/gemini-1.5-pro"}, {"name": "publishers/google/models/text-embedding-005"}, {"name": "publishers/google/models/multimodalembedding"}]}', + ), + ), + 200, + headers: {'content-type': 'application/json'}, + ); + } + if (request.url.path.endsWith('gemini-embedding-001:embedContent')) { + return http.StreamedResponse( + Stream.value( + utf8.encode( + '{"error": {"code": 400, "message": "Publisher Model `projects/my-project/locations/us-central1/publishers/google/models/gemini-embedding-001` is not supported in the embedContent API.", "status": "INVALID_ARGUMENT"}}', + ), + ), + 400, + headers: {'content-type': 'application/json'}, + ); + } + if (request.url.path.endsWith(':batchEmbedContents')) { + final body = jsonDecode(requestBody!) as Map; + final requests = body['requests'] as List; + final embeddings = List.generate( + requests.length, + (index) => { + 'values': [index + 0.1, index + 0.2, index + 0.3], + }, + ); + return http.StreamedResponse( + Stream.value(utf8.encode(jsonEncode({'embeddings': embeddings}))), + 200, + headers: {'content-type': 'application/json'}, + ); + } + if (request.url.path.endsWith(':embedContent')) { + return http.StreamedResponse( + Stream.value(utf8.encode('{"embedding": {"values": [0.1, 0.2, 0.3]}}')), + 200, + headers: {'content-type': 'application/json'}, + ); + } + if (request.url.path.contains('multimodalembedding') && + request.url.path.endsWith(':predict')) { + final body = jsonDecode(requestBody!) as Map; + final instances = body['instances'] as List; + final predictions = returnEmptyPredictions + ? const [] + : List.generate(instances.length, (index) { + if (returnMissingMultimodalEmbedding) { + return {}; + } + + final instance = instances[index] as Map; + return { + if (instance.containsKey('text')) + 'textEmbedding': [index + 0.7, index + 0.8, index + 0.9], + if (instance.containsKey('image')) + 'imageEmbedding': [index + 1.7, index + 1.8, index + 1.9], + if (instance.containsKey('video')) + 'videoEmbeddings': [ + { + 'embedding': [index + 2.7, index + 2.8, index + 2.9], + 'startOffsetSec': 0, + 'endOffsetSec': 16, + }, + ], + }; + }); + return http.StreamedResponse( + Stream.value(utf8.encode(jsonEncode({'predictions': predictions}))), + 200, + headers: {'content-type': 'application/json'}, + ); + } + if (request.url.path.endsWith(':predict')) { + final body = jsonDecode(requestBody!) as Map; + final instances = body['instances'] as List; + final predictions = returnEmptyPredictions + ? const [] + : List.generate( + instances.length, + (index) => returnInvalidTextPrediction + ? {'embeddings': {}} + : { + 'embeddings': { + 'values': [index + 0.4, index + 0.5, index + 0.6], + }, + }, + ); + return http.StreamedResponse( + Stream.value(utf8.encode(jsonEncode({'predictions': predictions}))), + 200, + headers: {'content-type': 'application/json'}, + ); + } + return http.StreamedResponse( + Stream.value( + utf8.encode( + '{"candidates": [{"content": {"parts": [{"text": "response"}], "role": "model"}, "finishReason": "STOP"}]} ', + ), + ), + 200, + headers: {'content-type': 'application/json'}, + ); + } +} diff --git a/packages/genkit_vertexai/test/vertex_test.dart b/packages/genkit_vertexai/test/vertex_test.dart index 99aba838..c7d96f16 100644 --- a/packages/genkit_vertexai/test/vertex_test.dart +++ b/packages/genkit_vertexai/test/vertex_test.dart @@ -12,43 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -import 'dart:convert'; - import 'package:genkit/genkit.dart'; - import 'package:genkit_vertexai/src/vertex_api_client.dart'; -import 'package:http/http.dart' as http; import 'package:test/test.dart'; -class MockHttpClient extends http.BaseClient { - Uri? lastUrl; - - @override - Future send(http.BaseRequest request) async { - lastUrl = request.url; - if (request.url.host == 'metadata.google.internal' || - request.url.host == 'oauth2.googleapis.com') { - return http.StreamedResponse( - Stream.value( - utf8.encode( - '{"access_token": "ya29.mock", "expires_in": 3600, "token_type": "Bearer"}', - ), - ), - 200, - headers: {'content-type': 'application/json'}, - ); - } - return http.StreamedResponse( - Stream.value( - utf8.encode( - '{"candidates": [{"content": {"parts": [{"text": "response"}], "role": "model"}, "finishReason": "STOP"}]} ', - ), - ), - 200, - headers: {'content-type': 'application/json'}, - ); - } -} +import 'test_http_client.dart'; void main() { group('Vertex AI Plugin', () {