import os
import time
import uuid
from typing import List, Optional

import lightning as L
import requests
from lightning.app.frontend import StaticWebFrontend
from lightning.app.storage import Drive
from lightning.app.utilities.frontend import AppInfo
from lightning_api_access import APIAccessFrontend

from muse import LoadBalancer, Locust, MuseSlackCommandBot, SafetyCheckerEmbedding, StableDiffusionServe
from muse.CONST import ENABLE_ANALYTICS, MUSE_GPU_TYPE, MUSE_MIN_WORKERS
from muse.utility.analytics import analytics_headers


class ReactUI(L.LightningFlow):
    def configure_layout(self):
        return StaticWebFrontend(os.path.join(os.path.dirname(__file__), "muse", "ui", "build"))


class APIUsageFlow(L.LightningFlow):
    def __init__(self, api_url: str = ""):
        super().__init__()
        self.api_url = api_url

    def configure_layout(self):
        return APIAccessFrontend(
            apis=[
                {
                    "name": "Generate Image",
                    "url": f"{self.api_url}/api/predict",
                    "method": "POST",
                    "request": {"prompt": "cats in hats", "high_quality": "true"},
                    "response": {"image": "data:image/png;base64,<image-actual-content>"},
                }
            ]
        )


class MuseFlow(L.LightningFlow):
    """The MuseFlow is a LightningFlow component that handles all the servers and uses load balancer to spawn up and
    shutdown based on current requests in the queue.

    Args:
        initial_num_workers: Number of works to start when app initializes.
        autoscale_interval: Number of seconds to wait before checking whether to upscale or downscale the works.
        max_batch_size: Number of requests to process at once.
        batch_timeout_secs: Number of seconds to wait before sending the requests to process.
        gpu_type: GPU type to use for the works.
        max_workers: Max numbers of works to spawn to handle the incoming requests.
        autoscale_down_limit: Lower limit to determine when to stop works.
        autoscale_up_limit: Upper limit to determine when to spawn up a new work.
    """

    def __init__(
        self,
        initial_num_workers: int = MUSE_MIN_WORKERS,
        autoscale_interval: int = 1 * 30,
        max_batch_size: int = 12,
        batch_timeout_secs: int = 2,
        gpu_type: str = MUSE_GPU_TYPE,
        max_workers: int = 20,
        autoscale_down_limit: Optional[int] = None,
        autoscale_up_limit: Optional[int] = None,
        load_testing: Optional[bool] = False,
    ):
        super().__init__()
        self.hide_footer_shadow = True
        self.load_balancer_started = False
        self._initial_num_workers = initial_num_workers
        self._num_workers = 0
        self._work_registry = {}
        self.autoscale_interval = autoscale_interval
        self.max_workers = max_workers
        self.autoscale_down_limit = autoscale_down_limit or initial_num_workers
        self.autoscale_up_limit = autoscale_up_limit or initial_num_workers * max_batch_size
        self.load_testing = load_testing or os.getenv("MUSE_LOAD_TESTING", False)
        self.fake_trigger = 0
        self.gpu_type = gpu_type
        self._last_autoscale = time.time()

        # Create Drive to store Safety Checker embeddings
        self.safety_embeddings_drive = Drive("lit://embeddings")

        # Safety Checker Embedding Work to create and store embeddings in the Drive
        self.safety_checker_embedding_work = SafetyCheckerEmbedding(drive=self.safety_embeddings_drive)

        self.load_balancer = LoadBalancer(
            max_batch_size=max_batch_size, batch_timeout_secs=batch_timeout_secs, cache_calls=True, parallel=True
        )
        for i in range(initial_num_workers):
            work = StableDiffusionServe(
                safety_embeddings_drive=self.safety_embeddings_drive,
                safety_embeddings_filename=self.safety_checker_embedding_work.safety_embeddings_filename,
                cloud_compute=L.CloudCompute(gpu_type, disk_size=30),
                cache_calls=True,
                parallel=True,
                start_with_flow=False,
            )
            self.add_work(work)

        self.slack_bot = MuseSlackCommandBot(command="/muse")
        if self.load_testing:
            self.locust = Locust(locustfile="./scripts/locustfile.py")
        self.printed_url = False
        self.slack_bot_url = ""
        self.dream_url = ""
        self.ui = ReactUI()
        self.api_component = APIUsageFlow()

        self.safety_embeddings_ready = False

    @property
    def ready(self) -> bool:
        return self.load_balancer.ready

    @property
    def model_servers(self) -> List[StableDiffusionServe]:
        works = []
        for i in range(self._num_workers):
            work: StableDiffusionServe = self.get_work(i)
            works.append(work)
        return works

    def add_work(self, work) -> str:
        work_attribute = uuid.uuid4().hex
        work_attribute = f"model_serve_{self._num_workers}_{str(work_attribute)}"
        setattr(self, work_attribute, work)
        self._work_registry[self._num_workers] = work_attribute
        self._num_workers += 1
        return work_attribute

    def remove_work(self, index: int) -> str:
        work_attribute = self._work_registry[index]
        del self._work_registry[index]
        work = getattr(self, work_attribute)
        work.stop()
        self._num_workers -= 1
        return work_attribute

    def get_work(self, index: int):
        work_attribute = self._work_registry[index]
        work = getattr(self, work_attribute)
        return work

    def run(self):  # noqa: C901
        if os.environ.get("TESTING_LAI"):
            print("⚡ Lightning Dream App! ⚡")

        # provision these works early
        if not self.load_balancer.is_running:
            self.load_balancer.run([])
        if not self.slack_bot.is_running:
            self.slack_bot.run("")

        if not self.safety_embeddings_ready:
            self.safety_checker_embedding_work.run()

        if not self.safety_embeddings_ready and self.safety_checker_embedding_work.has_succeeded:
            self.safety_embeddings_ready = True
            self.safety_checker_embedding_work.stop()

        for model_serve in self.model_servers:
            model_serve.run()

        if any(model_serve.url for model_serve in self.model_servers) and not self.load_balancer_started:
            # run the load balancer when one of the model servers is ready
            self.load_balancer.run([serve.url for serve in self.model_servers if serve.url])
            self.load_balancer_started = True

        if self.load_balancer.url:  # hack for getting the work url
            self.api_component.api_url = self.load_balancer.url
            self.dream_url = self.load_balancer.url
            if self.slack_bot is not None:
                self.slack_bot.run(self.load_balancer.url)
                self.slack_bot_url = self.slack_bot.url
                if self.slack_bot.url and not self.printed_url:
                    print("Slack Bot Work ready with URL=", self.slack_bot.url)
                    print("model serve url=", self.load_balancer.url)
                    print("API component url=", self.api_component.state_vars["vars"]["_layout"]["target"])
                    self.printed_url = True

        if self.load_testing and self.load_balancer.url:
            self.locust.run(self.load_balancer.url)

        if self.load_balancer.url:
            self.fake_trigger += 1
            self.autoscale()

    def configure_layout(self):
        ui = [{"name": "Muse App" if self.load_testing else None, "content": self.ui}]
        if self.load_testing:
            ui.append({"name": "Locust", "content": self.locust.url})

        return ui

    def autoscale(self):
        """Upscale and down scale model inference works based on the number of requests."""
        if time.time() - self._last_autoscale < self.autoscale_interval:
            return

        self.load_balancer.update_servers(self.model_servers)

        num_requests = int(requests.get(f"{self.load_balancer.url}/num-requests").json())
        num_workers = len(self.model_servers)

        # upscale
        if num_requests > self.autoscale_up_limit and num_workers < self.max_workers:
            idx = self._num_workers
            print(f"Upscale to {self._num_workers + 1}")
            work = StableDiffusionServe(
                safety_embeddings_drive=self.safety_embeddings_drive,
                safety_embeddings_filename=self.safety_checker_embedding_work.safety_embeddings_filename,
                cloud_compute=L.CloudCompute(self.gpu_type, disk_size=30),
                cache_calls=True,
                parallel=True,
            )
            new_work_id = self.add_work(work)
            print("new work id:", new_work_id)

        # downscale
        elif num_requests < self.autoscale_down_limit and num_workers > self._initial_num_workers:
            idx = self._num_workers - 1
            print(f"Downscale to {idx}")
            print("prev num servers:", len(self.model_servers))
            removed_id = self.remove_work(idx)
            print("removed:", removed_id)
            print("new num servers:", len(self.model_servers))
            self.load_balancer.update_servers(self.model_servers)
        self._last_autoscale = time.time()


