diff --git a/docs/extras/modules/data_connection/retrievers/how_to/multi-query-retriever.mdx b/docs/extras/modules/data_connection/retrievers/how_to/multi-query-retriever.mdx new file mode 100644 index 000000000000..93d2a9a40955 --- /dev/null +++ b/docs/extras/modules/data_connection/retrievers/how_to/multi-query-retriever.mdx @@ -0,0 +1,29 @@ +--- +hide_table_of_contents: true +--- + +# MultiQuery Retriever + +Distance-based vector database retrieval embeds (represents) queries in high-dimensional space and finds similar embedded documents based on "distance". +But retrieval may produce different results with subtle changes in query wording or if the embeddings do not capture the semantics of the data well. +Prompt engineering / tuning is sometimes done to manually address these problems, but can be tedious. + +The MultiQueryRetriever automates the process of prompt tuning by using an LLM to generate multiple queries from different perspectives for a given user input query. +For each query, it retrieves a set of relevant documents and takes the unique union across all queries to get a larger set of potentially relevant documents. +By generating multiple perspectives on the same question, the MultiQueryRetriever might be able to overcome some of the limitations of the distance-based retrieval and get a richer set of results. + +## Usage + +import CodeBlock from "@theme/CodeBlock"; +import Example from "@examples/retrievers/multi_query.ts"; + +{Example} + +## Customization + +You can also supply a custom prompt to tune what types of questions are generated. +You can also pass a custom output parser to parse and split the results of the LLM call into a list of queries. + +import CustomExample from "@examples/retrievers/multi_query_custom.ts"; + +{CustomExample} diff --git a/environment_tests/test-exports-bun/src/entrypoints.js b/environment_tests/test-exports-bun/src/entrypoints.js index 0cef0f43a27f..4602b02474b2 100644 --- a/environment_tests/test-exports-bun/src/entrypoints.js +++ b/environment_tests/test-exports-bun/src/entrypoints.js @@ -49,6 +49,7 @@ export * from "langchain/retrievers/remote"; export * from "langchain/retrievers/databerry"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; export * from "langchain/retrievers/time_weighted"; diff --git a/environment_tests/test-exports-cf/src/entrypoints.js b/environment_tests/test-exports-cf/src/entrypoints.js index 0cef0f43a27f..4602b02474b2 100644 --- a/environment_tests/test-exports-cf/src/entrypoints.js +++ b/environment_tests/test-exports-cf/src/entrypoints.js @@ -49,6 +49,7 @@ export * from "langchain/retrievers/remote"; export * from "langchain/retrievers/databerry"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; export * from "langchain/retrievers/time_weighted"; diff --git a/environment_tests/test-exports-cjs/src/entrypoints.js b/environment_tests/test-exports-cjs/src/entrypoints.js index 84c4e138726e..04d4e3076180 100644 --- a/environment_tests/test-exports-cjs/src/entrypoints.js +++ b/environment_tests/test-exports-cjs/src/entrypoints.js @@ -49,6 +49,7 @@ const retrievers_remote = require("langchain/retrievers/remote"); const retrievers_databerry = require("langchain/retrievers/databerry"); const retrievers_contextual_compression = require("langchain/retrievers/contextual_compression"); const retrievers_document_compressors = require("langchain/retrievers/document_compressors"); +const retrievers_multi_query = require("langchain/retrievers/multi_query"); const retrievers_multi_vector = require("langchain/retrievers/multi_vector"); const retrievers_parent_document = require("langchain/retrievers/parent_document"); const retrievers_time_weighted = require("langchain/retrievers/time_weighted"); diff --git a/environment_tests/test-exports-esbuild/src/entrypoints.js b/environment_tests/test-exports-esbuild/src/entrypoints.js index 4cf67cdf40e2..6be1da07b222 100644 --- a/environment_tests/test-exports-esbuild/src/entrypoints.js +++ b/environment_tests/test-exports-esbuild/src/entrypoints.js @@ -49,6 +49,7 @@ import * as retrievers_remote from "langchain/retrievers/remote"; import * as retrievers_databerry from "langchain/retrievers/databerry"; import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression"; import * as retrievers_document_compressors from "langchain/retrievers/document_compressors"; +import * as retrievers_multi_query from "langchain/retrievers/multi_query"; import * as retrievers_multi_vector from "langchain/retrievers/multi_vector"; import * as retrievers_parent_document from "langchain/retrievers/parent_document"; import * as retrievers_time_weighted from "langchain/retrievers/time_weighted"; diff --git a/environment_tests/test-exports-esm/src/entrypoints.js b/environment_tests/test-exports-esm/src/entrypoints.js index 4cf67cdf40e2..6be1da07b222 100644 --- a/environment_tests/test-exports-esm/src/entrypoints.js +++ b/environment_tests/test-exports-esm/src/entrypoints.js @@ -49,6 +49,7 @@ import * as retrievers_remote from "langchain/retrievers/remote"; import * as retrievers_databerry from "langchain/retrievers/databerry"; import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression"; import * as retrievers_document_compressors from "langchain/retrievers/document_compressors"; +import * as retrievers_multi_query from "langchain/retrievers/multi_query"; import * as retrievers_multi_vector from "langchain/retrievers/multi_vector"; import * as retrievers_parent_document from "langchain/retrievers/parent_document"; import * as retrievers_time_weighted from "langchain/retrievers/time_weighted"; diff --git a/environment_tests/test-exports-vercel/src/entrypoints.js b/environment_tests/test-exports-vercel/src/entrypoints.js index 0cef0f43a27f..4602b02474b2 100644 --- a/environment_tests/test-exports-vercel/src/entrypoints.js +++ b/environment_tests/test-exports-vercel/src/entrypoints.js @@ -49,6 +49,7 @@ export * from "langchain/retrievers/remote"; export * from "langchain/retrievers/databerry"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; export * from "langchain/retrievers/time_weighted"; diff --git a/environment_tests/test-exports-vite/src/entrypoints.js b/environment_tests/test-exports-vite/src/entrypoints.js index 0cef0f43a27f..4602b02474b2 100644 --- a/environment_tests/test-exports-vite/src/entrypoints.js +++ b/environment_tests/test-exports-vite/src/entrypoints.js @@ -49,6 +49,7 @@ export * from "langchain/retrievers/remote"; export * from "langchain/retrievers/databerry"; export * from "langchain/retrievers/contextual_compression"; export * from "langchain/retrievers/document_compressors"; +export * from "langchain/retrievers/multi_query"; export * from "langchain/retrievers/multi_vector"; export * from "langchain/retrievers/parent_document"; export * from "langchain/retrievers/time_weighted"; diff --git a/examples/src/retrievers/multi_query.ts b/examples/src/retrievers/multi_query.ts new file mode 100644 index 000000000000..eb953b9b1dc3 --- /dev/null +++ b/examples/src/retrievers/multi_query.ts @@ -0,0 +1,54 @@ +import { MemoryVectorStore } from "langchain/vectorstores/memory"; +import { CohereEmbeddings } from "langchain/embeddings/cohere"; +import { ChatAnthropic } from "langchain/chat_models/anthropic"; +import { MultiQueryRetriever } from "langchain/retrievers/multi_query"; + +const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() +); +const model = new ChatAnthropic({}); +const retriever = MultiQueryRetriever.fromLLM({ + llm: model, + retriever: vectorstore.asRetriever(), + verbose: true, +}); + +const query = "What are mitochondria made of?"; +const retrievedDocs = await retriever.getRelevantDocuments(query); + +/* + Generated queries: What are the components of mitochondria?,What substances comprise the mitochondria organelle? ,What is the molecular composition of mitochondria? +*/ + +console.log(retrievedDocs); + +/* + [ + Document { + pageContent: 'mitochondria is the powerhouse of the cell', + metadata: {} + }, + Document { + pageContent: 'mitochondria is made of lipids', + metadata: {} + }, + Document { + pageContent: 'Buildings are made out of brick', + metadata: { id: 1 } + }, + Document { + pageContent: 'Buildings are made out of wood', + metadata: { id: 2 } + } + ] +*/ diff --git a/examples/src/retrievers/multi_query_custom.ts b/examples/src/retrievers/multi_query_custom.ts new file mode 100644 index 000000000000..df2e05f6c19c --- /dev/null +++ b/examples/src/retrievers/multi_query_custom.ts @@ -0,0 +1,117 @@ +import { MemoryVectorStore } from "langchain/vectorstores/memory"; +import { CohereEmbeddings } from "langchain/embeddings/cohere"; +import { ChatAnthropic } from "langchain/chat_models/anthropic"; +import { MultiQueryRetriever } from "langchain/retrievers/multi_query"; +import { BaseOutputParser } from "langchain/schema/output_parser"; +import { PromptTemplate } from "langchain/prompts"; +import { LLMChain } from "langchain/chains"; + +type LineList = { + lines: string[]; +}; + +class LineListOutputParser extends BaseOutputParser { + static lc_name() { + return "LineListOutputParser"; + } + + lc_namespace = ["langchain", "retrievers", "multiquery"]; + + async parse(text: string): Promise { + const startKeyIndex = text.indexOf(""); + const endKeyIndex = text.indexOf(""); + const questionsStartIndex = + startKeyIndex === -1 ? 0 : startKeyIndex + "".length; + const questionsEndIndex = endKeyIndex === -1 ? text.length : endKeyIndex; + const lines = text + .slice(questionsStartIndex, questionsEndIndex) + .trim() + .split("\n") + .filter((line) => line.trim() !== ""); + return { lines }; + } + + getFormatInstructions(): string { + throw new Error("Not implemented."); + } +} + +// Create template +const prompt = + PromptTemplate.fromTemplate(`You are an AI language model assistant. Your task is +to generate {queryCount} different versions of the given user +question to retrieve relevant documents from a vector database. +By generating multiple perspectives on the user question, +your goal is to help the user overcome some of the limitations +of distance-based similarity search. + +All questions should be in German. + +Provide these alternative questions separated by newlines between XML tags. For example: + + +Question 1 +Question 2 +Question 3 + + +Original question: {question}`); + +const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Gebäude werden aus Ziegelsteinen hergestellt", + "Gebäude werden aus Holz hergestellt", + "Gebäude werden aus Stein hergestellt", + "Autos werden aus Metall hergestellt", + "Autos werden aus Kunststoff hergestellt", + "Mitochondrien sind die Energiekraftwerke der Zelle", + "Mitochondrien bestehen aus Lipiden", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() +); +const model = new ChatAnthropic({}); +const llmChain = new LLMChain({ + llm: model, + prompt, + outputParser: new LineListOutputParser(), +}); +const retriever = new MultiQueryRetriever({ + retriever: vectorstore.asRetriever(), + llmChain, + verbose: true, +}); + +const query = "What are mitochondria made of?"; +const retrievedDocs = await retriever.getRelevantDocuments(query); + +/* + Generated queries: Was besteht ein Mitochondrium?,Aus welchen Komponenten setzt sich ein Mitochondrium zusammen? ,Welche Moleküle finden sich in einem Mitochondrium? +*/ + +console.log(retrievedDocs); + +/* + [ + Document { + pageContent: 'Mitochondrien bestehen aus Lipiden', + metadata: {} + }, + Document { + pageContent: 'Mitochondrien sind die Energiekraftwerke der Zelle', + metadata: {} + }, + Document { + pageContent: 'Autos werden aus Metall hergestellt', + metadata: { id: 4 } + }, + Document { + pageContent: 'Gebäude werden aus Holz hergestellt', + metadata: { id: 2 } + }, + Document { + pageContent: 'Gebäude werden aus Ziegelsteinen hergestellt', + metadata: { id: 1 } + } + ] +*/ diff --git a/langchain/.gitignore b/langchain/.gitignore index 9cd3d57c6a84..3f7c1e7206f3 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -445,6 +445,9 @@ retrievers/contextual_compression.d.ts retrievers/document_compressors.cjs retrievers/document_compressors.js retrievers/document_compressors.d.ts +retrievers/multi_query.cjs +retrievers/multi_query.js +retrievers/multi_query.d.ts retrievers/multi_vector.cjs retrievers/multi_vector.js retrievers/multi_vector.d.ts diff --git a/langchain/package.json b/langchain/package.json index 38d326e63143..0ce405260f8d 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -457,6 +457,9 @@ "retrievers/document_compressors.cjs", "retrievers/document_compressors.js", "retrievers/document_compressors.d.ts", + "retrievers/multi_query.cjs", + "retrievers/multi_query.js", + "retrievers/multi_query.d.ts", "retrievers/multi_vector.cjs", "retrievers/multi_vector.js", "retrievers/multi_vector.d.ts", @@ -1843,6 +1846,11 @@ "import": "./retrievers/document_compressors.js", "require": "./retrievers/document_compressors.cjs" }, + "./retrievers/multi_query": { + "types": "./retrievers/multi_query.d.ts", + "import": "./retrievers/multi_query.js", + "require": "./retrievers/multi_query.cjs" + }, "./retrievers/multi_vector": { "types": "./retrievers/multi_vector.d.ts", "import": "./retrievers/multi_vector.js", diff --git a/langchain/scripts/create-entrypoints.js b/langchain/scripts/create-entrypoints.js index 5444aad6092a..a1af2baff932 100644 --- a/langchain/scripts/create-entrypoints.js +++ b/langchain/scripts/create-entrypoints.js @@ -182,6 +182,7 @@ const entrypoints = { "retrievers/databerry": "retrievers/databerry", "retrievers/contextual_compression": "retrievers/contextual_compression", "retrievers/document_compressors": "retrievers/document_compressors/index", + "retrievers/multi_query": "retrievers/multi_query", "retrievers/multi_vector": "retrievers/multi_vector", "retrievers/parent_document": "retrievers/parent_document", "retrievers/time_weighted": "retrievers/time_weighted", diff --git a/langchain/src/load/import_map.ts b/langchain/src/load/import_map.ts index 3a80f0c54840..db54ae87a59b 100644 --- a/langchain/src/load/import_map.ts +++ b/langchain/src/load/import_map.ts @@ -50,6 +50,7 @@ export * as retrievers__remote from "../retrievers/remote/index.js"; export * as retrievers__databerry from "../retrievers/databerry.js"; export * as retrievers__contextual_compression from "../retrievers/contextual_compression.js"; export * as retrievers__document_compressors from "../retrievers/document_compressors/index.js"; +export * as retrievers__multi_query from "../retrievers/multi_query.js"; export * as retrievers__multi_vector from "../retrievers/multi_vector.js"; export * as retrievers__parent_document from "../retrievers/parent_document.js"; export * as retrievers__time_weighted from "../retrievers/time_weighted.js"; diff --git a/langchain/src/retrievers/multi_query.ts b/langchain/src/retrievers/multi_query.ts new file mode 100644 index 000000000000..0cad7d9ed45d --- /dev/null +++ b/langchain/src/retrievers/multi_query.ts @@ -0,0 +1,167 @@ +import { LLMChain } from "../chains/llm_chain.js"; +import { PromptTemplate } from "../prompts/prompt.js"; +import { Document } from "../document.js"; +import { BaseOutputParser } from "../schema/output_parser.js"; +import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js"; +import { CallbackManagerForRetrieverRun } from "../callbacks/index.js"; +import { BaseLanguageModel } from "../base_language/index.js"; +import { BasePromptTemplate } from "../prompts/base.js"; + +interface LineList { + lines: string[]; +} + +class LineListOutputParser extends BaseOutputParser { + static lc_name() { + return "LineListOutputParser"; + } + + lc_namespace = ["langchain", "retrievers", "multiquery"]; + + async parse(text: string): Promise { + const startKeyIndex = text.indexOf(""); + const endKeyIndex = text.indexOf(""); + const questionsStartIndex = + startKeyIndex === -1 ? 0 : startKeyIndex + "".length; + const questionsEndIndex = endKeyIndex === -1 ? text.length : endKeyIndex; + const lines = text + .slice(questionsStartIndex, questionsEndIndex) + .trim() + .split("\n") + .filter((line) => line.trim() !== ""); + return { lines }; + } + + getFormatInstructions(): string { + throw new Error("Not implemented."); + } +} + +// Create template +const DEFAULT_QUERY_PROMPT = /* #__PURE__ */ new PromptTemplate({ + inputVariables: ["question", "queryCount"], + template: `You are an AI language model assistant. Your task is +to generate {queryCount} different versions of the given user +question to retrieve relevant documents from a vector database. +By generating multiple perspectives on the user question, +your goal is to help the user overcome some of the limitations +of distance-based similarity search. + +Provide these alternative questions separated by newlines between XML tags. For example: + + +Question 1 +Question 2 +Question 3 + + +Original question: {question}`, +}); + +export interface MultiQueryRetrieverInput extends BaseRetrieverInput { + retriever: BaseRetriever; + llmChain: LLMChain; + queryCount?: number; + parserKey?: string; +} + +// Export class +export class MultiQueryRetriever extends BaseRetriever { + static lc_name() { + return "MultiQueryRetriever"; + } + + lc_namespace = ["langchain", "retrievers", "multiquery"]; + + private retriever: BaseRetriever; + + private llmChain: LLMChain; + + private queryCount = 3; + + private parserKey = "lines"; + + constructor(fields: MultiQueryRetrieverInput) { + super(fields); + this.retriever = fields.retriever; + this.llmChain = fields.llmChain; + this.queryCount = fields.queryCount ?? this.queryCount; + this.parserKey = fields.parserKey ?? this.parserKey; + } + + static fromLLM( + fields: Omit & { + llm: BaseLanguageModel; + prompt?: BasePromptTemplate; + } + ): MultiQueryRetriever { + const { + retriever, + llm, + prompt = DEFAULT_QUERY_PROMPT, + queryCount, + parserKey, + ...rest + } = fields; + const outputParser = new LineListOutputParser(); + const llmChain = new LLMChain({ llm, prompt, outputParser }); + return new this({ retriever, llmChain, queryCount, parserKey, ...rest }); + } + + // Generate the different queries for each retrieval, using our llmChain + private async _generateQueries( + question: string, + runManager?: CallbackManagerForRetrieverRun + ): Promise { + const response = await this.llmChain.call( + { question, queryCount: this.queryCount }, + runManager?.getChild() + ); + const lines = response.text[this.parserKey] || []; + if (this.verbose) { + console.log(`Generated queries: ${lines}`); + } + return lines; + } + + // Retrieve documents using the original retriever + private async _retrieveDocuments( + queries: string[], + runManager?: CallbackManagerForRetrieverRun + ): Promise { + const documents: Document[] = []; + for (const query of queries) { + const docs = await this.retriever.getRelevantDocuments( + query, + runManager?.getChild() + ); + documents.push(...docs); + } + return documents; + } + + // Deduplicate the documents that were returned in multiple retrievals + private _uniqueUnion(documents: Document[]): Document[] { + const uniqueDocumentsDict: { [key: string]: Document } = {}; + + for (const doc of documents) { + const key = `${doc.pageContent}:${JSON.stringify( + Object.entries(doc.metadata).sort() + )}`; + uniqueDocumentsDict[key] = doc; + } + + const uniqueDocuments = Object.values(uniqueDocumentsDict); + return uniqueDocuments; + } + + async _getRelevantDocuments( + question: string, + runManager?: CallbackManagerForRetrieverRun + ): Promise { + const queries = await this._generateQueries(question, runManager); + const documents = await this._retrieveDocuments(queries, runManager); + const uniqueDocuments = this._uniqueUnion(documents); + return uniqueDocuments; + } +} diff --git a/langchain/src/retrievers/tests/multi_query.int.test.ts b/langchain/src/retrievers/tests/multi_query.int.test.ts new file mode 100644 index 000000000000..32b7f9be9806 --- /dev/null +++ b/langchain/src/retrievers/tests/multi_query.int.test.ts @@ -0,0 +1,57 @@ +import { expect, test } from "@jest/globals"; +import { CohereEmbeddings } from "../../embeddings/cohere.js"; +import { MemoryVectorStore } from "../../vectorstores/memory.js"; +import { MultiQueryRetriever } from "../multi_query.js"; +import { ChatAnthropic } from "../../chat_models/anthropic.js"; + +test("Should work with a question input", async () => { + const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() + ); + const model = new ChatAnthropic({}); + const retriever = MultiQueryRetriever.fromLLM({ + llm: model, + retriever: vectorstore.asRetriever(), + verbose: true, + }); + + const query = "What are mitochondria made of?"; + const retrievedDocs = await retriever.getRelevantDocuments(query); + expect(retrievedDocs[0].pageContent).toContain("mitochondria"); +}); + +test("Should work with a keyword", async () => { + const vectorstore = await MemoryVectorStore.fromTexts( + [ + "Buildings are made out of brick", + "Buildings are made out of wood", + "Buildings are made out of stone", + "Cars are made out of metal", + "Cars are made out of plastic", + "mitochondria is the powerhouse of the cell", + "mitochondria is made of lipids", + ], + [{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }], + new CohereEmbeddings() + ); + const model = new ChatAnthropic({}); + const retriever = MultiQueryRetriever.fromLLM({ + llm: model, + retriever: vectorstore.asRetriever(), + verbose: true, + }); + + const query = "cars"; + const retrievedDocs = await retriever.getRelevantDocuments(query); + expect(retrievedDocs[0].pageContent).toContain("Cars"); +}); diff --git a/langchain/tsconfig.json b/langchain/tsconfig.json index a209b97aeb81..ff11a05c563b 100644 --- a/langchain/tsconfig.json +++ b/langchain/tsconfig.json @@ -181,6 +181,7 @@ "src/retrievers/databerry.ts", "src/retrievers/contextual_compression.ts", "src/retrievers/document_compressors/index.ts", + "src/retrievers/multi_query.ts", "src/retrievers/multi_vector.ts", "src/retrievers/parent_document.ts", "src/retrievers/time_weighted.ts",