Skip to content

Commit

Permalink
[TextGeneration] Add Streaming Functionality (#1246)
Browse files Browse the repository at this point in the history
* add streaming functionality

* remove print

* set back default value

* rebase

* update to yield

* update pipeline.py

* update tests

* refactor out streaming functions and remove yield in process_engine_output

* fix tests

* update pipeline to use kwargs

* rebase

* Update src/deepsparse/transformers/pipelines/text_generation.py
  • Loading branch information
dsikka authored Sep 21, 2023
1 parent fdb5d44 commit cd74aa2
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 86 deletions.
13 changes: 9 additions & 4 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -259,7 +259,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
)

# join together the batches of size `self._batch_size`
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
engine_outputs = self.join_engine_outputs(
batch_outputs, orig_batch_size, **context
)
timer.stop(InferenceStages.ENGINE_FORWARD)

self.log(
Expand All @@ -280,7 +282,10 @@ def __call__(self, *args, **kwargs) -> BaseModel:
# ------ POSTPROCESSING ------
timer.start(InferenceStages.POST_PROCESS)
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
if not isinstance(pipeline_outputs, self.output_schema):
if not (
isinstance(pipeline_outputs, (self.output_schema, Generator))
or isinstance(pipeline_outputs, Generator)
):
raise ValueError(
f"Outputs of {self.__class__} must be instances of "
f"{self.output_schema} found output of type "
Expand Down Expand Up @@ -467,7 +472,7 @@ def to_config(self) -> "PipelineConfig":
)

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int, **kwargs
) -> List[numpy.ndarray]:
"""
Joins list of engine outputs together into one list.
Expand Down
210 changes: 129 additions & 81 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import numpy
import onnx
from pydantic import BaseModel, Field
from transformers import TextStreamer

from deepsparse import Pipeline
from deepsparse.pipeline import DEEPSPARSE_ENGINE
Expand Down Expand Up @@ -61,6 +60,7 @@ class FinishReason(Enum):
STOP = "stop"
LENGTH = "length"
TIME = "time"
CALLBACK = "callback"


class TextGenerationInput(BaseModel):
Expand Down Expand Up @@ -106,12 +106,12 @@ class Config:
"to have consistent length so one "
"can compute metric in a batched fashion. ",
)
streamer: Optional[TextStreamer] = Field(
default=None,
description="Streamer object that will be used to stream the "
"generated sequences. Generated tokens are passed through "
"`streamer.put(token_ids)` and the streamer is responsible "
"for any further processing.",
streaming: bool = Field(
default=False,
description="Whether to stream the results back as they are generated. If "
"True, then the results are returned as a generator object which yields "
"the results as they are generated. If False, then the results are returned "
"as a list after it has completed.",
)
callback: Optional[Callable[[Any], Union[bool, Any]]] = Field(
default=None,
Expand Down Expand Up @@ -161,7 +161,7 @@ class GeneratedText(BaseModel):
"The scores have the shape [sequence_length, vocab_size]"
)
finished: bool = Field(description="Whether generation has stopped.")
finished_reason: str = Field(
finished_reason: Optional[str] = Field(
description="The reason for generation to stop. "
"Defined by FinishReason. One of stop, length, or time."
)
Expand Down Expand Up @@ -473,9 +473,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

context = dict(
prompts=original_inputs,
streaming=inputs.streaming,
num_generated_predictions=inputs.num_generated_predictions,
return_logits=inputs.return_logits,
streamer=inputs.streamer,
include_prompt_logits=inputs.include_prompt_logits,
callback=inputs.callback,
stop=inputs.stop,
Expand All @@ -488,6 +488,40 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:

return engine_input, context

def _create_generated_text_output(
self,
sequence: str,
finish_reason: Optional[FinishReason] = None,
logits: Optional[numpy.array] = None,
):
if finish_reason:
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)
return GeneratedText(
text=sequence,
score=logits,
finished=False,
)

def _stream_engine_outputs(self, engine_outputs, prompts, kwargs):
for output in engine_outputs:
generated_tokens, generated_logits, finished_reason = output
logits = generated_logits if kwargs.get("return_logits") else None
generation = self._create_generated_text_output(
self.tokenizer.batch_decode(generated_tokens)[0],
finished_reason[0],
logits,
)
yield TextGenerationOutput(
created=datetime.datetime.now(),
prompts=prompts,
generations=[generation],
)

def process_engine_outputs(
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
) -> TextGenerationOutput:
Expand All @@ -497,33 +531,29 @@ def process_engine_outputs(
:param engine_outputs: the outputs from the engine
:return: the output schema for the pipeline
"""
generated_tokens, generated_logits, finished_reason, *debug = engine_outputs
finished_reason = [f[0] for f in finished_reason]

