Skip to content

Commit

Permalink
Add Ollama embeddings support (langchain-ai#2520)
Browse files Browse the repository at this point in the history
* ✨ Add Ollama Embeddings Class

* ✅ Add simple tests and refactor a bit

* fix inheritance

* authored by @smndtrl: Add ollama embeddings to configs

* ♻️ change sorting of class members

* 🎨 implement changes suggested by @jacoblee93

* refactor constructor to allow for no arguments

* consistancy fix

* Add first draft of documentation

* fix broken link in doc
  • Loading branch information
Basti-an authored Sep 7, 2023
1 parent b8d5a4d commit aa398cd
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Ollama

The `OllamaEmbeddings` class uses the `/api/embeddings` route of a locally hosted [Ollama](https://ollama.ai) server to generate embeddings for given texts.

# Setup

Follow [these instructions](https://github.com/jmorganca/ollama) to set up and run a local Ollama instance.

# Usage

Basic usage:

```typescript
import { OllamaEmbeddings } from "langchain/embeddings/ollama";

const embeddings = new OllamaEmbeddings({
model: "llama2", // default value
baseUrl: "http://localhost:11434", // default value
});
```

Ollama [model parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) are also supported:

```typescript
import { OllamaEmbeddings } from "langchain/embeddings/ollama";

const embeddings = new OllamaEmbeddings({
model: "llama2", // default value
baseUrl: "http://localhost:11434", // default value
requestOptions: {
useMMap: true, // use_mmap 1
numThreads: 6, // num_thread 6
numGpu: 1, // num_gpu 1
},
});
```

# Example usage:

```typescript
import { OllamaEmbeddings } from "langchain/embeddings/ollama";

const embeddings = new OllamaEmbeddings({
model: "llama2", // default value
baseUrl: "http://localhost:11434", // default value
requestOptions: {
useMMap: true,
numThreads: 6,
numGpu: 1,
},
});

const documents = ["Hello World!", "Bye Bye"];

const documentEmbeddings = await embeddings.embedDocuments(documents);
```
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ embeddings/cache_backed.d.ts
embeddings/fake.cjs
embeddings/fake.js
embeddings/fake.d.ts
embeddings/ollama.cjs
embeddings/ollama.js
embeddings/ollama.d.ts
embeddings/openai.cjs
embeddings/openai.js
embeddings/openai.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
"embeddings/fake.cjs",
"embeddings/fake.js",
"embeddings/fake.d.ts",
"embeddings/ollama.cjs",
"embeddings/ollama.js",
"embeddings/ollama.d.ts",
"embeddings/openai.cjs",
"embeddings/openai.js",
"embeddings/openai.d.ts",
Expand Down Expand Up @@ -1229,6 +1232,11 @@
"import": "./embeddings/fake.js",
"require": "./embeddings/fake.cjs"
},
"./embeddings/ollama": {
"types": "./embeddings/ollama.d.ts",
"import": "./embeddings/ollama.js",
"require": "./embeddings/ollama.cjs"
},
"./embeddings/openai": {
"types": "./embeddings/openai.d.ts",
"import": "./embeddings/openai.js",
Expand Down
6 changes: 4 additions & 2 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const entrypoints = {
"embeddings/base": "embeddings/base",
"embeddings/cache_backed": "embeddings/cache_backed",
"embeddings/fake": "embeddings/fake",
"embeddings/ollama": "embeddings/ollama",
"embeddings/openai": "embeddings/openai",
"embeddings/cohere": "embeddings/cohere",
"embeddings/tensorflow": "embeddings/tensorflow",
Expand Down Expand Up @@ -225,7 +226,7 @@ const entrypoints = {
"storage/in_memory": "storage/in_memory",
"storage/ioredis": "storage/ioredis",
// hub
"hub": "hub",
hub: "hub",
// utilities
"util/math": "util/math",
// experimental
Expand All @@ -235,7 +236,8 @@ const entrypoints = {
"experimental/plan_and_execute": "experimental/plan_and_execute/index",
"experimental/multimodal_embeddings/googlevertexai":
"experimental/multimodal_embeddings/googlevertexai",
"experimental/chat_models/anthropic_functions": "experimental/chat_models/anthropic_functions",
"experimental/chat_models/anthropic_functions":
"experimental/chat_models/anthropic_functions",
// evaluation
evaluation: "evaluation/index",
};
Expand Down
135 changes: 135 additions & 0 deletions langchain/src/embeddings/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import { OllamaInput, OllamaRequestParams } from "../util/ollama.js";
import { Embeddings, EmbeddingsParams } from "./base.js";

type CamelCasedRequestOptions = Omit<OllamaInput, "baseUrl" | "model">;

/**
* Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and
* defines additional parameters specific to the OllamaEmbeddings class.
*/
interface OllamaEmbeddingsParams extends EmbeddingsParams {
/** The Ollama model to use, e.g: "llama2:13b" */
model?: string;

/** Base URL of the Ollama server, defaults to "http://localhost:11434" */
baseUrl?: string;

/** Advanced Ollama API request parameters in camelCase, see
* https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
* for details of the available parameters.
*/
requestOptions?: CamelCasedRequestOptions;
}

export class OllamaEmbeddings extends Embeddings {
model = "llama2";

baseUrl = "http://localhost:11434";

requestOptions?: OllamaRequestParams["options"];

constructor(params?: OllamaEmbeddingsParams) {
super(params || {});

if (params?.model) {
this.model = params.model;
}

if (params?.baseUrl) {
this.baseUrl = params.baseUrl;
}

if (params?.requestOptions) {
this.requestOptions = this._convertOptions(params.requestOptions);
}
}

/** convert camelCased Ollama request options like "useMMap" to
* the snake_cased equivalent which the ollama API actually uses.
* Used only for consistency with the llms/Ollama and chatModels/Ollama classes
*/
_convertOptions(requestOptions: CamelCasedRequestOptions) {
const snakeCasedOptions: Record<string, unknown> = {};
const mapping: Record<keyof CamelCasedRequestOptions, string> = {
embeddingOnly: "embedding_only",
f16KV: "f16_kv",
frequencyPenalty: "frequency_penalty",
logitsAll: "logits_all",
lowVram: "low_vram",
mainGpu: "main_gpu",
mirostat: "mirostat",
mirostatEta: "mirostat_eta",
mirostatTau: "mirostat_tau",
numBatch: "num_batch",
numCtx: "num_ctx",
numGpu: "num_gpu",
numGqa: "num_gqa",
numKeep: "num_keep",
numThread: "num_thread",
penalizeNewline: "penalize_newline",
presencePenalty: "presence_penalty",
repeatLastN: "repeat_last_n",
repeatPenalty: "repeat_penalty",
ropeFrequencyBase: "rope_frequency_base",
ropeFrequencyScale: "rope_frequency_scale",
temperature: "temperature",
stop: "stop",
tfsZ: "tfs_z",
topK: "top_k",
topP: "top_p",
typicalP: "typical_p",
useMLock: "use_mlock",
useMMap: "use_mmap",
vocabOnly: "vocab_only",
};

for (const [key, value] of Object.entries(requestOptions)) {
const snakeCasedOption = mapping[key as keyof CamelCasedRequestOptions];
if (snakeCasedOption) {
snakeCasedOptions[snakeCasedOption] = value;
}
}
return snakeCasedOptions;
}

async _request(prompt: string): Promise<number[]> {
const { model, baseUrl, requestOptions } = this;

const response = await fetch(`${baseUrl}/api/embeddings`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
prompt,
model,
options: requestOptions,
}),
});
if (!response.ok) {
throw new Error(
`Request to Ollama server failed: ${response.status} ${response.statusText}`
);
}

const json = await response.json();
return json.embedding;
}

async _embed(strings: string[]): Promise<number[][]> {
const embeddings: number[][] = [];

for await (const prompt of strings) {
const embedding = await this.caller.call(() => this._request(prompt));
embeddings.push(embedding);
}

return embeddings;
}

async embedDocuments(documents: string[]) {
return this._embed(documents);
}

async embedQuery(document: string) {
return (await this.embedDocuments([document]))[0];
}
}
16 changes: 16 additions & 0 deletions langchain/src/embeddings/tests/ollama.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { test, expect } from "@jest/globals";
import { OllamaEmbeddings } from "../ollama.js";

test("Test OllamaEmbeddings.embedQuery", async () => {
const embeddings = new OllamaEmbeddings();
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
});

test("Test OllamaEmbeddings.embedDocuments", async () => {
const embeddings = new OllamaEmbeddings();
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]);
expect(res).toHaveLength(2);
expect(typeof res[0][0]).toBe("number");
expect(typeof res[1][0]).toBe("number");
});
1 change: 1 addition & 0 deletions langchain/src/load/import_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export * as chains__openai_functions from "../chains/openai_functions/index.js";
export * as embeddings__base from "../embeddings/base.js";
export * as embeddings__cache_backed from "../embeddings/cache_backed.js";
export * as embeddings__fake from "../embeddings/fake.js";
export * as embeddings__ollama from "../embeddings/ollama.js";
export * as embeddings__openai from "../embeddings/openai.js";
export * as embeddings__minimax from "../embeddings/minimax.js";
export * as llms__base from "../llms/base.js";
Expand Down
17 changes: 4 additions & 13 deletions langchain/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
"outDir": "../dist",
"rootDir": "./src",
"target": "ES2021",
"lib": [
"ES2021",
"ES2022.Object",
"DOM"
],
"lib": ["ES2021", "ES2022.Object", "DOM"],
"module": "ES2020",
"moduleResolution": "nodenext",
"esModuleInterop": true,
Expand All @@ -22,14 +18,8 @@
"allowJs": true,
"strict": true
},
"include": [
"src/**/*"
],
"exclude": [
"node_modules",
"dist",
"docs"
],
"include": ["src/**/*"],
"exclude": ["node_modules", "dist", "docs"],
"typedocOptions": {
"entryPoints": [
"src/load/index.ts",
Expand All @@ -55,6 +45,7 @@
"src/embeddings/base.ts",
"src/embeddings/cache_backed.ts",
"src/embeddings/fake.ts",
"src/embeddings/ollama.ts",
"src/embeddings/openai.ts",
"src/embeddings/cohere.ts",
"src/embeddings/tensorflow.ts",
Expand Down

0 comments on commit aa398cd

Please sign in to comment.