Skip to content

Commit ae9895f

Browse files
authored
Allow runnables to implement transform streaming (langchain-ai#2156)
* Adds support for transform streaming on runnables * Adds comment * Fix types * Adds EncodingOutputParser * Rename to BytesOutputParser
1 parent 1c1274d commit ae9895f

File tree

6 files changed

+227
-16
lines changed

6 files changed

+227
-16
lines changed

langchain/src/chains/transform.ts

+7-7
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ export interface TransformChainFields<
1111
outputVariables: (keyof O extends string ? keyof O : never)[];
1212
}
1313

14-
export class TransformChain<I extends ChainValues, O extends ChainValues>
15-
extends BaseChain
16-
implements TransformChainFields<I, O>
17-
{
18-
transform: (values: I, callbacks?: Callbacks) => O | Promise<O>;
14+
export class TransformChain<
15+
I extends ChainValues,
16+
O extends ChainValues
17+
> extends BaseChain {
18+
transformFunc: (values: I, callbacks?: Callbacks) => O | Promise<O>;
1919

2020
inputVariables: (keyof I extends string ? keyof I : never)[];
2121

@@ -35,12 +35,12 @@ export class TransformChain<I extends ChainValues, O extends ChainValues>
3535

3636
constructor(fields: TransformChainFields<I, O>) {
3737
super(fields);
38-
this.transform = fields.transform;
38+
this.transformFunc = fields.transform;
3939
this.inputVariables = fields.inputVariables;
4040
this.outputVariables = fields.outputVariables;
4141
}
4242

4343
async _call(values: I, runManager?: CallbackManagerForChainRun): Promise<O> {
44-
return this.transform(values, runManager?.getChild("transform"));
44+
return this.transformFunc(values, runManager?.getChild("transform"));
4545
}
4646
}

langchain/src/output_parsers/noop.ts

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
import { StringOutputParser } from "../schema/output_parser.js";
1+
import { BaseOutputParser } from "../schema/output_parser.js";
22

3-
/** @deprecated Use StringOutputParser instead */
4-
export class NoOpOutputParser extends StringOutputParser {}
3+
export class NoOpOutputParser extends BaseOutputParser<string> {
4+
lc_namespace = ["langchain", "output_parsers", "default"];
5+
6+
lc_serializable = true;
7+
8+
parse(text: string): Promise<string> {
9+
return Promise.resolve(text);
10+
}
11+
12+
getFormatInstructions(): string {
13+
return "";
14+
}
15+
}

langchain/src/schema/output_parser.ts

+51-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Callbacks } from "../callbacks/manager.js";
1+
import { BaseCallbackConfig, Callbacks } from "../callbacks/manager.js";
22
import {
33
BasePromptValue,
44
Generation,
@@ -101,10 +101,39 @@ export abstract class BaseOutputParser<
101101
}
102102
}
103103

