Skip to content

Feat(amazon bedrock): Add support for cohere embed models #6190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
5 changes: 5 additions & 0 deletions .changeset/fresh-pears-wave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/amazon-bedrock': patch
---

added support for amazon bedrock cohere models
29 changes: 25 additions & 4 deletions content/providers/01-ai-sdk-providers/08-amazon-bedrock.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,33 @@ The following optional settings are available for Bedrock Titan embedding models

Flag indicating whether or not to normalize the output embeddings. Defaults to true.

Bedrock Cohere embedding models (like cohere.embed-english-v3) support additional settings:

```ts
const model = bedrock.embedding('cohere.embed-english-v3', {
inputType: 'search_document', // optional, type of input being embedded
truncate: 'END' // optional, how to handle inputs that exceed the token limit
})
```

The following optional settings are available for Bedrock Cohere embedding models:

- **inputType**: _'search_document' | 'search_query' | 'classification' | 'clustering'_

Specifies the type of input being embedded. Default is 'search_document'.

- **truncate**: _'NONE' | 'START' | 'END'_

Controls how input is truncated if it exceeds the token limit. Default is 'NONE'.

### Model Capabilities

| Model | Default Dimensions | Custom Dimensions |
| ------------------------------ | ------------------ | ------------------- |
| `amazon.titan-embed-text-v1` | 1536 | <Cross size={18} /> |
| `amazon.titan-embed-text-v2:0` | 1024 | <Check size={18} /> |
| Model | Default Dimensions | Custom Dimensions | Additional Features |
| ------------------------------ | ------------------ | ------------------- | ------------------- |
| `amazon.titan-embed-text-v1` | 1536 | <Cross size={18} /> | None |
| `amazon.titan-embed-text-v2:0` | 1024 | <Check size={18} /> | None |
| `cohere.embed-english-v3` | 1024 | <Cross size={18} /> | Search, classification, clustering |
| `cohere.embed-multilingual-v3` | 1024 | <Cross size={18} /> | Search, classification, clustering |

## Image Models

