Skip to content

Commit

Permalink
Allow ParentDocumentRetriever to subclass MultiVectorRetriever (langc…
Browse files Browse the repository at this point in the history
…hain-ai#2853)

* Allow ParentDocumentRetriever to subclass MultiVectorRetreiver

* Remove subclassing

* Update ParentDocumentRetriever

* Allow backwards compatibility with docstore

* Parallelize mget

* Formatting
  • Loading branch information
jacoblee93 authored Oct 13, 2023
1 parent b74b340 commit 70774b7
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 31 deletions.
4 changes: 2 additions & 2 deletions examples/src/retrievers/parent_document_retriever.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { InMemoryDocstore } from "langchain/stores/doc/in_memory";
import { InMemoryStore } from "langchain/storage/in_memory";
import { ParentDocumentRetriever } from "langchain/retrievers/parent_document";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { TextLoader } from "langchain/document_loaders/fs/text";

const vectorstore = new MemoryVectorStore(new OpenAIEmbeddings());
const docstore = new InMemoryDocstore();
const docstore = new InMemoryStore();
const retriever = new ParentDocumentRetriever({
vectorstore,
docstore,
Expand Down
6 changes: 3 additions & 3 deletions langchain/src/retrievers/multi_vector.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseStore } from "../schema/storage.js";
import { BaseStoreInterface } from "../schema/storage.js";
import { Document } from "../document.js";
import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js";
import { VectorStore } from "../vectorstores/base.js";
Expand All @@ -8,7 +8,7 @@ import { VectorStore } from "../vectorstores/base.js";
*/
export interface MultiVectorRetrieverInput extends BaseRetrieverInput {
vectorstore: VectorStore;
docstore: BaseStore<string, Document>;
docstore: BaseStoreInterface<string, Document>;
idKey?: string;
childK?: number;
parentK?: number;
Expand All @@ -28,7 +28,7 @@ export class MultiVectorRetriever extends BaseRetriever {

public vectorstore: VectorStore;

public docstore: BaseStore<string, Document>;
public docstore: BaseStoreInterface<string, Document>;

protected idKey: string;

Expand Down
35 changes: 14 additions & 21 deletions langchain/src/retrievers/parent_document.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import * as uuid from "uuid";

import { BaseRetriever, BaseRetrieverInput } from "../schema/retriever.js";
import { Document } from "../document.js";
import { VectorStore } from "../vectorstores/base.js";
import { Docstore } from "../schema/index.js";
import { TextSplitter } from "../text_splitter.js";
import {
MultiVectorRetriever,
type MultiVectorRetrieverInput,
} from "./multi_vector.js";

/**
* Interface for the fields required to initialize a
* ParentDocumentRetriever instance.
*/
export interface ParentDocumentRetrieverFields extends BaseRetrieverInput {
vectorstore: VectorStore;
docstore: Docstore;
export type ParentDocumentRetrieverFields = MultiVectorRetrieverInput & {
childSplitter: TextSplitter;
parentSplitter?: TextSplitter;
idKey?: string;
childK?: number;
parentK?: number;
}
};

// TODO: Change this to subclass MultiVectorRetriever
/**
* A type of document retriever that splits input documents into smaller chunks
* while separately storing and preserving the original documents.
Expand All @@ -30,16 +26,14 @@ export interface ParentDocumentRetrieverFields extends BaseRetrieverInput {
* This strikes a balance between better targeted retrieval with small documents
* and the more context-rich larger documents.
*/
export class ParentDocumentRetriever extends BaseRetriever {
export class ParentDocumentRetriever extends MultiVectorRetriever {
static lc_name() {
return "ParentDocumentRetriever";
}

lc_namespace = ["langchain", "retrievers", "parent_document"];

protected vectorstore: VectorStore;

protected docstore: Docstore;
vectorstore: VectorStore;

protected childSplitter: TextSplitter;

Expand Down Expand Up @@ -72,12 +66,11 @@ export class ParentDocumentRetriever extends BaseRetriever {
}
}
const parentDocs: Document[] = [];
for (const parentDocId of parentDocIds) {
const parentDoc = await this.docstore.search(parentDocId);
if (parentDoc !== undefined) {
parentDocs.push(parentDoc);
}
}
const storedParentDocs = await this.docstore.mget(parentDocIds);
const retrievedDocs: Document[] = storedParentDocs.filter(
(doc?: Document): doc is Document => doc !== undefined
);
parentDocs.push(...retrievedDocs);
return parentDocs.slice(0, this.parentK);
}

Expand Down Expand Up @@ -138,7 +131,7 @@ export class ParentDocumentRetriever extends BaseRetriever {
}
await this.vectorstore.addDocuments(embeddedDocs);
if (addToDocstore) {
await this.docstore.add(fullDocs);
await this.docstore.mset(Object.entries(fullDocs));
}
}
}
27 changes: 24 additions & 3 deletions langchain/src/retrievers/tests/parent_document.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { expect, test } from "@jest/globals";
import { TextLoader } from "../../document_loaders/fs/text.js";
import { InMemoryDocstore } from "../../stores/doc/in_memory.js";
import { InMemoryStore } from "../../storage/in_memory.js";
import { OpenAIEmbeddings } from "../../embeddings/openai.js";
import { MemoryVectorStore } from "../../vectorstores/memory.js";
import { ParentDocumentRetriever } from "../parent_document.js";
Expand All @@ -10,7 +11,7 @@ test("Should return the full document if an unsplit parent document has been add
const vectorstore = new MemoryVectorStore(new OpenAIEmbeddings());
const retriever = new ParentDocumentRetriever({
vectorstore,
docstore: new InMemoryDocstore(),
docstore: new InMemoryStore(),
childSplitter: new RecursiveCharacterTextSplitter({
chunkOverlap: 0,
chunkSize: 100,
Expand All @@ -29,7 +30,7 @@ test("Should return the full document if an unsplit parent document has been add

test("Should return a part of a document if a parent splitter is passed", async () => {
const vectorstore = new MemoryVectorStore(new OpenAIEmbeddings());
const docstore = new InMemoryDocstore();
const docstore = new InMemoryStore();
const retriever = new ParentDocumentRetriever({
vectorstore,
docstore,
Expand All @@ -46,7 +47,6 @@ test("Should return a part of a document if a parent splitter is passed", async
"../examples/state_of_the_union.txt"
).load();
await retriever.addDocuments(docs);
console.log(docstore._docs.size);
const query = "justice breyer";
const retrievedDocs = await retriever.getRelevantDocuments(query);
const vectorstoreRetreivedDocs = await vectorstore.similaritySearch(
Expand All @@ -57,3 +57,24 @@ test("Should return a part of a document if a parent splitter is passed", async
expect(retrievedDocs.length).toBeGreaterThan(1);
expect(retrievedDocs[0].pageContent.length).toBeGreaterThan(100);
});

test("Should work with a backwards compatible docstore too", async () => {
const vectorstore = new MemoryVectorStore(new OpenAIEmbeddings());
const retriever = new ParentDocumentRetriever({
vectorstore,
docstore: new InMemoryDocstore(),
childSplitter: new RecursiveCharacterTextSplitter({
chunkOverlap: 0,
chunkSize: 100,
}),
});
const docs = await new TextLoader(
"../examples/state_of_the_union.txt"
).load();
await retriever.addDocuments(docs);

const query = "justice breyer";
const retrievedDocs = await retriever.getRelevantDocuments(query);
expect(retrievedDocs.length).toEqual(1);
expect(retrievedDocs[0].pageContent.length).toBeGreaterThan(1000);
});
36 changes: 35 additions & 1 deletion langchain/src/schema/storage.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
import { Serializable } from "../load/serializable.js";

/** @deprecated For backwards compatibility only. Remove on next minor version upgrade. */
export interface BaseStoreInterface<K, V> {
/**
* Method to get multiple values for a set of keys.
* @param {K[]} keys - An array of keys.
* @returns {Promise<(V | undefined)[]>} - A Promise that resolves with array of values or undefined if key not found.
*/
mget(keys: K[]): Promise<(V | undefined)[]>;

/**
* Method to set a value for multiple keys.
* @param {[K, V][]} keyValuePairs - An array of key-value pairs.
* @returns {Promise<void>} - A Promise that resolves when the operation is complete.
*/
mset(keyValuePairs: [K, V][]): Promise<void>;

/**
* Method to delete multiple keys.
* @param {K[]} keys - An array of keys to delete.
* @returns {Promise<void>} - A Promise that resolves when the operation is complete.
*/
mdelete(keys: K[]): Promise<void>;

/**
* Method to yield keys optionally based on a prefix.
* @param {string} prefix - Optional prefix to filter keys.
* @returns {AsyncGenerator<K | string>} - An asynchronous generator that yields keys on iteration.
*/
yieldKeys(prefix?: string): AsyncGenerator<K | string>;
}

/**
* Abstract interface for a key-value store.
*/
export abstract class BaseStore<K, V> extends Serializable {
export abstract class BaseStore<K, V>
extends Serializable
implements BaseStoreInterface<K, V>
{
/**
* Abstract method to get multiple values for a set of keys.
* @param {K[]} keys - An array of keys.
Expand Down
25 changes: 24 additions & 1 deletion langchain/src/stores/doc/in_memory.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import { Document } from "../../document.js";
import { Docstore } from "../../schema/index.js";
import { BaseStoreInterface } from "../../schema/storage.js";

/**
* Class for storing and retrieving documents in memory asynchronously.
* Extends the Docstore class.
*/
export class InMemoryDocstore extends Docstore {
export class InMemoryDocstore
extends Docstore
implements BaseStoreInterface<string, Document>
{
_docs: Map<string, Document>;

constructor(docs?: Map<string, Document>) {
Expand Down Expand Up @@ -44,6 +48,25 @@ export class InMemoryDocstore extends Docstore {
this._docs.set(key, value);
}
}

async mget(keys: string[]): Promise<Document[]> {
return Promise.all(keys.map((key) => this.search(key)));
}

async mset(keyValuePairs: [string, Document][]): Promise<void> {
await Promise.all(
keyValuePairs.map(([key, value]) => this.add({ [key]: value }))
);
}

async mdelete(_keys: string[]): Promise<void> {
throw new Error("Not implemented.");
}

// eslint-disable-next-line require-yield
async *yieldKeys(_prefix?: string): AsyncGenerator<string> {
throw new Error("Not implemented");
}
}

/**
Expand Down

0 comments on commit 70774b7

Please sign in to comment.