prompts = kwargs.get("prompts")
streaming = kwargs.get("streaming")

if streaming:
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)

generated_tokens, generated_logits, finished_reason, *debug = list(
*engine_outputs
)
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
num_preds = kwargs.get("num_generated_predictions", 1)
prompts = kwargs.get("prompts")

def _create_generated_text_output(
sequence: str,
finish_reason: FinishReason,
logits: Optional[numpy.array] = None,
):
return GeneratedText(
text=sequence,
score=logits,
finished=True,
finished_reason=finish_reason.value,
)

logits = generated_logits if kwargs.get("return_logits") else None

num_preds = kwargs.get("num_generated_predictions", 1)
finished_reason = [f[0] for f in finished_reason]

if logits is not None:
generations = list(
self.executor.map(
_create_generated_text_output,
self._create_generated_text_output,
sequences,
finished_reason,
logits,
Expand All @@ -532,7 +562,7 @@ def _create_generated_text_output(
else:
generations = list(
self.executor.map(
_create_generated_text_output, sequences, finished_reason
self._create_generated_text_output, sequences, finished_reason
)
)

Expand Down Expand Up @@ -582,8 +612,8 @@ def engine_forward(
# names in this context

with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")
finished_reason = []
streaming = context.get("streaming")

if not self.cache_support_enabled:
prompt_logits = self.multitoken_engine(engine_inputs)
Expand All @@ -610,9 +640,6 @@ def engine_forward(
)
token_generator.generate(prompt_logits[-1][0, -1, :])

if streamer is not None:
streamer.put(numpy.array(token_generator.tokens))

# create the generated output
max_tokens = context.get("max_tokens", 0)
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)
Expand All @@ -638,9 +665,6 @@ def engine_forward(
generated_tokens.append(token)
generated_logits.append(logits)

if streamer is not None:
streamer.put(numpy.array([token]))

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
Expand All @@ -656,30 +680,38 @@ def engine_forward(
finished_reason.append(FinishReason.STOP)
break

# TODO: Add any generic callback reason?
if callback is not None and callback(token) is False:
_LOGGER.debug(
"callback %s returned False, stopping generation."
% callback.__qualname__
)
finished_reason.append(FinishReason.CALLBACK)
break

if len(generated_tokens) == max_tokens:
finished_reason.append(FinishReason.LENGTH)

if streamer is not None:
streamer.end()
if streaming:
yield (numpy.array([token]), numpy.array([logits]), [None])

returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)
if streaming:
yield (
numpy.array([token]),
numpy.array([logits]),
[finished_reason[-1]],
)

if not streaming:
returns = (
numpy.array([generated_tokens]),
numpy.concatenate(generated_logits, axis=1),
finished_reason,
)

if self._debug is True:
return *returns, session
if self._debug is True:
yield *returns, session

return returns
yield returns

def prompt_inference(
self,
Expand Down Expand Up @@ -870,6 +902,7 @@ def join_engine_outputs(
self,
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
orig_batch_size: int,
**kwargs,
) -> List[Union[numpy.ndarray, FinishReason]]:
"""
Takes a list of outputs (batches) from the engine
Expand All @@ -881,48 +914,63 @@ def join_engine_outputs(
:param orig_batch_size: The original batch size
:return: A list of joined outputs
"""
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
# different lengths

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]
streaming = kwargs.get("streaming")
if streaming:
for batch in batch_outputs:
for outputs in batch:
yield outputs
else:
batch_outputs = [list(*b) for b in batch_outputs]
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
# different lengths

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]

# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])
# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])

# pad all logits to the same length
logits = [
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
for single_logits in logits
]
# pad all logits to the same length
logits = [
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
for single_logits in logits
]

tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)
tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)

if debug:
sessions = debug[0]
kv_cache_state = numpy.stack(session.cached_inputs for session in sessions)
num_processed_tokens = numpy.stack(
session.total_num_processed_tokens for session in sessions
)
if debug:
sessions = debug[0]
kv_cache_state = numpy.stack(
session.cached_inputs for session in sessions
)
num_processed_tokens = numpy.stack(
session.total_num_processed_tokens for session in sessions
)

return [tokens, logits, finish_reason, kv_cache_state, num_processed_tokens]
yield [
tokens,
logits,
finish_reason,
kv_cache_state,
num_processed_tokens,
]

return [tokens, logits, finish_reason]
yield [tokens, logits, finish_reason]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
Expand Down
Loading

0 comments on commit cd74aa2

Please sign in to comment.