From abb9491696e7f223c2328279889c811a1d12139b Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 11 Sep 2023 09:52:20 -0700 Subject: [PATCH] Runnable binding stream fixes (#2587) * Start runnables streaming updates * Fix test --- langchain/src/schema/runnable.ts | 206 ++++++++++++-------- langchain/src/schema/tests/runnable.test.ts | 13 ++ 2 files changed, 135 insertions(+), 84 deletions(-) diff --git a/langchain/src/schema/runnable.ts b/langchain/src/schema/runnable.ts index e21cc409eacd..0ee5ac469ce9 100644 --- a/langchain/src/schema/runnable.ts +++ b/langchain/src/schema/runnable.ts @@ -317,10 +317,23 @@ export abstract class Runnable< * @param generator * @param options */ - transform?( + async *transform( generator: AsyncGenerator, options: Partial - ): AsyncGenerator; + ): AsyncGenerator { + let finalChunk; + for await (const chunk of generator) { + if (!finalChunk) { + finalChunk = chunk; + } else { + // Make a best effort to gather, for any type that supports concat. + // This method should throw an error if gathering fails. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + finalChunk = (finalChunk as any).concat(chunk); + } + } + yield* this._streamIterator(finalChunk, options); + } // eslint-disable-next-line @typescript-eslint/no-explicit-any static isRunnable(thing: any): thing is Runnable { @@ -328,6 +341,96 @@ export abstract class Runnable< } } +/** + * A runnable that delegates calls to another runnable with a set of kwargs. + */ +export class RunnableBinding< + RunInput, + RunOutput, + CallOptions extends BaseCallbackConfig +> extends Runnable { + static lc_name() { + return "RunnableBinding"; + } + + lc_namespace = ["langchain", "schema", "runnable"]; + + lc_serializable = true; + + bound: Runnable; + + protected kwargs: Partial; + + constructor(fields: { + bound: Runnable; + kwargs: Partial; + }) { + super(fields); + this.bound = fields.bound; + this.kwargs = fields.kwargs; + } + + bind( + kwargs: Partial + ): RunnableBinding { + return new RunnableBinding({ + bound: this.bound, + kwargs: { ...this.kwargs, ...kwargs }, + }); + } + + async invoke( + input: RunInput, + options?: Partial + ): Promise { + return this.bound.invoke(input, { ...options, ...this.kwargs }); + } + + async batch( + inputs: RunInput[], + options?: Partial | Partial[], + batchOptions?: { maxConcurrency?: number } + ): Promise { + const mergedOptions = Array.isArray(options) + ? options.map((individualOption) => ({ + ...individualOption, + ...this.kwargs, + })) + : { ...options, ...this.kwargs }; + return this.bound.batch(inputs, mergedOptions, batchOptions); + } + + async *_streamIterator( + input: RunInput, + options?: Partial | undefined + ) { + yield* this.bound._streamIterator(input, { ...options, ...this.kwargs }); + } + + async stream( + input: RunInput, + options?: Partial | undefined + ): Promise> { + return this.bound.stream(input, { ...options, ...this.kwargs }); + } + + async *transform( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + generator: AsyncGenerator, + options: Partial + ): AsyncGenerator { + yield* this.bound.transform(generator, options); + } + + static isRunnableBinding( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + thing: any + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ): thing is RunnableBinding { + return thing.bound && Runnable.isRunnable(thing.bound); + } +} + /** * A sequence of runnables, where the output of each is the input of the next. */ @@ -482,14 +585,24 @@ export class RunnableSequence< ); let nextStepInput = input; const steps = [this.first, ...this.middle, this.last]; - // Find the index of the last runnable in the sequence that doesn't have a .transform() method + // Find the index of the last runnable in the sequence that doesn't have an overridden .transform() method // and start streaming from there - const streamingStartStepIndex = + const streamingStartStepIndex = Math.min( + steps.length - 1, steps.length - - [...steps] - .reverse() - .findIndex((step) => typeof step.transform !== "function") - - 1; + [...steps].reverse().findIndex((step) => { + const isDefaultImplementation = + step.transform === Runnable.prototype.transform; + const boundRunnableIsDefaultImplementation = + RunnableBinding.isRunnableBinding(step) && + step.bound?.transform === Runnable.prototype.transform; + return ( + isDefaultImplementation || boundRunnableIsDefaultImplementation + ); + }) - + 1 + ); + try { for (const step of steps.slice(0, streamingStartStepIndex)) { nextStepInput = await step.invoke( @@ -509,8 +622,7 @@ export class RunnableSequence< this._patchConfig(options, runManager?.getChild()) ); for (const step of steps.slice(streamingStartStepIndex + 1)) { - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - finalGenerator = await step.transform!( + finalGenerator = await step.transform( finalGenerator, this._patchConfig(options, runManager?.getChild()) ); @@ -699,80 +811,6 @@ export class RunnablePassthrough extends Runnable< } } -/** - * A runnable that delegates calls to another runnable with a set of kwargs. - */ -export class RunnableBinding< - RunInput, - RunOutput, - CallOptions extends BaseCallbackConfig -> extends Runnable { - static lc_name() { - return "RunnableBinding"; - } - - lc_namespace = ["langchain", "schema", "runnable"]; - - lc_serializable = true; - - protected bound: Runnable; - - protected kwargs: Partial; - - constructor(fields: { - bound: Runnable; - kwargs: Partial; - }) { - super(fields); - this.bound = fields.bound; - this.kwargs = fields.kwargs; - } - - bind( - kwargs: Partial - ): RunnableBinding { - return new RunnableBinding({ - bound: this.bound, - kwargs: { ...this.kwargs, ...kwargs }, - }); - } - - async invoke( - input: RunInput, - options?: Partial - ): Promise { - return this.bound.invoke(input, { ...options, ...this.kwargs }); - } - - async batch( - inputs: RunInput[], - options?: Partial | Partial[], - batchOptions?: { maxConcurrency?: number } - ): Promise { - const mergedOptions = Array.isArray(options) - ? options.map((individualOption) => ({ - ...individualOption, - ...this.kwargs, - })) - : { ...options, ...this.kwargs }; - return this.bound.batch(inputs, mergedOptions, batchOptions); - } - - async *_streamIterator( - input: RunInput, - options?: Partial | undefined - ) { - yield* this.bound._streamIterator(input, { ...options, ...this.kwargs }); - } - - async stream( - input: RunInput, - options?: Partial | undefined - ): Promise> { - return this.bound.stream(input, { ...options, ...this.kwargs }); - } -} - export type RouterInput = { key: string; // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/langchain/src/schema/tests/runnable.test.ts b/langchain/src/schema/tests/runnable.test.ts index 65dbc3ad32f6..cb20c6889cf5 100644 --- a/langchain/src/schema/tests/runnable.test.ts +++ b/langchain/src/schema/tests/runnable.test.ts @@ -349,3 +349,16 @@ test("Stream with RunnableBinding", async () => { expect(chunks.length).toEqual("Hi there!".length); expect(chunks.join("")).toEqual("Hi there!"); }); + +test("Stream through a RunnableBinding if the bound runnable implements transform", async () => { + const llm = new FakeStreamingLLM({}).bind({ stop: ["dummy"] }); + const outputParser = new StringOutputParser().bind({ callbacks: [] }); + const stream = await llm.pipe(outputParser).stream("Hi there!"); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + console.log(chunk); + } + expect(chunks.length).toEqual("Hi there!".length); + expect(chunks.join("")).toEqual("Hi there!"); +});