Expand Down
300 changes: 296 additions & 4 deletions packages/amazon-bedrock/src/bedrock-embedding-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@ const mockEmbeddings = [
const fakeFetchWithAuth = injectFetchHeaders({ 'x-amz-auth': 'test-auth' });

const testValues = ['sunny day at the beach', 'rainy day in the city'];
const mockImageUri =
'data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/example';

const embedUrl = `https://bedrock-runtime.us-east-1.amazonaws.com/model/${encodeURIComponent(
const titanEmbedUrl = `https://bedrock-runtime.us-east-1.amazonaws.com/model/${encodeURIComponent(
'amazon.titan-embed-text-v2:0',
)}/invoke`;

describe('doEmbed', () => {
const cohereEmbedUrl = `https://bedrock-runtime.us-east-1.amazonaws.com/model/${encodeURIComponent(
'cohere.embed-english-v3',
)}/invoke`;

describe('doEmbed with Titan models', () => {
const mockConfigHeaders = {
'config-header': 'config-value',
'shared-header': 'config-shared',
};

const server = createTestServer({
[embedUrl]: {
[titanEmbedUrl]: {
response: {
type: 'binary',
headers: {
Expand Down Expand Up @@ -52,7 +58,7 @@ describe('doEmbed', () => {

beforeEach(() => {
callCount = 0;
server.urls[embedUrl].response = {
server.urls[titanEmbedUrl].response = {
type: 'binary',
headers: {
'content-type': 'application/json',
Expand Down Expand Up @@ -159,3 +165,289 @@ describe('doEmbed', () => {
expect(requestHeaders['authorization']).toBe('AWS4-HMAC-SHA256...');
});
});

describe('doEmbed with Cohere models', () => {
const server = createTestServer({
[cohereEmbedUrl]: {
response: {
type: 'binary',
headers: {
'content-type': 'application/json',
},
body: Buffer.from(
JSON.stringify({
embeddings: [mockEmbeddings[0]],
id: 'emb_123456',
response_type: 'embeddings_floats',
texts: [testValues[0]],
}),
),
},
},
});

const cohereModel = new BedrockEmbeddingModel(
'cohere.embed-english-v3',
{
cohere: {
input_type: 'search_document',
truncate: 'NONE',
embedding_types: ['float'],
},
},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {},
fetch: fakeFetchWithAuth,
},
);

beforeEach(() => {
server.urls[cohereEmbedUrl].response = {
type: 'binary',
headers: {
'content-type': 'application/json',
},
body: Buffer.from(
JSON.stringify({
embeddings: [mockEmbeddings[0]],
id: 'emb_123456',
response_type: 'embeddings_floats',
texts: [testValues[0]],
}),
),
};
});

it('should handle single input value for Cohere models', async () => {
const { embeddings } = await cohereModel.doEmbed({
values: [testValues[0]],
});

expect(embeddings.length).toBe(1);
expect(embeddings[0]).toStrictEqual(mockEmbeddings[0]);

const body = await server.calls[0].requestBody;
expect(body).toEqual({
texts: [testValues[0]],
input_type: 'search_document',
truncate: 'NONE',
embedding_types: ['float'],
});
});

it('should handle multiple input values for Cohere models', async () => {
// Update server response for multiple inputs
server.urls[cohereEmbedUrl].response = {
type: 'binary',
headers: {
'content-type': 'application/json',
},
body: Buffer.from(
JSON.stringify({
embeddings: mockEmbeddings,
id: 'emb_123456',
response_type: 'embeddings_floats',
texts: testValues,
}),
),
};

const { embeddings } = await cohereModel.doEmbed({
values: testValues,
});

expect(embeddings.length).toBe(2);
expect(embeddings[0]).toStrictEqual(mockEmbeddings[0]);
expect(embeddings[1]).toStrictEqual(mockEmbeddings[1]);

const body = await server.calls[0].requestBody;
expect(body).toEqual({
texts: testValues,
input_type: 'search_document',
truncate: 'NONE',
embedding_types: ['float'],
});
});

it('should estimate token usage for Cohere models', async () => {
const { usage } = await cohereModel.doEmbed({
values: [testValues[0]],
});

// Based on the approximate 1 token = 4 chars rule mentioned in AWS docs
const expectedTokens = Math.ceil(testValues[0].length / 4);
expect(usage?.tokens).toStrictEqual(expectedTokens);
});

it('should use default cohere settings if not specified', async () => {
const cohereModelWithDefaults = new BedrockEmbeddingModel(
'cohere.embed-english-v3',
{}, // No cohere settings specified
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {},
fetch: fakeFetchWithAuth,
},
);

await cohereModelWithDefaults.doEmbed({
values: [testValues[0]],
});

const body = await server.calls[0].requestBody;
expect(body).toEqual({
texts: [testValues[0]],
input_type: 'search_document', // Default value
truncate: 'NONE', // Default value
embedding_types: ['float'], // Default value
});
});

it('should handle image input for Cohere models', async () => {
// Set up model with image settings
const cohereImageModel = new BedrockEmbeddingModel(
'cohere.embed-english-v3',
{
cohere: {
input_type: 'image',
images: [mockImageUri],
embedding_types: ['float'],
},
},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {},
fetch: fakeFetchWithAuth,
},
);

// Mock the response for image embed
server.urls[cohereEmbedUrl].response = {
type: 'binary',
headers: {
'content-type': 'application/json',
},
body: Buffer.from(
JSON.stringify({
embeddings: [mockEmbeddings[0]],
id: 'emb_789012',
response_type: 'embeddings_floats',
}),
),
};

// Call embed on image
await cohereImageModel.doEmbed({
values: ['placeholder'], // Value is ignored for image embedding
});

const body = await server.calls[0].requestBody;
expect(body).toEqual({
images: [mockImageUri],
input_type: 'image',
truncate: 'NONE',
embedding_types: ['float'],
});
});

it('should handle multiple embedding types in response', async () => {
// Set up model with multiple embedding types
const cohereMultiTypeModel = new BedrockEmbeddingModel(
'cohere.embed-english-v3',
{
cohere: {
input_type: 'search_document',
embedding_types: ['float', 'int8'],
},
},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {},
fetch: fakeFetchWithAuth,
},
);

// Mock the response for multiple embedding types
server.urls[cohereEmbedUrl].response = {
type: 'binary',
headers: {
'content-type': 'application/json',
},
body: Buffer.from(
JSON.stringify({
embeddings: {
float: [mockEmbeddings[0]],
int8: [[1, 2, 3, 4, 5]],
},
id: 'emb_345678',
response_type: 'embeddings_multiple',
texts: [testValues[0]],
}),
),
};

// Call embed with multiple types
const { embeddings } = await cohereMultiTypeModel.doEmbed({
values: [testValues[0]],
});

const body = await server.calls[0].requestBody;
expect(body).toEqual({
texts: [testValues[0]],
input_type: 'search_document',
truncate: 'NONE',
embedding_types: ['float', 'int8'],
});

// Should prefer float embeddings when multiple types are available
expect(embeddings.length).toBe(1);
expect(embeddings[0]).toStrictEqual(mockEmbeddings[0]);
});

it('should fall back to the first embedding type if float is not available', async () => {
// Set up model with multiple embedding types
const cohereMultiTypeModel = new BedrockEmbeddingModel(
'cohere.embed-english-v3',
{
cohere: {
input_type: 'search_document',
embedding_types: ['int8', 'binary'],
},
},
{
baseUrl: () => 'https://bedrock-runtime.us-east-1.amazonaws.com',
headers: {},
fetch: fakeFetchWithAuth,
},
);

// Mock the response without float embeddings
server.urls[cohereEmbedUrl].response = {
type: 'binary',
headers: {
'content-type': 'application/json',
},
body: Buffer.from(
JSON.stringify({
embeddings: {
int8: [[1, 2, 3, 4, 5]],
binary: [[0, 1, 0, 1, 0]],
},
id: 'emb_901234',
response_type: 'embeddings_multiple',
texts: [testValues[0]],
}),
),
};

// Call embed with multiple types
const { embeddings } = await cohereMultiTypeModel.doEmbed({
values: [testValues[0]],
});

// Should fall back to the first type (int8) when float is not available
expect(embeddings.length).toBe(1);
expect(embeddings[0]).toStrictEqual([1, 2, 3, 4, 5]);
});
});
Loading