Skip to content

Commit 6af735d

Browse files
fix(community): Add INSERT support to PrismaVectorStore for ParentDocumentRetriever compatibility (#8833) (#8948)
1 parent 89a7909 commit 6af735d

5 files changed

Lines changed: 223 additions & 15 deletions

File tree

.changeset/fifty-plants-drive.md

Lines changed: 0 additions & 5 deletions
This file was deleted.

.changeset/hungry-dolls-turn.md

Lines changed: 0 additions & 6 deletions
This file was deleted.

.changeset/small-parrots-lick.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@langchain/community": patch
3+
---
4+
5+
fix(community): Add INSERT support to PrismaVectorStore for ParentDocumentRetriever compatibility (#8833)

libs/langchain-community/src/vectorstores/prisma.ts

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ export class PrismaVectorStore<
138138

139139
protected columnTypes?: ColumnTypeConfig;
140140

141+
/**
142+
* When true, addDocuments uses INSERT statements to create new records.
143+
* When false (default), addDocuments uses UPDATE statements to update existing records by ID.
144+
* Set to true when using with ParentDocumentRetriever or when documents don't pre-exist in the database.
145+
*/
146+
protected useInsert: boolean;
147+
141148
static IdColumn: typeof IdColumnSymbol = IdColumnSymbol;
142149

143150
static ContentColumn: typeof ContentColumnSymbol = ContentColumnSymbol;
@@ -160,6 +167,12 @@ export class PrismaVectorStore<
160167
columns: TSelectModel;
161168
filter?: TFilterModel;
162169
columnTypes?: ColumnTypeConfig;
170+
/**
171+
* When true, addDocuments uses INSERT statements to create new records.
172+
* When false (default), addDocuments uses UPDATE statements to update existing records by ID.
173+
* Set to true when using with ParentDocumentRetriever or when documents don't pre-exist in the database.
174+
*/
175+
useInsert?: boolean;
163176
}
164177
) {
165178
super(embeddings, {});
@@ -182,6 +195,7 @@ export class PrismaVectorStore<
182195
this.tableName = config.tableName;
183196
this.vectorColumnName = config.vectorColumnName;
184197
this.columnTypes = config.columnTypes;
198+
this.useInsert = config.useInsert ?? false;
185199

186200
this.selectColumns = entries
187201
.map(([key, alias]) => (alias && key) || null)
@@ -211,6 +225,7 @@ export class PrismaVectorStore<
211225
columns: TColumns;
212226
filter?: TFilters;
213227
columnTypes?: ColumnTypeConfig;
228+
useInsert?: boolean;
214229
}
215230
) {
216231
type ModelName = keyof TPrisma["ModelName"] & string;
@@ -233,6 +248,7 @@ export class PrismaVectorStore<
233248
vectorColumnName: string;
234249
columns: TColumns;
235250
columnTypes?: ColumnTypeConfig;
251+
useInsert?: boolean;
236252
}
237253
) {
238254
const docs: Document[] = [];
@@ -264,6 +280,7 @@ export class PrismaVectorStore<
264280
vectorColumnName: string;
265281
columns: TColumns;
266282
columnTypes?: ColumnTypeConfig;
283+
useInsert?: boolean;
267284
}
268285
) {
269286
type ModelName = keyof TPrisma["ModelName"] & string;
@@ -303,10 +320,12 @@ export class PrismaVectorStore<
303320
*/
304321
async addDocuments(documents: Document<TModel>[]) {
305322
const texts = documents.map(({ pageContent }) => pageContent);
306-
return this.addVectors(
307-
await this.embeddings.embedDocuments(texts),
308-
documents
309-
);
323+
const vectors = await this.embeddings.embedDocuments(texts);
324+
325+
if (this.useInsert) {
326+
return this.addDocumentsWithVectors(vectors, documents);
327+
}
328+
return this.addVectors(vectors, documents);
310329
}
311330

312331
/**
@@ -350,6 +369,58 @@ export class PrismaVectorStore<
350369
);
351370
}
352371

372+
/**
373+
* Adds documents with their corresponding vectors to the store using INSERT statements.
374+
* This method ensures documents are created if they don't exist, making it compatible
375+
* with ParentDocumentRetriever which creates new child documents.
376+
* @param vectors The vectors to add.
377+
* @param documents The documents associated with the vectors.
378+
* @returns A promise that resolves when the documents have been added.
379+
*/
380+
async addDocumentsWithVectors(
381+
vectors: number[][],
382+
documents: Document<TModel>[]
383+
) {
384+
// table name, column name cannot be parametrised
385+
// these fields are thus not escaped by Prisma and can be dangerous if user input is used
386+
const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`);
387+
const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`);
388+
389+
// Build column names for INSERT statement
390+
const columnNames = this.selectColumns.map((col) =>
391+
this.Prisma.raw(`"${col}"`)
392+
);
393+
const allColumns = [...columnNames, vectorColumnRaw];
394+
395+
await this.db.$transaction(
396+
vectors.map((vector, idx) => {
397+
const document = documents[idx];
398+
const vectorString = `[${vector.join(",")}]`;
399+
400+
// Build values for each column
401+
const columnValues = this.selectColumns.map((col) => {
402+
if (col === this.contentColumn) {
403+
return document.pageContent;
404+
}
405+
return document.metadata[col];
406+
});
407+
408+
// Add vector as the last value
409+
const allValues = [
410+
...columnValues,
411+
this.Prisma.sql`${vectorString}::vector`,
412+
];
413+
414+
return this.db.$executeRaw(
415+
this.Prisma.sql`
416+
INSERT INTO ${tableNameRaw} (${this.Prisma.join(allColumns, ", ")})
417+
VALUES (${this.Prisma.join(allValues, ", ")})
418+
`
419+
);
420+
})
421+
);
422+
}
423+
353424
/**
354425
* Performs a similarity search with the specified query.
355426
* @param query The query to use for the similarity search.
@@ -572,6 +643,7 @@ export class PrismaVectorStore<
572643
vectorColumnName: string;
573644
columns: ModelColumns<Record<string, unknown>>;
574645
columnTypes?: ColumnTypeConfig;
646+
useInsert?: boolean;
575647
}
576648
): Promise<DefaultPrismaVectorStore> {
577649
const docs: Document[] = [];
@@ -604,6 +676,7 @@ export class PrismaVectorStore<
604676
vectorColumnName: string;
605677
columns: ModelColumns<Record<string, unknown>>;
606678
columnTypes?: ColumnTypeConfig;
679+
useInsert?: boolean;
607680
}
608681
): Promise<DefaultPrismaVectorStore> {
609682
const instance = new PrismaVectorStore(embeddings, dbConfig);

libs/langchain-community/src/vectorstores/tests/prisma.test.ts

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22
import { FakeEmbeddings } from "@langchain/core/utils/testing";
33
import { jest, test, expect } from "@jest/globals";
4+
import { Document } from "@langchain/core/documents";
45
import { PrismaVectorStore } from "../prisma.js";
56

67
class Sql {
@@ -39,6 +40,7 @@ describe("Prisma", () => {
3940
beforeEach(() => {
4041
jest.clearAllMocks();
4142
});
43+
4244
test("passes provided filters with simiaritySearch", async () => {
4345
const embeddings = new FakeEmbeddings();
4446
const store = new PrismaVectorStore(new FakeEmbeddings(), {
@@ -270,4 +272,143 @@ describe("Prisma", () => {
270272
expect(sqlCall).toBeDefined();
271273
});
272274
});
275+
276+
test("addDocumentsWithVectors creates new documents with INSERT", async () => {
277+
const embeddings = new FakeEmbeddings();
278+
const store = new PrismaVectorStore(embeddings, {
279+
db: mockPrismaClient,
280+
prisma: mockPrismaNamespace,
281+
tableName: "test",
282+
vectorColumnName: "vector",
283+
columns: mockColumns,
284+
});
285+
286+
const documents = [
287+
new Document({
288+
pageContent: "test content 1",
289+
metadata: { id: "doc1", custom: "value1" },
290+
}),
291+
new Document({
292+
pageContent: "test content 2",
293+
metadata: { id: "doc2", custom: "value2" },
294+
}),
295+
];
296+
297+
const vectors = [
298+
[1, 2, 3],
299+
[4, 5, 6],
300+
];
301+
302+
// Mock the transaction to capture the SQL statements
303+
$transaction.mockImplementation((queries) => {
304+
// Verify that INSERT statements are being used
305+
expect(queries).toHaveLength(2);
306+
return Promise.resolve();
307+
});
308+
309+
await store.addDocumentsWithVectors(vectors, documents);
310+
311+
expect($transaction).toHaveBeenCalledTimes(1);
312+
expect($executeRaw).toHaveBeenCalledTimes(2);
313+
});
314+
315+
test("addDocuments uses addVectors by default (backward compatibility)", async () => {
316+
const embeddings = new FakeEmbeddings();
317+
const store = new PrismaVectorStore(embeddings, {
318+
db: mockPrismaClient,
319+
prisma: mockPrismaNamespace,
320+
tableName: "test",
321+
vectorColumnName: "vector",
322+
columns: mockColumns,
323+
});
324+
325+
const documents = [
326+
new Document({
327+
pageContent: "test content",
328+
metadata: { id: "doc1" },
329+
}),
330+
];
331+
332+
// Spy on both methods
333+
const addDocumentsWithVectorsSpy = jest
334+
.spyOn(store, "addDocumentsWithVectors")
335+
.mockResolvedValue();
336+
const addVectorsSpy = jest.spyOn(store, "addVectors").mockResolvedValue();
337+
338+
await store.addDocuments(documents);
339+
340+
// Verify addVectors was called (default behavior)
341+
expect(addVectorsSpy).toHaveBeenCalledTimes(1);
342+
// Verify addDocumentsWithVectors was NOT called
343+
expect(addDocumentsWithVectorsSpy).not.toHaveBeenCalled();
344+
});
345+
346+
test("addDocuments uses addDocumentsWithVectors when useInsert is true", async () => {
347+
const embeddings = new FakeEmbeddings();
348+
const store = new PrismaVectorStore(embeddings, {
349+
db: mockPrismaClient,
350+
prisma: mockPrismaNamespace,
351+
tableName: "test",
352+
vectorColumnName: "vector",
353+
columns: mockColumns,
354+
useInsert: true,
355+
});
356+
357+
const documents = [
358+
new Document({
359+
pageContent: "test content",
360+
metadata: { id: "doc1" },
361+
}),
362+
];
363+
364+
// Spy on both methods
365+
const addDocumentsWithVectorsSpy = jest
366+
.spyOn(store, "addDocumentsWithVectors")
367+
.mockResolvedValue();
368+
const addVectorsSpy = jest.spyOn(store, "addVectors").mockResolvedValue();
369+
370+
await store.addDocuments(documents);
371+
372+
// Verify addDocumentsWithVectors was called
373+
expect(addDocumentsWithVectorsSpy).toHaveBeenCalledTimes(1);
374+
// Verify addVectors was NOT called
375+
expect(addVectorsSpy).not.toHaveBeenCalled();
376+
});
377+
378+
test("addVectors still uses UPDATE statements for backward compatibility", async () => {
379+
const embeddings = new FakeEmbeddings();
380+
const store = new PrismaVectorStore(embeddings, {
381+
db: mockPrismaClient,
382+
prisma: mockPrismaNamespace,
383+
tableName: "test",
384+
vectorColumnName: "vector",
385+
columns: mockColumns,
386+
});
387+
388+
const documents = [
389+
new Document({
390+
pageContent: "test content",
391+
metadata: { id: "doc1" },
392+
}),
393+
];
394+
395+
const vectors = [[1, 2, 3]];
396+
397+
// Mock sql function to capture the SQL template
398+
let capturedSql = "";
399+
// @ts-expect-error - we are mocking the sql function
400+
sql.mockImplementation((strings: string[], ...values) => {
401+
capturedSql = strings.join("");
402+
return { strings, values };
403+
});
404+
405+
$transaction.mockResolvedValue([]);
406+
407+
await store.addVectors(vectors, documents);
408+
409+
expect($transaction).toHaveBeenCalledTimes(1);
410+
// Verify UPDATE statement is used
411+
expect(capturedSql).toContain("UPDATE");
412+
expect(capturedSql).not.toContain("INSERT");
413+
});
273414
});

0 commit comments

Comments
 (0)