forked from Lightning-Universe/stable-diffusion-deploy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
269 lines (230 loc) · 11.7 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import os
import time
import uuid
from typing import List, Optional
import lightning as L
import requests
from lightning.app.frontend import StaticWebFrontend
from lightning.app.storage import Drive
from lightning.app.utilities.frontend import AppInfo
from lightning_api_access import APIAccessFrontend
from muse import LoadBalancer, Locust, MuseSlackCommandBot, SafetyCheckerEmbedding, StableDiffusionServe
from muse.CONST import ENABLE_ANALYTICS, MUSE_GPU_TYPE, MUSE_MIN_WORKERS
from muse.utility.analytics import analytics_headers
class ReactUI(L.LightningFlow):
def configure_layout(self):
return StaticWebFrontend(os.path.join(os.path.dirname(__file__), "muse", "ui", "build"))
class APIUsageFlow(L.LightningFlow):
def __init__(self, api_url: str = ""):
super().__init__()
self.api_url = api_url
def configure_layout(self):
return APIAccessFrontend(
apis=[
{
"name": "Generate Image",
"url": f"{self.api_url}/api/predict",
"method": "POST",
"request": {"prompt": "cats in hats", "high_quality": "true"},
"response": {"image": "data:image/png;base64,<image-actual-content>"},
}
]
)
class MuseFlow(L.LightningFlow):
"""The MuseFlow is a LightningFlow component that handles all the servers and uses load balancer to spawn up and
shutdown based on current requests in the queue.
Args:
initial_num_workers: Number of works to start when app initializes.
autoscale_interval: Number of seconds to wait before checking whether to upscale or downscale the works.
max_batch_size: Number of requests to process at once.
batch_timeout_secs: Number of seconds to wait before sending the requests to process.
gpu_type: GPU type to use for the works.
max_workers: Max numbers of works to spawn to handle the incoming requests.
autoscale_down_limit: Lower limit to determine when to stop works.
autoscale_up_limit: Upper limit to determine when to spawn up a new work.
"""
def __init__(
self,
initial_num_workers: int = MUSE_MIN_WORKERS,
autoscale_interval: int = 1 * 30,
max_batch_size: int = 12,
batch_timeout_secs: int = 2,
gpu_type: str = MUSE_GPU_TYPE,
max_workers: int = 20,
autoscale_down_limit: Optional[int] = None,
autoscale_up_limit: Optional[int] = None,
load_testing: Optional[bool] = False,
):
super().__init__()
self.hide_footer_shadow = True
self.load_balancer_started = False
self._initial_num_workers = initial_num_workers
self._num_workers = 0
self._work_registry = {}
self.autoscale_interval = autoscale_interval
self.max_workers = max_workers
self.autoscale_down_limit = autoscale_down_limit or initial_num_workers
self.autoscale_up_limit = autoscale_up_limit or initial_num_workers * max_batch_size
self.load_testing = load_testing or os.getenv("MUSE_LOAD_TESTING", False)
self.fake_trigger = 0
self.gpu_type = gpu_type
self._last_autoscale = time.time()
# Create Drive to store Safety Checker embeddings
self.safety_embeddings_drive = Drive("lit://embeddings")
# Safety Checker Embedding Work to create and store embeddings in the Drive
self.safety_checker_embedding_work = SafetyCheckerEmbedding(drive=self.safety_embeddings_drive)
self.load_balancer = LoadBalancer(
max_batch_size=max_batch_size, batch_timeout_secs=batch_timeout_secs, cache_calls=True, parallel=True
)
for i in range(initial_num_workers):
work = StableDiffusionServe(
safety_embeddings_drive=self.safety_embeddings_drive,
safety_embeddings_filename=self.safety_checker_embedding_work.safety_embeddings_filename,
cloud_compute=L.CloudCompute(gpu_type, disk_size=30),
cache_calls=True,
parallel=True,
start_with_flow=False,
)
self.add_work(work)
self.slack_bot = MuseSlackCommandBot(command="/muse")
if self.load_testing:
self.locust = Locust(locustfile="./scripts/locustfile.py")
self.printed_url = False
self.slack_bot_url = ""
self.dream_url = ""
self.ui = ReactUI()
self.api_component = APIUsageFlow()
self.safety_embeddings_ready = False
@property
def ready(self) -> bool:
return self.load_balancer.ready
@property
def model_servers(self) -> List[StableDiffusionServe]:
works = []
for i in range(self._num_workers):
work: StableDiffusionServe = self.get_work(i)
works.append(work)
return works
def add_work(self, work) -> str:
work_attribute = uuid.uuid4().hex
work_attribute = f"model_serve_{self._num_workers}_{str(work_attribute)}"
setattr(self, work_attribute, work)
self._work_registry[self._num_workers] = work_attribute
self._num_workers += 1
return work_attribute
def remove_work(self, index: int) -> str:
work_attribute = self._work_registry[index]
del self._work_registry[index]
work = getattr(self, work_attribute)
work.stop()
self._num_workers -= 1
return work_attribute
def get_work(self, index: int):
work_attribute = self._work_registry[index]
work = getattr(self, work_attribute)
return work
def run(self): # noqa: C901
if os.environ.get("TESTING_LAI"):
print("⚡ Lightning Dream App! ⚡")
# provision these works early
if not self.load_balancer.is_running:
self.load_balancer.run([])
if not self.slack_bot.is_running:
self.slack_bot.run("")
if not self.safety_embeddings_ready:
self.safety_checker_embedding_work.run()
if not self.safety_embeddings_ready and self.safety_checker_embedding_work.has_succeeded:
self.safety_embeddings_ready = True
self.safety_checker_embedding_work.stop()
for model_serve in self.model_servers:
model_serve.run()
if any(model_serve.url for model_serve in self.model_servers) and not self.load_balancer_started:
# run the load balancer when one of the model servers is ready
self.load_balancer.run([serve.url for serve in self.model_servers if serve.url])
self.load_balancer_started = True
if self.load_balancer.url: # hack for getting the work url
self.api_component.api_url = self.load_balancer.url
self.dream_url = self.load_balancer.url
if self.slack_bot is not None:
self.slack_bot.run(self.load_balancer.url)
self.slack_bot_url = self.slack_bot.url
if self.slack_bot.url and not self.printed_url:
print("Slack Bot Work ready with URL=", self.slack_bot.url)
print("model serve url=", self.load_balancer.url)
print("API component url=", self.api_component.state_vars["vars"]["_layout"]["target"])
self.printed_url = True
if self.load_testing and self.load_balancer.url:
self.locust.run(self.load_balancer.url)
if self.load_balancer.url:
self.fake_trigger += 1
self.autoscale()
def configure_layout(self):
ui = [{"name": "Muse App" if self.load_testing else None, "content": self.ui}]
if self.load_testing:
ui.append({"name": "Locust", "content": self.locust.url})
return ui
def autoscale(self):
"""Upscale and down scale model inference works based on the number of requests."""
if time.time() - self._last_autoscale < self.autoscale_interval:
return
self.load_balancer.update_servers(self.model_servers)
num_requests = int(requests.get(f"{self.load_balancer.url}/num-requests").json())
num_workers = len(self.model_servers)
# upscale
if num_requests > self.autoscale_up_limit and num_workers < self.max_workers:
idx = self._num_workers
print(f"Upscale to {self._num_workers + 1}")
work = StableDiffusionServe(
safety_embeddings_drive=self.safety_embeddings_drive,
safety_embeddings_filename=self.safety_checker_embedding_work.safety_embeddings_filename,
cloud_compute=L.CloudCompute(self.gpu_type, disk_size=30),
cache_calls=True,
parallel=True,
)
new_work_id = self.add_work(work)
print("new work id:", new_work_id)
# downscale
elif num_requests < self.autoscale_down_limit and num_workers > self._initial_num_workers:
idx = self._num_workers - 1
print(f"Downscale to {idx}")
print("prev num servers:", len(self.model_servers))
removed_id = self.remove_work(idx)
print("removed:", removed_id)
print("new num servers:", len(self.model_servers))
self.load_balancer.update_servers(self.model_servers)
self._last_autoscale = time.time()
if __name__ == "__main__":
app = L.LightningApp(
MuseFlow(),
info=AppInfo(
title="Use AI to inspire your art.",
favicon="https://storage.googleapis.com/grid-static/muse/favicon.ico",
description="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.",
image="https://storage.googleapis.com/grid-static/header.png",
meta_tags=[
'<meta name="theme-color" content="#792EE5" />',
'<meta name="image" content="https://storage.googleapis.com/grid-static/header.png">'
'<meta itemprop="name" content="Use AI to inspire your art.">'
'<meta itemprop="description" content="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.">' # noqa
'<meta itemprop="image" content="https://storage.googleapis.com/grid-static/header.png">'
# <!-- Twitter -->
'<meta name="twitter:card" content="summary">'
'<meta name="twitter:title" content="Use AI to inspire your art.">'
'<meta name="twitter:description" content="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.">' # noqa
'<meta name="twitter:site" content="https://lightning.ai/muse">'
'<meta name="twitter:domain" content="https://lightning.ai/muse">'
'<meta name="twitter:creator" content="@LightningAI">'
'<meta name="twitter:image:src" content="https://storage.googleapis.com/grid-static/header.png">'
# <!-- Open Graph general (Facebook, Pinterest & Google+) -->
'<meta name="og:title" content="Use AI to inspire your art.">'
'<meta name="og:description" content="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.">' # noqa
'<meta name="og:url" content="https://lightning.ai/muse">'
'<meta property="og:image" content="https://storage.googleapis.com/grid-static/header.png" />',
'<meta property="og:image:type" content="image/png" />',
'<meta property="og:image:height" content="1114" />'
'<meta property="og:image:width" content="1112" />',
*(analytics_headers if ENABLE_ANALYTICS else []),
],
),
root_path=os.getenv("MUSE_ROOT_PATH", ""),
)