104+
/**
105+
* Class to parse the output of an LLM call that also allows streaming inputs.
106+
*/
107+
export abstract class BaseTransformOutputParser<
108+
T = unknown
109+
> extends BaseOutputParser<T> {
110+
async *_transform(
111+
inputGenerator: AsyncGenerator<string | BaseMessage>
112+
): AsyncGenerator<T> {
113+
for await (const chunk of inputGenerator) {
114+
if (typeof chunk === "string") {
115+
yield this.parseResult([{ text: chunk }]);
116+
} else {
117+
yield this.parseResult([{ message: chunk, text: chunk.content }]);
118+
}
119+
}
120+
}
121+
122+
async *transform(
123+
inputGenerator: AsyncGenerator<string | BaseMessage>,
124+
options: BaseCallbackConfig
125+
): AsyncGenerator<T> {
126+
yield* this._streamWithConfig(this._transform(inputGenerator), {
127+
...options,
128+
runType: "parser",
129+
});
130+
}
131+
}
132+
104133
/**
105134
* OutputParser that parses LLMResult into the top likely string.
106135
*/
107-
export class StringOutputParser extends BaseOutputParser<string> {
136+
export class StringOutputParser extends BaseTransformOutputParser<string> {
108137
lc_namespace = ["schema", "output_parser"];
109138

110139
lc_serializable = true;
@@ -118,6 +147,26 @@ export class StringOutputParser extends BaseOutputParser<string> {
118147
}
119148
}
120149

150+
/**
151+
* OutputParser that parses LLMResult into the top likely string and
152+
* encodes it into bytes.
153+
*/
154+
export class BytesOutputParser extends BaseTransformOutputParser<Uint8Array> {
155+
lc_namespace = ["schema", "output_parser"];
156+
157+
lc_serializable = true;
158+
159+
protected textEncoder = new TextEncoder();
160+
161+
parse(text: string): Promise<Uint8Array> {
162+
return Promise.resolve(this.textEncoder.encode(text));
163+
}
164+
165+
getFormatInstructions(): string {
166+
return "";
167+
}
168+
}
169+
121170
export class OutputParserException extends Error {
122171
output?: string;
123172

langchain/src/schema/runnable.ts

+68-3
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,50 @@ export abstract class Runnable<
143143
return output;
144144
}
145145

146+
protected async *_streamWithConfig<T extends RunOutput>(
147+
generator: AsyncGenerator<T>,
148+
options?: RunnableConfig & { runType?: string }
149+
) {
150+
const callbackManager_ = await CallbackManager.configure(
151+
options?.callbacks,
152+
undefined,
153+
options?.tags,
154+
undefined,
155+
options?.metadata
156+
);
157+
// TODO: Find a way to pass the entire streamed value into the callback.
158+
const runManager = await callbackManager_?.handleChainStart(
159+
this.toJSON(),
160+
_coerceToDict("<streamed value>", "input"),
161+
undefined,
162+
options?.runType
163+
);
164+
let output;
165+
let concatSupported = true;
166+
try {
167+
for await (const chunk of generator) {
168+
yield chunk;
169+
if (concatSupported) {
170+
if (output === undefined) {
171+
output = chunk;
172+
} else {
173+
try {
174+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
175+
output = (output as any).concat(chunk);
176+
} catch (e) {
177+
output = undefined;
178+
concatSupported = false;
179+
}
180+
}
181+
}
182+
}
183+
} catch (e) {
184+
await runManager?.handleChainError(e);
185+
throw e;
186+
}
187+
await runManager?.handleChainEnd(_coerceToDict(output, "output"));
188+
}
189+
146190
_patchConfig(
147191
config: Partial<CallOptions> = {},
148192
callbackManager: CallbackManager | undefined = undefined
@@ -160,6 +204,11 @@ export abstract class Runnable<
160204
});
161205
}
162206

