Skip to content

Commit 796315f

Browse files
authored
feat: add Rerankers from Azure (#5576)
* Create AzureRerankerApi.credential.ts Add Azure Foundry Reranker integration * Create test * Add files via upload feat: Add Azure Foundry Reranker integration * Delete packages/components/nodes/retrievers/AzureRerankRetriever/test * feat: Add Azure Foundry Reranker integration * Delete packages/components/nodes/retrievers/AzureRerankRetriever/03513-icon-service-AI-Studio.svg * feat: Add Azure Reranker integration * feat: Add Azure Reranker integration * feat: Add Azure Reranker integration * feat: Add Azure Reranker integration * feat: Add Azure Reranker integration * Update AzureRerankRetriever.ts * Update AzureRerank.ts
1 parent 1bfa7a1 commit 796315f

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { INodeParams, INodeCredential } from '../src/Interface'
2+
3+
class AzureRerankerApi implements INodeCredential {
4+
label: string
5+
name: string
6+
version: number
7+
description: string
8+
inputs: INodeParams[]
9+
10+
constructor() {
11+
this.label = 'Azure Foundry API'
12+
this.name = 'azureFoundryApi'
13+
this.version = 1.0
14+
this.description =
15+
'Refer to <a target="_blank" href="https://docs.microsoft.com/en-us/azure/ai-foundry/">Azure AI Foundry documentation</a> for setup instructions'
16+
this.inputs = [
17+
{
18+
label: 'Azure Foundry API Key',
19+
name: 'azureFoundryApiKey',
20+
type: 'password',
21+
description: 'Your Azure AI Foundry API key'
22+
},
23+
{
24+
label: 'Azure Foundry Endpoint',
25+
name: 'azureFoundryEndpoint',
26+
type: 'string',
27+
placeholder: 'https://your-foundry-instance.services.ai.azure.com/providers/cohere/v2/rerank',
28+
description: 'Your Azure AI Foundry endpoint URL'
29+
}
30+
]
31+
}
32+
}
33+
34+
module.exports = { credClass: AzureRerankerApi }
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import axios from 'axios'
2+
import { Callbacks } from '@langchain/core/callbacks/manager'
3+
import { Document } from '@langchain/core/documents'
4+
import { BaseDocumentCompressor } from 'langchain/retrievers/document_compressors'
5+
6+
export class AzureRerank extends BaseDocumentCompressor {
7+
private readonly azureApiKey: string
8+
private readonly azureApiUrl: string
9+
private readonly model: string
10+
private readonly k: number
11+
private readonly maxChunksPerDoc: number
12+
constructor(azureApiKey: string, azureApiUrl: string, model: string, k: number, maxChunksPerDoc: number) {
13+
super()
14+
this.azureApiKey = azureApiKey
15+
this.azureApiUrl = azureApiUrl
16+
this.model = model
17+
this.k = k
18+
this.maxChunksPerDoc = maxChunksPerDoc
19+
}
20+
async compressDocuments(
21+
documents: Document<Record<string, any>>[],
22+
query: string,
23+
_?: Callbacks | undefined
24+
): Promise<Document<Record<string, any>>[]> {
25+
// avoid empty api call
26+
if (documents.length === 0) {
27+
return []
28+
}
29+
const config = {
30+
headers: {
31+
'api-key': `${this.azureApiKey}`,
32+
'Content-Type': 'application/json',
33+
Accept: 'application/json'
34+
}
35+
}
36+
const data = {
37+
model: this.model,
38+
top_n: this.k,
39+
max_chunks_per_doc: this.maxChunksPerDoc,
40+
query: query,
41+
return_documents: false,
42+
documents: documents.map((doc) => doc.pageContent)
43+
}
44+
try {
45+
let returnedDocs = await axios.post(this.azureApiUrl, data, config)
46+
const finalResults: Document<Record<string, any>>[] = []
47+
returnedDocs.data.results.forEach((result: any) => {
48+
const doc = documents[result.index]
49+
doc.metadata.relevance_score = result.relevance_score
50+
finalResults.push(doc)
51+
})
52+
return finalResults.splice(0, this.k)
53+
} catch (error) {
54+
throw new Error(`Azure Rerank API call failed: ${error.message}`)
55+
}
56+
}
57+
}
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import { BaseRetriever } from '@langchain/core/retrievers'
2+
import { VectorStoreRetriever } from '@langchain/core/vectorstores'
3+
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'
4+
import { AzureRerank } from './AzureRerank'
5+
import { getCredentialData, getCredentialParam, handleEscapeCharacters } from '../../../src'
6+
import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface'
7+
8+
class AzureRerankRetriever_Retrievers implements INode {
9+
label: string
10+
name: string
11+
version: number
12+
description: string
13+
type: string
14+
icon: string
15+
category: string
16+
baseClasses: string[]
17+
inputs: INodeParams[]
18+
credential: INodeParams
19+
badge: string
20+
outputs: INodeOutputsValue[]
21+
22+
constructor() {
23+
this.label = 'Azure Rerank Retriever'
24+
this.name = 'AzureRerankRetriever'
25+
this.version = 1.0
26+
this.type = 'Azure Rerank Retriever'
27+
this.icon = 'azurefoundry.svg'
28+
this.category = 'Retrievers'
29+
this.description = 'Azure Rerank indexes the documents from most to least semantically relevant to the query.'
30+
this.baseClasses = [this.type, 'BaseRetriever']
31+
this.credential = {
32+
label: 'Connect Credential',
33+
name: 'credential',
34+
type: 'credential',
35+
credentialNames: ['azureFoundryApi']
36+
}
37+
this.inputs = [
38+
{
39+
label: 'Vector Store Retriever',
40+
name: 'baseRetriever',
41+
type: 'VectorStoreRetriever'
42+
},
43+
{
44+
label: 'Model Name',
45+
name: 'model',
46+
type: 'options',
47+
options: [
48+
{
49+
label: 'rerank-v3.5',
50+
name: 'rerank-v3.5'
51+
},
52+
{
53+
label: 'rerank-english-v3.0',
54+
name: 'rerank-english-v3.0'
55+
},
56+
{
57+
label: 'rerank-multilingual-v3.0',
58+
name: 'rerank-multilingual-v3.0'
59+
},
60+
{
61+
label: 'Cohere-rerank-v4.0-fast',
62+
name: 'Cohere-rerank-v4.0-fast'
63+
},
64+
{
65+
label: 'Cohere-rerank-v4.0-pro',
66+
name: 'Cohere-rerank-v4.0-pro'
67+
}
68+
],
69+
default: 'Cohere-rerank-v4.0-fast',
70+
optional: true
71+
},
72+
{
73+
label: 'Query',
74+
name: 'query',
75+
type: 'string',
76+
description: 'Query to retrieve documents from retriever. If not specified, user question will be used',
77+
optional: true,
78+
acceptVariable: true
79+
},
80+
{
81+
label: 'Top K',
82+
name: 'topK',
83+
description: 'Number of top results to fetch. Default to the TopK of the Base Retriever',
84+
placeholder: '4',
85+
type: 'number',
86+
additionalParams: true,
87+
optional: true
88+
},
89+
{
90+
label: 'Max Chunks Per Doc',
91+
name: 'maxChunksPerDoc',
92+
description: 'The maximum number of chunks to produce internally from a document. Default to 10',
93+
placeholder: '10',
94+
type: 'number',
95+
additionalParams: true,
96+
optional: true
97+
}
98+
]
99+
this.outputs = [
100+
{
101+
label: 'Azure Rerank Retriever',
102+
name: 'retriever',
103+
baseClasses: this.baseClasses
104+
},
105+
{
106+
label: 'Document',
107+
name: 'document',
108+
description: 'Array of document objects containing metadata and pageContent',
109+
baseClasses: ['Document', 'json']
110+
},
111+
{
112+
label: 'Text',
113+
name: 'text',
114+
description: 'Concatenated string from pageContent of documents',
115+
baseClasses: ['string', 'json']
116+
}
117+
]
118+
}
119+
120+
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
121+
const baseRetriever = nodeData.inputs?.baseRetriever as BaseRetriever
122+
const model = nodeData.inputs?.model as string
123+
const query = nodeData.inputs?.query as string
124+
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
125+
const azureApiKey = getCredentialParam('azureFoundryApiKey', credentialData, nodeData)
126+
if (!azureApiKey) {
127+
throw new Error('Azure Foundry API Key is missing in credentials.')
128+
}
129+
const azureEndpoint = getCredentialParam('azureFoundryEndpoint', credentialData, nodeData)
130+
if (!azureEndpoint) {
131+
throw new Error('Azure Foundry Endpoint is missing in credentials.')
132+
}
133+
const topK = nodeData.inputs?.topK as string
134+
const k = topK ? parseFloat(topK) : (baseRetriever as VectorStoreRetriever).k ?? 4
135+
const maxChunksPerDoc = nodeData.inputs?.maxChunksPerDoc as string
136+
const maxChunksPerDocValue = maxChunksPerDoc ? parseFloat(maxChunksPerDoc) : 10
137+
const output = nodeData.outputs?.output as string
138+
139+
const azureCompressor = new AzureRerank(azureApiKey, azureEndpoint, model, k, maxChunksPerDocValue)
140+
141+
const retriever = new ContextualCompressionRetriever({
142+
baseCompressor: azureCompressor,
143+
baseRetriever: baseRetriever
144+
})
145+
146+
if (output === 'retriever') return retriever
147+
else if (output === 'document') return await retriever.getRelevantDocuments(query ? query : input)
148+
else if (output === 'text') {
149+
let finaltext = ''
150+
151+
const docs = await retriever.getRelevantDocuments(query ? query : input)
152+
153+
for (const doc of docs) finaltext += `${doc.pageContent}\n`
154+
155+
return handleEscapeCharacters(finaltext, false)
156+
}
157+
158+
return retriever
159+
}
160+
}
161+
162+
module.exports = { nodeClass: AzureRerankRetriever_Retrievers }
Lines changed: 1 addition & 0 deletions
Loading

0 commit comments

Comments
 (0)