Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .changeset/fifty-plants-drive.md

This file was deleted.

6 changes: 0 additions & 6 deletions .changeset/hungry-dolls-turn.md

This file was deleted.

5 changes: 5 additions & 0 deletions .changeset/small-parrots-lick.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@langchain/community": patch
---

fix(community): Add INSERT support to PrismaVectorStore for ParentDocumentRetriever compatibility (#8833)
81 changes: 77 additions & 4 deletions libs/langchain-community/src/vectorstores/prisma.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ export class PrismaVectorStore<

protected columnTypes?: ColumnTypeConfig;

/**
* When true, addDocuments uses INSERT statements to create new records.
* When false (default), addDocuments uses UPDATE statements to update existing records by ID.
* Set to true when using with ParentDocumentRetriever or when documents don't pre-exist in the database.
*/
protected useInsert: boolean;

static IdColumn: typeof IdColumnSymbol = IdColumnSymbol;

static ContentColumn: typeof ContentColumnSymbol = ContentColumnSymbol;
Expand All @@ -160,6 +167,12 @@ export class PrismaVectorStore<
columns: TSelectModel;
filter?: TFilterModel;
columnTypes?: ColumnTypeConfig;
/**
* When true, addDocuments uses INSERT statements to create new records.
* When false (default), addDocuments uses UPDATE statements to update existing records by ID.
* Set to true when using with ParentDocumentRetriever or when documents don't pre-exist in the database.
*/
useInsert?: boolean;
}
) {
super(embeddings, {});
Expand All @@ -182,6 +195,7 @@ export class PrismaVectorStore<
this.tableName = config.tableName;
this.vectorColumnName = config.vectorColumnName;
this.columnTypes = config.columnTypes;
this.useInsert = config.useInsert ?? false;

this.selectColumns = entries
.map(([key, alias]) => (alias && key) || null)
Expand Down Expand Up @@ -211,6 +225,7 @@ export class PrismaVectorStore<
columns: TColumns;
filter?: TFilters;
columnTypes?: ColumnTypeConfig;
useInsert?: boolean;
}
) {
type ModelName = keyof TPrisma["ModelName"] & string;
Expand All @@ -233,6 +248,7 @@ export class PrismaVectorStore<
vectorColumnName: string;
columns: TColumns;
columnTypes?: ColumnTypeConfig;
useInsert?: boolean;
}
) {
const docs: Document[] = [];
Expand Down Expand Up @@ -264,6 +280,7 @@ export class PrismaVectorStore<
vectorColumnName: string;
columns: TColumns;
columnTypes?: ColumnTypeConfig;
useInsert?: boolean;
}
) {
type ModelName = keyof TPrisma["ModelName"] & string;
Expand Down Expand Up @@ -303,10 +320,12 @@ export class PrismaVectorStore<
*/
async addDocuments(documents: Document<TModel>[]) {
const texts = documents.map(({ pageContent }) => pageContent);
return this.addVectors(
await this.embeddings.embedDocuments(texts),
documents
);
const vectors = await this.embeddings.embedDocuments(texts);

if (this.useInsert) {
return this.addDocumentsWithVectors(vectors, documents);
}
return this.addVectors(vectors, documents);
}

/**
Expand Down Expand Up @@ -350,6 +369,58 @@ export class PrismaVectorStore<
);
}

/**
* Adds documents with their corresponding vectors to the store using INSERT statements.
* This method ensures documents are created if they don't exist, making it compatible
* with ParentDocumentRetriever which creates new child documents.
* @param vectors The vectors to add.
* @param documents The documents associated with the vectors.
* @returns A promise that resolves when the documents have been added.
*/
async addDocumentsWithVectors(
vectors: number[][],
documents: Document<TModel>[]
) {
// table name, column name cannot be parametrised
// these fields are thus not escaped by Prisma and can be dangerous if user input is used
const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`);
const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`);

// Build column names for INSERT statement
const columnNames = this.selectColumns.map((col) =>
this.Prisma.raw(`"${col}"`)
);
const allColumns = [...columnNames, vectorColumnRaw];

await this.db.$transaction(
vectors.map((vector, idx) => {
const document = documents[idx];
const vectorString = `[${vector.join(",")}]`;

// Build values for each column
const columnValues = this.selectColumns.map((col) => {
if (col === this.contentColumn) {
return document.pageContent;
}
return document.metadata[col];
});

// Add vector as the last value
const allValues = [
...columnValues,
this.Prisma.sql`${vectorString}::vector`,
];

return this.db.$executeRaw(
this.Prisma.sql`
INSERT INTO ${tableNameRaw} (${this.Prisma.join(allColumns, ", ")})
VALUES (${this.Prisma.join(allValues, ", ")})
`
);
})
);
}

/**
* Performs a similarity search with the specified query.
* @param query The query to use for the similarity search.
Expand Down Expand Up @@ -572,6 +643,7 @@ export class PrismaVectorStore<
vectorColumnName: string;
columns: ModelColumns<Record<string, unknown>>;
columnTypes?: ColumnTypeConfig;
useInsert?: boolean;
}
): Promise<DefaultPrismaVectorStore> {
const docs: Document[] = [];
Expand Down Expand Up @@ -604,6 +676,7 @@ export class PrismaVectorStore<
vectorColumnName: string;
columns: ModelColumns<Record<string, unknown>>;
columnTypes?: ColumnTypeConfig;
useInsert?: boolean;
}
): Promise<DefaultPrismaVectorStore> {
const instance = new PrismaVectorStore(embeddings, dbConfig);
Expand Down
141 changes: 141 additions & 0 deletions libs/langchain-community/src/vectorstores/tests/prisma.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { FakeEmbeddings } from "@langchain/core/utils/testing";
import { jest, test, expect } from "@jest/globals";
import { Document } from "@langchain/core/documents";
import { PrismaVectorStore } from "../prisma.js";

class Sql {
Expand Down Expand Up @@ -39,6 +40,7 @@ describe("Prisma", () => {
beforeEach(() => {
jest.clearAllMocks();
});

test("passes provided filters with simiaritySearch", async () => {
const embeddings = new FakeEmbeddings();
const store = new PrismaVectorStore(new FakeEmbeddings(), {
Expand Down Expand Up @@ -270,4 +272,143 @@ describe("Prisma", () => {
expect(sqlCall).toBeDefined();
});
});

test("addDocumentsWithVectors creates new documents with INSERT", async () => {
const embeddings = new FakeEmbeddings();
const store = new PrismaVectorStore(embeddings, {
db: mockPrismaClient,
prisma: mockPrismaNamespace,
tableName: "test",
vectorColumnName: "vector",
columns: mockColumns,
});

const documents = [
new Document({
pageContent: "test content 1",
metadata: { id: "doc1", custom: "value1" },
}),
new Document({
pageContent: "test content 2",
metadata: { id: "doc2", custom: "value2" },
}),
];

const vectors = [
[1, 2, 3],
[4, 5, 6],
];

// Mock the transaction to capture the SQL statements
$transaction.mockImplementation((queries) => {
// Verify that INSERT statements are being used
expect(queries).toHaveLength(2);
return Promise.resolve();
});

await store.addDocumentsWithVectors(vectors, documents);

expect($transaction).toHaveBeenCalledTimes(1);
expect($executeRaw).toHaveBeenCalledTimes(2);
});

test("addDocuments uses addVectors by default (backward compatibility)", async () => {
const embeddings = new FakeEmbeddings();
const store = new PrismaVectorStore(embeddings, {
db: mockPrismaClient,
prisma: mockPrismaNamespace,
tableName: "test",
vectorColumnName: "vector",
columns: mockColumns,
});

const documents = [
new Document({
pageContent: "test content",
metadata: { id: "doc1" },
}),
];

// Spy on both methods
const addDocumentsWithVectorsSpy = jest
.spyOn(store, "addDocumentsWithVectors")
.mockResolvedValue();
const addVectorsSpy = jest.spyOn(store, "addVectors").mockResolvedValue();

await store.addDocuments(documents);

// Verify addVectors was called (default behavior)
expect(addVectorsSpy).toHaveBeenCalledTimes(1);
// Verify addDocumentsWithVectors was NOT called
expect(addDocumentsWithVectorsSpy).not.toHaveBeenCalled();
});

test("addDocuments uses addDocumentsWithVectors when useInsert is true", async () => {
const embeddings = new FakeEmbeddings();
const store = new PrismaVectorStore(embeddings, {
db: mockPrismaClient,
prisma: mockPrismaNamespace,
tableName: "test",
vectorColumnName: "vector",
columns: mockColumns,
useInsert: true,
});

const documents = [
new Document({
pageContent: "test content",
metadata: { id: "doc1" },
}),
];

// Spy on both methods
const addDocumentsWithVectorsSpy = jest
.spyOn(store, "addDocumentsWithVectors")
.mockResolvedValue();
const addVectorsSpy = jest.spyOn(store, "addVectors").mockResolvedValue();

await store.addDocuments(documents);

// Verify addDocumentsWithVectors was called
expect(addDocumentsWithVectorsSpy).toHaveBeenCalledTimes(1);
// Verify addVectors was NOT called
expect(addVectorsSpy).not.toHaveBeenCalled();
});

test("addVectors still uses UPDATE statements for backward compatibility", async () => {
const embeddings = new FakeEmbeddings();
const store = new PrismaVectorStore(embeddings, {
db: mockPrismaClient,
prisma: mockPrismaNamespace,
tableName: "test",
vectorColumnName: "vector",
columns: mockColumns,
});

const documents = [
new Document({
pageContent: "test content",
metadata: { id: "doc1" },
}),
];

const vectors = [[1, 2, 3]];

// Mock sql function to capture the SQL template
let capturedSql = "";
// @ts-expect-error - we are mocking the sql function
sql.mockImplementation((strings: string[], ...values) => {
capturedSql = strings.join("");
return { strings, values };
});

$transaction.mockResolvedValue([]);

await store.addVectors(vectors, documents);

expect($transaction).toHaveBeenCalledTimes(1);
// Verify UPDATE statement is used
expect(capturedSql).toContain("UPDATE");
expect(capturedSql).not.toContain("INSERT");
});
});
Loading