Skip to content

Commit 464d907

Browse files
Merge pull request huggingface#45 from ChristiaensBert/feature/add-processor-class
add Processor class
2 parents 9adbed7 + da0d1ef commit 464d907

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

src/controlnet_aux/processor.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
This file contains a Processor that can be used to process images with controlnet aux processors
3+
"""
4+
import io
5+
from typing import Union
6+
7+
from PIL import Image
8+
import numpy as np
9+
import torch
10+
from controlnet_aux import (HEDdetector,
11+
MidasDetector,
12+
MLSDdetector,
13+
OpenposeDetector,
14+
PidiNetDetector,
15+
NormalBaeDetector,
16+
LineartDetector,
17+
LineartAnimeDetector,
18+
CannyDetector,
19+
ContentShuffleDetector,
20+
ZoeDetector,
21+
MediapipeFaceDetector)
22+
23+
24+
MODELS = {
25+
# checkpoint models
26+
'hed': {'class': HEDdetector, 'checkpoint': True},
27+
'midas': {'class': MidasDetector, 'checkpoint': True},
28+
'mlsd': {'class': MLSDdetector, 'checkpoint': True},
29+
'openpose': {'class': OpenposeDetector, 'checkpoint': True},
30+
'pidinet': {'class': PidiNetDetector, 'checkpoint': True},
31+
'normalbae': {'class': NormalBaeDetector, 'checkpoint': True},
32+
'lineart': {'class': LineartDetector, 'checkpoint': True},
33+
'lineart_coarse': {'class': LineartDetector, 'checkpoint': True},
34+
'lineart_anime': {'class': LineartAnimeDetector, 'checkpoint': True},
35+
'zoe': {'class': ZoeDetector, 'checkpoint': True},
36+
# instantiate
37+
'content_shuffle': {'class': ContentShuffleDetector, 'checkpoint': False},
38+
'mediapipe_face': {'class': MediapipeFaceDetector, 'checkpoint': False},
39+
'canny': {'class': CannyDetector, 'checkpoint': False},
40+
}
41+
42+
# @patrickvonplaten, I can change this so people can pass their own parameters
43+
# but for my use case I'm using this Dictionary
44+
MODEL_PARAMS = {
45+
'hed': {'resize': False},
46+
'midas': {'resize': 512},
47+
'mlsd': {'resize': False},
48+
'openpose': {'resize': False, 'hand_and_face': True},
49+
'pidinet': {'resize': False, 'safe': True},
50+
'normalbae': {'resize': False},
51+
'lineart': {'resize': False, 'coarse': True},
52+
'lineart_coarse': {'resize': False, 'coarse': True},
53+
'lineart_anime': {'resize': False},
54+
'canny': {'resize': False},
55+
'content_shuffle': {'resize': False},
56+
'zoe': {'resize': False},
57+
'mediapipe_face': {'resize': False},
58+
}
59+
60+
61+
class Processor:
62+
def __init__(self, processor_id: str) -> 'Processor':
63+
"""Processor that can be used to process images with controlnet aux processors
64+
65+
Args:
66+
processor_id (str): processor name
67+
68+
Returns:
69+
Processor: Processor object
70+
"""
71+
print(f"Loading {processor_id} processor")
72+
self.processor_id = processor_id
73+
self.processor = self.load_processor(self.processor_id)
74+
self.params = MODEL_PARAMS[self.processor_id]
75+
self.resize = self.params.pop('resize', False)
76+
if self.resize:
77+
# print warning: image will be resized
78+
print(f"Warning: {self.processor_id} will resize image to {self.resize}x{self.resize}")
79+
80+
def load_processor(self, processor_id: str):
81+
"""Load controlnet aux processors
82+
83+
Args:
84+
processor_id (str): processor name
85+
"""
86+
processor = MODELS[processor_id]['class']
87+
88+
if MODELS[processor_id]['checkpoint']:
89+
processor = processor.from_pretrained("lllyasviel/Annotators")
90+
else:
91+
processor = processor()
92+
return processor
93+
94+
def __call__(self, image: Union[Image.Image, bytes],
95+
to_bytes: bool = True) -> Union[Image.Image, bytes]:
96+
"""processes an image with a controlnet aux processor
97+
98+
Args:
99+
image (Union[Image.Image, bytes]): input image in bytes or PIL Image
100+
to_bytes (bool): whether to return bytes or PIL Image
101+
102+
Returns:
103+
Union[Image.Image, bytes]: processed image in bytes or PIL Image
104+
"""
105+
# check if bytes or PIL Image
106+
if isinstance(image, bytes):
107+
image = Image.open(io.BytesIO(image)).convert("RGB")
108+
109+
if self.resize:
110+
image = image.resize((self.resize, self.resize))
111+
112+
processed_image = self.processor(image, **self.params)
113+
114+
if to_bytes:
115+
output_bytes = io.BytesIO()
116+
processed_image.save(output_bytes, format='JPEG')
117+
return output_bytes.getvalue()
118+
else:
119+
return processed_image

0 commit comments

Comments
 (0)