7
7
from huggingface_hub import hf_hub_download
8
8
from PIL import Image
9
9
10
- from ..util import HWC3
10
+ from ..util import HWC3 , resize_image
11
11
from .api import MiDaSInference
12
12
13
13
@@ -36,14 +36,17 @@ def to(self, device):
36
36
self .model .to (device )
37
37
return self
38
38
39
- def __call__ (self , input_image , a = np .pi * 2.0 , bg_th = 0.1 , depth_and_normal = False ):
39
+ def __call__ (self , input_image , a = np .pi * 2.0 , bg_th = 0.1 , depth_and_normal = False , detect_resolution = 512 , image_resolution = 512 , output_type = None ):
40
40
device = next (iter (self .model .parameters ())).device
41
- input_type = "np"
42
- if isinstance (input_image , Image .Image ):
43
- input_image = np .array (input_image )
44
- input_type = "pil"
45
-
41
+ if not isinstance (input_image , np .ndarray ):
42
+ input_image = np .array (input_image , dtype = np .uint8 )
43
+ output_type = output_type or "pil"
44
+ else :
45
+ output_type = output_type or "np"
46
+
46
47
input_image = HWC3 (input_image )
48
+ input_image = resize_image (input_image , detect_resolution )
49
+
47
50
assert input_image .ndim == 3
48
51
image_depth = input_image
49
52
with torch .no_grad ():
@@ -70,9 +73,19 @@ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False
70
73
normal /= np .sum (normal ** 2.0 , axis = 2 , keepdims = True ) ** 0.5
71
74
normal_image = (normal * 127.5 + 127.5 ).clip (0 , 255 ).astype (np .uint8 )[:, :, ::- 1 ]
72
75
73
- if input_type == "pil" :
76
+ depth_image = HWC3 (depth_image )
77
+ if depth_and_normal :
78
+ normal_image = HWC3 (normal_image )
79
+
80
+ img = resize_image (input_image , image_resolution )
81
+ H , W , C = img .shape
82
+
83
+ depth_image = cv2 .resize (depth_image , (W , H ), interpolation = cv2 .INTER_LINEAR )
84
+ if depth_and_normal :
85
+ normal_image = cv2 .resize (normal_image , (W , H ), interpolation = cv2 .INTER_LINEAR )
86
+
87
+ if output_type == "pil" :
74
88
depth_image = Image .fromarray (depth_image )
75
- depth_image = depth_image .convert ("RGB" )
76
89
if depth_and_normal :
77
90
normal_image = Image .fromarray (normal_image )
78
91
0 commit comments