Skip to content

Commit

Permalink
fix: batch embedding queries, fix get client in tests (langchain-ai#2908
Browse files Browse the repository at this point in the history
)

* fix: batch embedding queries, fix get client in tests

* fix: more efficient .push instead of .concat, fix jsdoc

* nit: console log error inside catch

* chore: lint files

* fix: rename, add @deprecated jsdoc to _embedText

* fix: jsdoc

* chore: lint files

* fix: scrap new method, instead refactor existing _embedText method

* chore: lint files

* fix: remove batching (caller does that) & update jsdoc comment
  • Loading branch information
bracesproul authored Oct 13, 2023
1 parent 5019bb4 commit 85586fa
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 41 deletions.
74 changes: 47 additions & 27 deletions langchain/src/embeddings/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import {
BedrockRuntimeClient,
InvokeModelCommand,
} from "@aws-sdk/client-bedrock-runtime";

import { Embeddings, EmbeddingsParams } from "./base.js";
import type { CredentialType } from "../util/bedrock.js";

Expand Down Expand Up @@ -40,6 +39,8 @@ export class BedrockEmbeddings

client: BedrockRuntimeClient;

batchSize = 512;

constructor(fields?: BedrockEmbeddingsParams) {
super(fields ?? {});

Expand All @@ -53,28 +54,48 @@ export class BedrockEmbeddings
});
}

/**
* Protected method to make a request to the Bedrock API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param request Request to send to the Bedrock API.
* @returns Promise that resolves to the response from the API.
*/
protected async _embedText(text: string): Promise<number[]> {
// replace newlines, which can negatively affect performance.
const cleanedText = text.replace(/\n/g, " ");

const res = await this.client.send(
new InvokeModelCommand({
modelId: this.model,
body: JSON.stringify({
inputText: cleanedText,
}),
contentType: "application/json",
accept: "application/json",
})
);

try {
const body = new TextDecoder().decode(res.body);

return JSON.parse(body).embedding;
} catch (e) {
throw new Error("An invalid response was returned by Bedrock.");
}
return this.caller.call(async () => {
try {
// replace newlines, which can negatively affect performance.
const cleanedText = text.replace(/\n/g, " ");

const res = await this.client.send(
new InvokeModelCommand({
modelId: this.model,
body: JSON.stringify({
inputText: cleanedText,
}),
contentType: "application/json",
accept: "application/json",
})
);

const body = new TextDecoder().decode(res.body);
return JSON.parse(body).embedding;
} catch (e) {
console.error({
error: e,
});
// eslint-disable-next-line no-instanceof/no-instanceof
if (e instanceof Error) {
throw new Error(
`An error occurred while embedding documents with Bedrock: ${e.message}`
);
}

throw new Error(
"An error occurred while embedding documents with Bedrock"
);
}
});
}

/**
Expand All @@ -93,13 +114,12 @@ export class BedrockEmbeddings
}

/**
* Method that takes an array of documents as input and returns a promise
* that resolves to a 2D array of embeddings for each document. It calls
* the _embedText method for each document in the array.
* @param documents Array of documents for which to generate embeddings.
* Method to generate embeddings for an array of texts. Calls _embedText
* method which batches and handles retry logic when calling the AWS Bedrock API.
* @param documents Array of texts for which to generate embeddings.
* @returns Promise that resolves to a 2D array of embeddings for each input document.
*/
embedDocuments(documents: string[]): Promise<number[][]> {
async embedDocuments(documents: string[]): Promise<number[][]> {
return Promise.all(documents.map((document) => this._embedText(document)));
}
}
39 changes: 25 additions & 14 deletions langchain/src/embeddings/tests/bedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,42 @@ import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime";
import { HNSWLib } from "../../vectorstores/hnswlib.js";
import { BedrockEmbeddings } from "../bedrock.js";

const client = new BedrockRuntimeClient({
region: process.env.BEDROCK_AWS_REGION!,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
},
});
const getClient = () => {
if (
!process.env.BEDROCK_AWS_REGION ||
!process.env.BEDROCK_AWS_ACCESS_KEY_ID ||
!process.env.BEDROCK_AWS_SECRET_ACCESS_KEY
) {
throw new Error("Missing environment variables for AWS");
}

const client = new BedrockRuntimeClient({
region: process.env.BEDROCK_AWS_REGION,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY,
},
});

return client;
};

test("Test BedrockEmbeddings.embedQuery", async () => {
const client = getClient();
const embeddings = new BedrockEmbeddings({
maxRetries: 1,
client,
});
const res = await embeddings.embedQuery("Hello world");
console.log(res);
// console.log(res);
expect(typeof res[0]).toBe("number");
});

test("Test BedrockEmbeddings.embedDocuments with passed region and credentials", async () => {
const client = getClient();
const embeddings = new BedrockEmbeddings({
maxRetries: 1,
region: process.env.BEDROCK_AWS_REGION!,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
},
client,
});
const res = await embeddings.embedDocuments([
"Hello world",
Expand All @@ -41,14 +51,15 @@ test("Test BedrockEmbeddings.embedDocuments with passed region and credentials",
"six documents",
"to test pagination",
]);
console.log(res);
// console.log(res);
expect(res).toHaveLength(6);
res.forEach((r) => {
expect(typeof r[0]).toBe("number");
});
});

test("Test end to end with HNSWLib", async () => {
const client = getClient();
const vectorStore = await HNSWLib.fromTexts(
["Hello world", "Bye bye", "hello nice world"],
[{ id: 2 }, { id: 1 }, { id: 3 }],
Expand Down

0 comments on commit 85586fa

Please sign in to comment.