Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipeline Refactor] Migration #1460

Merged
merged 20 commits into from
Dec 11, 2023
Merged
Prev Previous commit
Next Next commit
update codegen alias to use the new registry and text generation pipe…
…line
  • Loading branch information
dsikka committed Dec 8, 2023
commit 8c71397df33835e6200e3ad27a472d27ec85dbf0
27 changes: 25 additions & 2 deletions src/deepsparse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,22 @@ class SupportedTasks:
bloom=AliasedTask("bloom", []),
)

code_generation = namedtuple(
"code_generation", ["code_generation", "code_gen", "codegen"]
)(
code_generation=AliasedTask("code_generation", []),
code_gen=AliasedTask("code_gen", []),
codegen=AliasedTask("codegen", []),
)

image_classification = namedtuple("image_classification", ["image_classification"])(
image_classification=AliasedTask(
"image_classification",
["image_classification"],
),
)

all_task_categories = [text_generation]
all_task_categories = [text_generation, code_generation, image_classification]

@classmethod
def check_register_task(
Expand All @@ -107,6 +115,9 @@ def check_register_task(
if cls.is_text_generation(task):
import deepsparse.transformers.pipelines.text_generation # noqa: F401

elif cls.is_code_generation(task):
import deepsparse.transformers.pipelines.code_generation # noqa: F401

elif cls.is_image_classification(task):
# trigger image classification pipelines to
# register with Pipeline.register
Expand Down Expand Up @@ -142,7 +153,7 @@ def is_image_classification(cls, task: str) -> bool:

@classmethod
def task_names(cls):
task_names = ["custom"]
task_names = []
for task_category in cls.all_task_categories:
for task in task_category:
unique_aliases = (
Expand All @@ -151,6 +162,18 @@ def task_names(cls):
task_names += (task._name, *unique_aliases)
return task_names

@classmethod
def is_code_generation(cls, task: str) -> bool:
"""
:param task: the name of the task to check whether it is a text generation task
such as codegen
:return: True if it is a text generation task, False otherwise
"""
return any(
code_generation_task.matches(task)
for code_generation_task in cls.code_generation
)


def dynamic_import_task(module_or_path: str) -> str:
"""
Expand Down
11 changes: 3 additions & 8 deletions src/deepsparse/transformers/pipelines/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,14 @@
# limitations under the License.


from deepsparse.legacy import Pipeline
from deepsparse.legacy.transformers.pipelines.text_generation import (
TextGenerationPipeline,
)
from deepsparse.operators import OperatorRegistry
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline


__all__ = ["CodeGenerationPipeline"]


@Pipeline.register(
task="code_generation",
task_aliases=["codegen"],
)
@OperatorRegistry.register(name=["code_generation", "code_gen", "codegen"])
class CodeGenerationPipeline(TextGenerationPipeline):
"""
Subclass of text generation pipeline to support any defaults or
Expand Down
Loading