207+
transform?(
208+
generator: AsyncGenerator<RunInput>,
209+
options: Partial<CallOptions>
210+
): AsyncGenerator<RunOutput>;
211+
163212
// eslint-disable-next-line @typescript-eslint/no-explicit-any
164213
static isRunnable(thing: any): thing is Runnable {
165214
return thing.lc_runnable;
@@ -314,8 +363,17 @@ export class RunnableSequence<
314363
_coerceToDict(input, "input")
315364
);
316365
let nextStepInput = input;
366+
const steps = [this.first, ...this.middle, this.last];
367+
// Find the index of the last runnable in the sequence that doesn't have a .transform() method
368+
// and start streaming from there
369+
const streamingStartStepIndex =
370+
steps.length -
371+
[...steps]
372+
.reverse()
373+
.findIndex((step) => typeof step.transform !== "function") -
374+
1;
317375
try {
318-
for (const step of [this.first, ...this.middle]) {
376+
for (const step of steps.slice(0, streamingStartStepIndex)) {
319377
nextStepInput = await step.invoke(
320378
nextStepInput,
321379
this._patchConfig(options, runManager?.getChild())
@@ -328,11 +386,18 @@ export class RunnableSequence<
328386
let concatSupported = true;
329387
let finalOutput;
330388
try {
331-
const iterator = await this.last._streamIterator(
389+
let finalGenerator = await steps[streamingStartStepIndex]._streamIterator(
332390
nextStepInput,
333391
this._patchConfig(options, runManager?.getChild())
334392
);
335-
for await (const chunk of iterator) {
393+
for (const step of steps.slice(streamingStartStepIndex + 1)) {
394+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
395+
finalGenerator = await step.transform!(
396+
finalGenerator,
397+
this._patchConfig(options, runManager?.getChild())
398+
);
399+
}
400+
for await (const chunk of finalGenerator) {
336401
yield chunk;
337402
if (concatSupported) {
338403
if (finalOutput === undefined) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* eslint-disable no-promise-executor-return */
2+
3+
import { test } from "@jest/globals";
4+
import { LLM } from "../../llms/base.js";
5+
import { GenerationChunk } from "../index.js";
6+
import { BytesOutputParser } from "../output_parser.js";
7+
8+
class FakeStreamingLLM extends LLM {
9+
_llmType() {
10+
return "fake";
11+
}
12+
13+
async _call(prompt: string): Promise<string> {
14+
return prompt;
15+
}
16+
17+
async *_streamResponseChunks(input: string) {
18+
for (const c of input) {
19+
await new Promise((resolve) => setTimeout(resolve, 50));
20+
yield { text: c, generationInfo: {} } as GenerationChunk;
21+
}
22+
}
23+
}
24+
25+
test("BytesOutputParser", async () => {
26+
const llm = new FakeStreamingLLM({});
27+
const stream = await llm.pipe(new BytesOutputParser()).stream("Hi there!");
28+
const chunks = [];
29+
const decoder = new TextDecoder();
30+
for await (const chunk of stream) {
31+
chunks.push(decoder.decode(chunk));
32+
}
33+
expect(chunks.length).toEqual("Hi there!".length);
34+
expect(chunks.join("")).toEqual("Hi there!");
35+
});

langchain/src/schema/tests/runnable.test.ts

+52-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
/* eslint-disable no-promise-executor-return */
2+
13
import { z } from "zod";
24
import { test } from "@jest/globals";
35
import { LLM } from "../../llms/base.js";
46
import {
57
BaseChatModel,
68
createChatMessageChunkEncoderStream,
79
} from "../../chat_models/base.js";
8-
import { AIMessage, BaseMessage, ChatResult } from "../index.js";
10+
import {
11+
AIMessage,
12+
BaseMessage,
13+
ChatResult,
14+
GenerationChunk,
15+
} from "../index.js";
916
import {
1017
ChatPromptTemplate,
1118
HumanMessagePromptTemplate,
@@ -28,6 +35,23 @@ class FakeLLM extends LLM {
2835
}
2936
}
3037

38+
class FakeStreamingLLM extends LLM {
39+
_llmType() {
40+
return "fake";
41+
}
42+
43+
async _call(prompt: string): Promise<string> {
44+
return prompt;
45+
}
46+
47+
async *_streamResponseChunks(input: string) {
48+
for (const c of input) {
49+
await new Promise((resolve) => setTimeout(resolve, 50));
50+
yield { text: c, generationInfo: {} } as GenerationChunk;
51+
}
52+
}
53+
}
54+
3155
class FakeChatModel extends BaseChatModel {
3256
_combineLLMOutput() {
3357
return [];
@@ -195,3 +219,30 @@ test("Bind kwargs to a runnable with a batch call", async () => {
195219
console.log(result);
196220
expect(result).toEqual(["testing", "testing", "testing", "testing"]);
197221
});
222+
223+
test("Stream the entire way through", async () => {
224+
const llm = new FakeStreamingLLM({});
225+
const stream = await llm.pipe(new StringOutputParser()).stream("Hi there!");
226+
const chunks = [];
227+
for await (const chunk of stream) {
228+
chunks.push(chunk);
229+
console.log(chunk);
230+
}
231+
expect(chunks.length).toEqual("Hi there!".length);
232+
expect(chunks.join("")).toEqual("Hi there!");
233+
});
234+
235+
test("Don't use intermediate streaming", async () => {
236+
const llm = new FakeStreamingLLM({});
237+
const stream = await llm
238+
.pipe(new StringOutputParser())
239+
.pipe(new FakeLLM({}))
240+
.stream("Hi there!");
241+
const chunks = [];
242+
for await (const chunk of stream) {
243+
chunks.push(chunk);
244+
console.log(chunk);
245+
}
246+
expect(chunks.length).toEqual(1);
247+
expect(chunks[0]).toEqual("Hi there!");
248+
});

0 commit comments

Comments
 (0)