Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,28 @@ describe('integration: extension functions', () => {
expect(rows.length).toBeGreaterThan(0);
expect(rows.some((r) => r.id === 1)).toBe(true);
});

it('cosineSimilarity computes similarity for identical vectors', async () => {
const row = await db()
.posts.select('id')
.select('similarity', (f, fns) => fns.cosineSimilarity(f.embedding, [1, 0, 0]))
.where((f, fns) => fns.eq(f.id, 1))
.first();
expect(row).not.toBeNull();
// template: 1 - (self <=> arg0), identical vectors → 1 - 0 = 1
expect(row!.similarity).toBeCloseTo(1, 5);
});

it('cosineSimilarity filters in WHERE', async () => {
// post 1 has embedding [1,0,0] → similarity to [1,0,0] is 1.0
// post 3 has embedding [0,0,1] → similarity to [1,0,0] is ~0 (orthogonal)
const rows = await collect(
db()
.posts.select('id')
.where((f, fns) => fns.gt(fns.cosineSimilarity(f.embedding, [1, 0, 0]), 0.5))
.all(),
);
expect(rows.length).toBeGreaterThan(0);
expect(rows.some((r) => r.id === 1)).toBe(true);
});
});
16 changes: 15 additions & 1 deletion packages/3-extensions/pgvector/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ The extension provides an `OperationTypes` export for vector operations:
import type { OperationTypes } from '@prisma-next/extension-pgvector/operation-types';

// OperationTypes['pg/vector@1']['cosineDistance'] = (rhs: number[] | vector) => number
// OperationTypes['pg/vector@1']['cosineSimilarity'] = (rhs: number[] | vector) => number
```

## Operations
Expand All @@ -171,11 +172,24 @@ Computes the cosine distance between two vectors.
const distance = tables.post.columns.embedding.cosineDistance(param('queryVector'));
```

### cosineSimilarity

Computes the cosine similarity between two vectors (1 minus cosine distance).

**Signature**: `cosineSimilarity(rhs: number[] | vector): number`

**SQL**: Uses the pgvector `<=>` operator: `1 - (vector1 <=> vector2)`

**Example**:
```typescript
const similarity = tables.post.columns.embedding.cosineSimilarity(param('queryVector'));
```

## Capabilities

The extension declares the following capabilities:

- `pgvector/cosine`: Indicates support for cosine distance operations
- `pgvector/cosine`: Indicates support for cosine distance and similarity operations

## References

Expand Down
38 changes: 32 additions & 6 deletions packages/3-extensions/pgvector/src/core/descriptor-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,36 @@ import type { QueryOperationDescriptor } from '@prisma-next/sql-relational-core/

const pgvectorTypeId = 'pg/vector@1' as const;

const cosineLowering = {
const cosineDistanceLowering = {
targetFamily: 'sql',
strategy: 'function',
template: '{{self}} <=> {{arg0}}',
} as const;

const cosineSimilarityLowering = {
targetFamily: 'sql',
strategy: 'function',
template: '1 - ({{self}} <=> {{arg0}})',
} as const;

const cosineDistanceOperation = Object.freeze({
method: 'cosineDistance',
args: [{ kind: 'param' }],
returns: { kind: 'builtin', type: 'number' },
lowering: cosineLowering,
lowering: cosineDistanceLowering,
} as const);

const cosineSimilarityOperation = Object.freeze({
method: 'cosineSimilarity',
args: [{ kind: 'param' }],
returns: { kind: 'builtin', type: 'number' },
lowering: cosineSimilarityLowering,
} as const);

export const pgvectorOperationSignature: SqlOperationSignature = {
forTypeId: pgvectorTypeId,
...cosineDistanceOperation,
};
export const pgvectorOperationSignatures: readonly SqlOperationSignature[] = [
{ forTypeId: pgvectorTypeId, ...cosineDistanceOperation },
{ forTypeId: pgvectorTypeId, ...cosineSimilarityOperation },
];

export const pgvectorQueryOperations: readonly QueryOperationDescriptor[] = [
{
Expand All @@ -35,6 +48,19 @@ export const pgvectorQueryOperations: readonly QueryOperationDescriptor[] = [
template: '{{self}} <=> {{arg0}}',
},
},
{
method: 'cosineSimilarity',
args: [
{ codecId: pgvectorTypeId, nullable: false },
{ codecId: pgvectorTypeId, nullable: false },
],
returns: { codecId: 'pg/float8@1', nullable: false },
lowering: {
targetFamily: 'sql',
strategy: 'function',
template: '1 - ({{self}} <=> {{arg0}})',
},
},
];

export const pgvectorPackMeta = {
Expand Down
4 changes: 2 additions & 2 deletions packages/3-extensions/pgvector/src/exports/control.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type {
ComponentDatabaseDependencies,
SqlControlExtensionDescriptor,
} from '@prisma-next/family-sql/control';
import { pgvectorOperationSignature, pgvectorPackMeta } from '../core/descriptor-meta';
import { pgvectorOperationSignatures, pgvectorPackMeta } from '../core/descriptor-meta';

const PGVECTOR_CODEC_ID = 'pg/vector@1' as const;

Expand Down Expand Up @@ -75,7 +75,7 @@ const pgvectorExtensionDescriptor: SqlControlExtensionDescriptor<'postgres'> = {
},
},
},
operationSignatures: () => [pgvectorOperationSignature],
operationSignatures: () => pgvectorOperationSignatures,
databaseDependencies: pgvectorDatabaseDependencies,
create: () => ({
familyId: 'sql' as const,
Expand Down
4 changes: 2 additions & 2 deletions packages/3-extensions/pgvector/src/exports/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { type as arktype } from 'arktype';
import { codecDefinitions } from '../core/codecs';
import { VECTOR_CODEC_ID, VECTOR_MAX_DIM } from '../core/constants';
import {
pgvectorOperationSignature,
pgvectorOperationSignatures,
pgvectorPackMeta,
pgvectorQueryOperations,
} from '../core/descriptor-meta';
Expand Down Expand Up @@ -49,7 +49,7 @@ const pgvectorRuntimeDescriptor: SqlRuntimeExtensionDescriptor<'postgres'> = {
familyId: 'sql' as const,
targetId: 'postgres' as const,
codecs: createPgvectorCodecRegistry,
operationSignatures: () => [pgvectorOperationSignature],
operationSignatures: () => pgvectorOperationSignatures,
queryOperations: () => pgvectorQueryOperations,
parameterizedCodecs: () => parameterizedCodecDescriptors,
create() {
Expand Down
23 changes: 23 additions & 0 deletions packages/3-extensions/pgvector/src/types/operation-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ export type OperationTypes = {
readonly template: string;
};
};
readonly cosineSimilarity: {
readonly args: readonly [
{
readonly kind: 'param';
},
];
readonly returns: {
readonly kind: 'builtin';
readonly type: 'number';
};
readonly lowering: {
readonly targetFamily: 'sql';
readonly strategy: 'function';
readonly template: string;
};
};
};
};

Expand All @@ -41,4 +57,11 @@ export type QueryOperationTypes = SqlQueryOperationTypes<{
];
readonly returns: { readonly codecId: 'pg/float8@1'; readonly nullable: false };
};
readonly cosineSimilarity: {
readonly args: readonly [
{ readonly codecId: 'pg/vector@1'; readonly nullable: boolean },
{ readonly codecId: 'pg/vector@1'; readonly nullable: boolean },
];
readonly returns: { readonly codecId: 'pg/float8@1'; readonly nullable: false };
};
}>;
19 changes: 18 additions & 1 deletion packages/3-extensions/pgvector/test/manifest.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ describe('pgvector descriptor', () => {

it('has cosineDistance operation via operationSignatures()', () => {
const operations = pgvectorExtensionDescriptor.operationSignatures();
expect(operations.length).toBeGreaterThan(0);
expect(operations.length).toBe(2);

const cosineDistanceOp = operations.find(
(op) => op.forTypeId === 'pg/vector@1' && op.method === 'cosineDistance',
Expand All @@ -48,6 +48,23 @@ describe('pgvector descriptor', () => {
});
});

it('has cosineSimilarity operation via operationSignatures()', () => {
const operations = pgvectorExtensionDescriptor.operationSignatures();

const cosineSimilarityOp = operations.find(
(op) => op.forTypeId === 'pg/vector@1' && op.method === 'cosineSimilarity',
);

expect(cosineSimilarityOp).toBeDefined();
expect(cosineSimilarityOp?.args).toEqual([{ kind: 'param' }]);
expect(cosineSimilarityOp?.returns).toEqual({ kind: 'builtin', type: 'number' });
expect(cosineSimilarityOp?.lowering).toEqual({
targetFamily: 'sql',
strategy: 'function',
template: '1 - ({{self}} <=> {{arg0}})',
});
});

it(
'codec types are importable',
async () => {
Expand Down
21 changes: 16 additions & 5 deletions packages/3-extensions/pgvector/test/operations.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,29 @@ describe('pgvector operations', () => {
it('descriptor provides operation signatures', () => {
const operations = pgvectorDescriptor.operationSignatures();
expect(operations).toBeDefined();
expect(operations.length).toBe(1);
expect(operations.length).toBe(2);

const cosineDistanceOp = operations[0];
const cosineDistanceOp = operations.find((op) => op.method === 'cosineDistance');
expect(cosineDistanceOp).toBeDefined();
expect(cosineDistanceOp?.forTypeId).toBe('pg/vector@1');
expect(cosineDistanceOp?.method).toBe('cosineDistance');
expect(cosineDistanceOp?.args).toEqual([{ kind: 'param' }]);
expect(cosineDistanceOp?.returns).toEqual({ kind: 'builtin', type: 'number' });
expect(cosineDistanceOp?.lowering).toEqual({
targetFamily: 'sql',
strategy: 'function',
template: '{{self}} <=> {{arg0}}',
});

const cosineSimilarityOp = operations.find((op) => op.method === 'cosineSimilarity');
expect(cosineSimilarityOp).toBeDefined();
expect(cosineSimilarityOp?.forTypeId).toBe('pg/vector@1');
expect(cosineSimilarityOp?.args).toEqual([{ kind: 'param' }]);
expect(cosineSimilarityOp?.returns).toEqual({ kind: 'builtin', type: 'number' });
expect(cosineSimilarityOp?.lowering).toEqual({
targetFamily: 'sql',
strategy: 'function',
template: '1 - ({{self}} <=> {{arg0}})',
});
});

it('operations can be registered in operation registry', () => {
Expand All @@ -49,8 +59,9 @@ describe('pgvector operations', () => {
}

const registeredOps = registry.byType('pg/vector@1');
expect(registeredOps.length).toBe(1);
expect(registeredOps[0]?.method).toBe('cosineDistance');
expect(registeredOps.length).toBe(2);
expect(registeredOps.find((op) => op.method === 'cosineDistance')).toBeDefined();
expect(registeredOps.find((op) => op.method === 'cosineSimilarity')).toBeDefined();
});

it('codecs can be registered in codec registry', () => {
Expand Down
51 changes: 51 additions & 0 deletions packages/3-extensions/pgvector/test/result-type.types.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,57 @@ test('ResultType infers Vector<1536> for parameterized non-nullable vector colum
expectTypeOf(_plan).toExtend<SqlQueryPlan<Row>>();
});

test('cosineSimilarity remains available on parameterized vector columns', () => {
const contractWithVector = validateContract<ContractWithNonNullableVector>({
target: 'postgres',
targetFamily: 'sql' as const,
storageHash: 'sha256:test-core',
profileHash: 'sha256:test-profile',
storage: {
tables: {
post: {
columns: {
id: { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false },
embedding: {
nativeType: 'vector',
codecId: 'pg/vector@1',
nullable: false,
typeParams: { length: 1536 },
},
},
uniques: [],
indexes: [],
foreignKeys: [],
},
},
},
models: {},
relations: {},
mappings: {},
});

const adapter = createPostgresAdapter();
const context = createTestContext(contractWithVector, adapter, {
extensionPacks: [pgvectorDescriptor],
});
const tables = schema<ContractWithNonNullableVector>(context).tables;
const postTable = tables['post'];
if (!postTable) throw new Error('post table not found');
const postColumns = postTable.columns;

const _plan = sql<ContractWithNonNullableVector>({ context })
.from(postTable)
.select({
similarity: postColumns['embedding']!.cosineSimilarity(param('queryVector')),
})
.build({ params: { queryVector: [0, 1, 2] } });

type Row = ResultType<typeof _plan>;
const similarityValue = 0 as Row['similarity'];
const similarityAsExpected: number = similarityValue;
void similarityAsExpected;
});

test('cosineDistance remains available on parameterized vector columns', () => {
const contractWithVector = validateContract<ContractWithNonNullableVector>({
target: 'postgres',
Expand Down
21 changes: 11 additions & 10 deletions test/integration/test/pgvector.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@ describe('pgvector extension pack integration', () => {
const registry = assembleOperationRegistry(descriptors);

const operations = registry.byType('pg/vector@1');
expect(operations.length).toBe(1);
expect(operations[0]?.method).toBe('cosineDistance');
expect(operations[0]?.forTypeId).toBe('pg/vector@1');
expect(operations[0]?.args).toEqual([{ kind: 'param' }]);
expect(operations[0]?.returns).toEqual({ kind: 'builtin', type: 'number' });
expect(operations.length).toBe(2);
expect(operations.find((op) => op.method === 'cosineDistance')).toBeDefined();
expect(operations.find((op) => op.method === 'cosineSimilarity')).toBeDefined();
// Note: lowering is SQL-specific and not part of core OperationSignature
// The SQL family descriptor converts manifests to SqlOperationSignature with lowering
// but the registry returns core OperationSignature types
Expand All @@ -85,11 +83,13 @@ describe('pgvector extension pack integration', () => {
it('descriptor provides operation signatures', () => {
const operations = pgvector.operationSignatures();
expect(operations).toBeDefined();
expect(operations.length).toBe(1);
expect(operations.length).toBe(2);

const cosineDistanceOp = operations[0];
const cosineDistanceOp = operations.find((op) => op.method === 'cosineDistance');
expect(cosineDistanceOp?.forTypeId).toBe('pg/vector@1');
expect(cosineDistanceOp?.method).toBe('cosineDistance');

const cosineSimilarityOp = operations.find((op) => op.method === 'cosineSimilarity');
expect(cosineSimilarityOp?.forTypeId).toBe('pg/vector@1');
});

it('codecs can be registered in registry', { timeout: 1_000 }, () => {
Expand All @@ -116,7 +116,8 @@ describe('pgvector extension pack integration', () => {
}

const registeredOps = registry.byType('pg/vector@1');
expect(registeredOps.length).toBe(1);
expect(registeredOps[0]?.method).toBe('cosineDistance');
expect(registeredOps.length).toBe(2);
expect(registeredOps.find((op) => op.method === 'cosineDistance')).toBeDefined();
expect(registeredOps.find((op) => op.method === 'cosineSimilarity')).toBeDefined();
});
});
9 changes: 9 additions & 0 deletions test/utils/src/operation-descriptors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ export type PgVectorOperations = {
readonly template: string;
};
};
readonly cosineSimilarity: {
readonly args: ReadonlyArray<{ readonly kind: 'typeId'; readonly type: 'pg/vector@1' }>;
readonly returns: { readonly kind: 'builtin'; readonly type: 'number' };
readonly lowering: {
readonly targetFamily: 'sql';
readonly strategy: 'function';
readonly template: string;
};
};
readonly l2Distance: {
readonly args: ReadonlyArray<{ readonly kind: 'typeId'; readonly type: 'pg/vector@1' }>;
readonly returns: { readonly kind: 'builtin'; readonly type: 'number' };
Expand Down
Loading