if __name__ == "__main__":
    app = L.LightningApp(
        MuseFlow(),
        info=AppInfo(
            title="Use AI to inspire your art.",
            favicon="https://storage.googleapis.com/grid-static/muse/favicon.ico",
            description="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.",
            image="https://storage.googleapis.com/grid-static/header.png",
            meta_tags=[
                '<meta name="theme-color" content="#792EE5" />',
                '<meta name="image" content="https://storage.googleapis.com/grid-static/header.png">'
                '<meta itemprop="name" content="Use AI to inspire your art.">'
                '<meta itemprop="description" content="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.">'  # noqa
                '<meta itemprop="image" content="https://storage.googleapis.com/grid-static/header.png">'
                # <!-- Twitter -->
                '<meta name="twitter:card" content="summary">'
                '<meta name="twitter:title" content="Use AI to inspire your art.">'
                '<meta name="twitter:description" content="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.">'  # noqa
                '<meta name="twitter:site" content="https://lightning.ai/muse">'
                '<meta name="twitter:domain" content="https://lightning.ai/muse">'
                '<meta name="twitter:creator" content="@LightningAI">'
                '<meta name="twitter:image:src" content="https://storage.googleapis.com/grid-static/header.png">'
                # <!-- Open Graph general (Facebook, Pinterest & Google+) -->
                '<meta name="og:title" content="Use AI to inspire your art.">'
                '<meta name="og:description" content="Bring your words to life in seconds - powered by Lightning AI and Stable Diffusion.">'  # noqa
                '<meta name="og:url" content="https://lightning.ai/muse">'
                '<meta property="og:image" content="https://storage.googleapis.com/grid-static/header.png" />',
                '<meta property="og:image:type" content="image/png" />',
                '<meta property="og:image:height" content="1114" />'
                '<meta property="og:image:width" content="1112" />',
                *(analytics_headers if ENABLE_ANALYTICS else []),
            ],
        ),
        root_path=os.getenv("MUSE_ROOT_PATH", ""),
    )