Skip to content

Commit

Permalink
Runnable binding stream fixes (langchain-ai#2587)
Browse files Browse the repository at this point in the history
* Start runnables streaming updates

* Fix test
  • Loading branch information
jacoblee93 authored Sep 11, 2023
1 parent c706e8c commit abb9491
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 84 deletions.
206 changes: 122 additions & 84 deletions langchain/src/schema/runnable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,17 +317,120 @@ export abstract class Runnable<
* @param generator
* @param options
*/
transform?(
async *transform(
generator: AsyncGenerator<RunInput>,
options: Partial<CallOptions>
): AsyncGenerator<RunOutput>;
): AsyncGenerator<RunOutput> {
let finalChunk;
for await (const chunk of generator) {
if (!finalChunk) {
finalChunk = chunk;
} else {
// Make a best effort to gather, for any type that supports concat.
// This method should throw an error if gathering fails.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
finalChunk = (finalChunk as any).concat(chunk);
}
}
yield* this._streamIterator(finalChunk, options);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
static isRunnable(thing: any): thing is Runnable {
return thing.lc_runnable;
}
}

/**
* A runnable that delegates calls to another runnable with a set of kwargs.
*/
export class RunnableBinding<
RunInput,
RunOutput,
CallOptions extends BaseCallbackConfig
> extends Runnable<RunInput, RunOutput, CallOptions> {
static lc_name() {
return "RunnableBinding";
}

lc_namespace = ["langchain", "schema", "runnable"];

lc_serializable = true;

bound: Runnable<RunInput, RunOutput, CallOptions>;

protected kwargs: Partial<CallOptions>;

constructor(fields: {
bound: Runnable<RunInput, RunOutput, CallOptions>;
kwargs: Partial<CallOptions>;
}) {
super(fields);
this.bound = fields.bound;
this.kwargs = fields.kwargs;
}

bind(
kwargs: Partial<CallOptions>
): RunnableBinding<RunInput, RunOutput, CallOptions> {
return new RunnableBinding({
bound: this.bound,
kwargs: { ...this.kwargs, ...kwargs },
});
}

async invoke(
input: RunInput,
options?: Partial<CallOptions>
): Promise<RunOutput> {
return this.bound.invoke(input, { ...options, ...this.kwargs });
}

async batch(
inputs: RunInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: { maxConcurrency?: number }
): Promise<RunOutput[]> {
const mergedOptions = Array.isArray(options)
? options.map((individualOption) => ({
...individualOption,
...this.kwargs,
}))
: { ...options, ...this.kwargs };
return this.bound.batch(inputs, mergedOptions, batchOptions);
}

async *_streamIterator(
input: RunInput,
options?: Partial<CallOptions> | undefined
) {
yield* this.bound._streamIterator(input, { ...options, ...this.kwargs });
}

async stream(
input: RunInput,
options?: Partial<CallOptions> | undefined
): Promise<IterableReadableStream<RunOutput>> {
return this.bound.stream(input, { ...options, ...this.kwargs });
}

async *transform(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
generator: AsyncGenerator<RunInput>,
options: Partial<CallOptions>
): AsyncGenerator<RunOutput> {
yield* this.bound.transform(generator, options);
}

static isRunnableBinding(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
thing: any
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): thing is RunnableBinding<any, any, any> {
return thing.bound && Runnable.isRunnable(thing.bound);
}
}

/**
* A sequence of runnables, where the output of each is the input of the next.
*/
Expand Down Expand Up @@ -482,14 +585,24 @@ export class RunnableSequence<
);
let nextStepInput = input;
const steps = [this.first, ...this.middle, this.last];
// Find the index of the last runnable in the sequence that doesn't have a .transform() method
// Find the index of the last runnable in the sequence that doesn't have an overridden .transform() method
// and start streaming from there
const streamingStartStepIndex =
const streamingStartStepIndex = Math.min(
steps.length - 1,
steps.length -
[...steps]
.reverse()
.findIndex((step) => typeof step.transform !== "function") -
1;
[...steps].reverse().findIndex((step) => {
const isDefaultImplementation =
step.transform === Runnable.prototype.transform;
const boundRunnableIsDefaultImplementation =
RunnableBinding.isRunnableBinding(step) &&
step.bound?.transform === Runnable.prototype.transform;
return (
isDefaultImplementation || boundRunnableIsDefaultImplementation
);
}) -
1
);

try {
for (const step of steps.slice(0, streamingStartStepIndex)) {
nextStepInput = await step.invoke(
Expand All @@ -509,8 +622,7 @@ export class RunnableSequence<
this._patchConfig(options, runManager?.getChild())
);
for (const step of steps.slice(streamingStartStepIndex + 1)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
finalGenerator = await step.transform!(
finalGenerator = await step.transform(
finalGenerator,
this._patchConfig(options, runManager?.getChild())
);
Expand Down Expand Up @@ -699,80 +811,6 @@ export class RunnablePassthrough<RunInput> extends Runnable<
}
}

/**
* A runnable that delegates calls to another runnable with a set of kwargs.
*/
export class RunnableBinding<
RunInput,
RunOutput,
CallOptions extends BaseCallbackConfig
> extends Runnable<RunInput, RunOutput, CallOptions> {
static lc_name() {
return "RunnableBinding";
}

lc_namespace = ["langchain", "schema", "runnable"];

lc_serializable = true;

protected bound: Runnable<RunInput, RunOutput, CallOptions>;

protected kwargs: Partial<CallOptions>;

constructor(fields: {
bound: Runnable<RunInput, RunOutput, CallOptions>;
kwargs: Partial<CallOptions>;
}) {
super(fields);
this.bound = fields.bound;
this.kwargs = fields.kwargs;
}

bind(
kwargs: Partial<CallOptions>
): RunnableBinding<RunInput, RunOutput, CallOptions> {
return new RunnableBinding({
bound: this.bound,
kwargs: { ...this.kwargs, ...kwargs },
});
}

async invoke(
input: RunInput,
options?: Partial<CallOptions>
): Promise<RunOutput> {
return this.bound.invoke(input, { ...options, ...this.kwargs });
}

async batch(
inputs: RunInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: { maxConcurrency?: number }
): Promise<RunOutput[]> {
const mergedOptions = Array.isArray(options)
? options.map((individualOption) => ({
...individualOption,
...this.kwargs,
}))
: { ...options, ...this.kwargs };
return this.bound.batch(inputs, mergedOptions, batchOptions);
}

async *_streamIterator(
input: RunInput,
options?: Partial<CallOptions> | undefined
) {
yield* this.bound._streamIterator(input, { ...options, ...this.kwargs });
}

async stream(
input: RunInput,
options?: Partial<CallOptions> | undefined
): Promise<IterableReadableStream<RunOutput>> {
return this.bound.stream(input, { ...options, ...this.kwargs });
}
}

export type RouterInput = {
key: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
13 changes: 13 additions & 0 deletions langchain/src/schema/tests/runnable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,16 @@ test("Stream with RunnableBinding", async () => {
expect(chunks.length).toEqual("Hi there!".length);
expect(chunks.join("")).toEqual("Hi there!");
});

test("Stream through a RunnableBinding if the bound runnable implements transform", async () => {
const llm = new FakeStreamingLLM({}).bind({ stop: ["dummy"] });
const outputParser = new StringOutputParser().bind({ callbacks: [] });
const stream = await llm.pipe(outputParser).stream("Hi there!");
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
console.log(chunk);
}
expect(chunks.length).toEqual("Hi there!".length);
expect(chunks.join("")).toEqual("Hi there!");
});

0 comments on commit abb9491

Please sign in to comment.