forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Ollama embeddings support (langchain-ai#2520)
* ✨ Add Ollama Embeddings Class * ✅ Add simple tests and refactor a bit * fix inheritance * authored by @smndtrl: Add ollama embeddings to configs * ♻️ change sorting of class members * 🎨 implement changes suggested by @jacoblee93 * refactor constructor to allow for no arguments * consistancy fix * Add first draft of documentation * fix broken link in doc
- Loading branch information
Showing
8 changed files
with
227 additions
and
15 deletions.
There are no files selected for viewing
56 changes: 56 additions & 0 deletions
56
docs/extras/modules/data_connection/text_embedding/integrations/ollama.mdx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Ollama | ||
|
||
The `OllamaEmbeddings` class uses the `/api/embeddings` route of a locally hosted [Ollama](https://ollama.ai) server to generate embeddings for given texts. | ||
|
||
# Setup | ||
|
||
Follow [these instructions](https://github.com/jmorganca/ollama) to set up and run a local Ollama instance. | ||
|
||
# Usage | ||
|
||
Basic usage: | ||
|
||
```typescript | ||
import { OllamaEmbeddings } from "langchain/embeddings/ollama"; | ||
|
||
const embeddings = new OllamaEmbeddings({ | ||
model: "llama2", // default value | ||
baseUrl: "http://localhost:11434", // default value | ||
}); | ||
``` | ||
|
||
Ollama [model parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) are also supported: | ||
|
||
```typescript | ||
import { OllamaEmbeddings } from "langchain/embeddings/ollama"; | ||
|
||
const embeddings = new OllamaEmbeddings({ | ||
model: "llama2", // default value | ||
baseUrl: "http://localhost:11434", // default value | ||
requestOptions: { | ||
useMMap: true, // use_mmap 1 | ||
numThreads: 6, // num_thread 6 | ||
numGpu: 1, // num_gpu 1 | ||
}, | ||
}); | ||
``` | ||
|
||
# Example usage: | ||
|
||
```typescript | ||
import { OllamaEmbeddings } from "langchain/embeddings/ollama"; | ||
|
||
const embeddings = new OllamaEmbeddings({ | ||
model: "llama2", // default value | ||
baseUrl: "http://localhost:11434", // default value | ||
requestOptions: { | ||
useMMap: true, | ||
numThreads: 6, | ||
numGpu: 1, | ||
}, | ||
}); | ||
|
||
const documents = ["Hello World!", "Bye Bye"]; | ||
|
||
const documentEmbeddings = await embeddings.embedDocuments(documents); | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import { OllamaInput, OllamaRequestParams } from "../util/ollama.js"; | ||
import { Embeddings, EmbeddingsParams } from "./base.js"; | ||
|
||
type CamelCasedRequestOptions = Omit<OllamaInput, "baseUrl" | "model">; | ||
|
||
/** | ||
* Interface for OllamaEmbeddings parameters. Extends EmbeddingsParams and | ||
* defines additional parameters specific to the OllamaEmbeddings class. | ||
*/ | ||
interface OllamaEmbeddingsParams extends EmbeddingsParams { | ||
/** The Ollama model to use, e.g: "llama2:13b" */ | ||
model?: string; | ||
|
||
/** Base URL of the Ollama server, defaults to "http://localhost:11434" */ | ||
baseUrl?: string; | ||
|
||
/** Advanced Ollama API request parameters in camelCase, see | ||
* https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values | ||
* for details of the available parameters. | ||
*/ | ||
requestOptions?: CamelCasedRequestOptions; | ||
} | ||
|
||
export class OllamaEmbeddings extends Embeddings { | ||
model = "llama2"; | ||
|
||
baseUrl = "http://localhost:11434"; | ||
|
||
requestOptions?: OllamaRequestParams["options"]; | ||
|
||
constructor(params?: OllamaEmbeddingsParams) { | ||
super(params || {}); | ||
|
||
if (params?.model) { | ||
this.model = params.model; | ||
} | ||
|
||
if (params?.baseUrl) { | ||
this.baseUrl = params.baseUrl; | ||
} | ||
|
||
if (params?.requestOptions) { | ||
this.requestOptions = this._convertOptions(params.requestOptions); | ||
} | ||
} | ||
|
||
/** convert camelCased Ollama request options like "useMMap" to | ||
* the snake_cased equivalent which the ollama API actually uses. | ||
* Used only for consistency with the llms/Ollama and chatModels/Ollama classes | ||
*/ | ||
_convertOptions(requestOptions: CamelCasedRequestOptions) { | ||
const snakeCasedOptions: Record<string, unknown> = {}; | ||
const mapping: Record<keyof CamelCasedRequestOptions, string> = { | ||
embeddingOnly: "embedding_only", | ||
f16KV: "f16_kv", | ||
frequencyPenalty: "frequency_penalty", | ||
logitsAll: "logits_all", | ||
lowVram: "low_vram", | ||
mainGpu: "main_gpu", | ||
mirostat: "mirostat", | ||
mirostatEta: "mirostat_eta", | ||
mirostatTau: "mirostat_tau", | ||
numBatch: "num_batch", | ||
numCtx: "num_ctx", | ||
numGpu: "num_gpu", | ||
numGqa: "num_gqa", | ||
numKeep: "num_keep", | ||
numThread: "num_thread", | ||
penalizeNewline: "penalize_newline", | ||
presencePenalty: "presence_penalty", | ||
repeatLastN: "repeat_last_n", | ||
repeatPenalty: "repeat_penalty", | ||
ropeFrequencyBase: "rope_frequency_base", | ||
ropeFrequencyScale: "rope_frequency_scale", | ||
temperature: "temperature", | ||
stop: "stop", | ||
tfsZ: "tfs_z", | ||
topK: "top_k", | ||
topP: "top_p", | ||
typicalP: "typical_p", | ||
useMLock: "use_mlock", | ||
useMMap: "use_mmap", | ||
vocabOnly: "vocab_only", | ||
}; | ||
|
||
for (const [key, value] of Object.entries(requestOptions)) { | ||
const snakeCasedOption = mapping[key as keyof CamelCasedRequestOptions]; | ||
if (snakeCasedOption) { | ||
snakeCasedOptions[snakeCasedOption] = value; | ||
} | ||
} | ||
return snakeCasedOptions; | ||
} | ||
|
||
async _request(prompt: string): Promise<number[]> { | ||
const { model, baseUrl, requestOptions } = this; | ||
|
||
const response = await fetch(`${baseUrl}/api/embeddings`, { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ | ||
prompt, | ||
model, | ||
options: requestOptions, | ||
}), | ||
}); | ||
if (!response.ok) { | ||
throw new Error( | ||
`Request to Ollama server failed: ${response.status} ${response.statusText}` | ||
); | ||
} | ||
|
||
const json = await response.json(); | ||
return json.embedding; | ||
} | ||
|
||
async _embed(strings: string[]): Promise<number[][]> { | ||
const embeddings: number[][] = []; | ||
|
||
for await (const prompt of strings) { | ||
const embedding = await this.caller.call(() => this._request(prompt)); | ||
embeddings.push(embedding); | ||
} | ||
|
||
return embeddings; | ||
} | ||
|
||
async embedDocuments(documents: string[]) { | ||
return this._embed(documents); | ||
} | ||
|
||
async embedQuery(document: string) { | ||
return (await this.embedDocuments([document]))[0]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { OllamaEmbeddings } from "../ollama.js"; | ||
|
||
test("Test OllamaEmbeddings.embedQuery", async () => { | ||
const embeddings = new OllamaEmbeddings(); | ||
const res = await embeddings.embedQuery("Hello world"); | ||
expect(typeof res[0]).toBe("number"); | ||
}); | ||
|
||
test("Test OllamaEmbeddings.embedDocuments", async () => { | ||
const embeddings = new OllamaEmbeddings(); | ||
const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); | ||
expect(res).toHaveLength(2); | ||
expect(typeof res[0][0]).toBe("number"); | ||
expect(typeof res[1][0]).toBe("number"); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters