Skip to content

Commit

Permalink
Merge pull request #48 from houseofsecrets/develop
Browse files Browse the repository at this point in the history
Display selected checkpoint & vae
  • Loading branch information
Danamir authored May 10, 2023
2 parents 2bfcec9 + 6dd6599 commit 86aa6f9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 11 deletions.
13 changes: 12 additions & 1 deletion scripts/common/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from .utils import load_config, update_size
from .utils import load_config, update_size, fetch_configuration


class State:
Expand All @@ -10,6 +10,7 @@ class State:
configuration = {
"config_file": "config.json",
"config": {},
"webui_config": {}
}
presets = {
"presets_file": "presets.json",
Expand All @@ -21,6 +22,8 @@ class State:
"busy": False,
}
render = {
"checkpoint": None,
"vae": None,
"hr_scales": [],
"hr_scale": 1.0,
"hr_scale_prev": 1.25,
Expand Down Expand Up @@ -175,6 +178,14 @@ def update_settings(self):
with open(self.json_file, "w") as f:
json.dump(settings, f, indent=4)

def update_webui_config(self):
"""
Update webui configuration from the API.
"""
self.configuration["webui_config"] = fetch_configuration(self)
self.render['checkpoint'] = self.configuration["webui_config"]['sd_model_checkpoint']
self.render['vae'] = self.configuration["webui_config"]['sd_vae']

def __setitem__(self, key, value):
setattr(self, key, value)

Expand Down
45 changes: 45 additions & 0 deletions scripts/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import functools
import os
import random
import re
import shutil
import threading
import base64
import io
import json
import time
import math

import requests
from PIL import Image
from psd_tools import PSDImage

Expand Down Expand Up @@ -312,5 +315,47 @@ def get_img2img_json(state):
return json_data


def fetch_configuration(state):
"""
Request current configuration from the webui API.
:return: The configuration JSON.
"""

response = requests.get(url=f'{state.server["url"]}/sdapi/v1/options')
if response.status_code == 200:
r = response.json()
return r
else:
return {}


checkpoint_pattern = re.compile(r'^(?P<dir>.*(?:\\|\/))?(?P<name>.*?)(?P<vae>\.vae)?(?P<ext>\.safetensors|\.pt|\.ckpt) ?(?P<hash>\[[^\]]*\])?.*')


def ckpt_name(name, display_dir=False, display_ext=False, display_hash=False):
"""
Clean checkpoint name.
:param str name: Checkpoint name.
:param bool display_dir: Display full path.
:param bool display_ext: Display checkpoint extension.
:param bool display_hash: Display checkpoint hash.
:return: Cleaned checkpoint name.
"""

replace = ''
if display_dir:
replace += r'\g<dir>'

replace += r'\g<name>'

if display_ext:
replace += r'\g<vae>\g<ext>'

if display_hash:
replace += r' \g<hash>'

return checkpoint_pattern.sub(replace, name)


# Type hinting imports:
from .state import State
26 changes: 16 additions & 10 deletions scripts/views/PygameView.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from PIL import Image, ImageOps
import tkinter as tk
from tkinter import filedialog, simpledialog
from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed
from scripts.common.utils import payload_submit, update_config, save_preset, update_size, new_random_seed, ckpt_name
from scripts.common.cn_requests import fetch_controlnet_models, progress_request, fetch_detect_image, fetch_img2img, post_request
from scripts.common.output_files_utils import autosave_image, save_image
from scripts.common.state import State
Expand Down Expand Up @@ -715,12 +715,16 @@ def display_configuration(self, wrap=True):
:param bool wrap: Wrap long text.
"""

self.state.update_webui_config()

fields = [
'--Prompt',
'state/gen_settings/prompt',
'state/gen_settings/negative_prompt',
'state/gen_settings/seed',
'--Render',
'state/render/checkpoint',
'state/render/vae',
'state/render/render_size',
'settings.steps',
'settings.cfg_scale',
Expand All @@ -733,7 +737,6 @@ def display_configuration(self, wrap=True):
'state/control_net/controlnet_weight',
'state/control_net/controlnet_guidance_end',
'state/render/pixel_perfect',
'--Misc',
'state/detectors/detector'
]

Expand All @@ -759,7 +762,7 @@ def display_configuration(self, wrap=True):

if '.' in field:
field = field.split('.')
var = globals().get(field[0], None)
var = globals().get(field[0], locals().get(field[0], None))
if var is None:
continue

Expand All @@ -779,19 +782,22 @@ def display_configuration(self, wrap=True):
field_value = getattr(self.state, field_components[0])[field_components[1]]
else:
label = field
field_value = globals().get(field, None)

if 'size' in label and isinstance(field_value, tuple) and len(field_value) == 2:
field_value = f"{field_value[0]}x{field_value[1]}"
field_value = globals().get(field, locals().get(field, None))

if field_value is not None:
value = field_value

if label and value is not None:
value = str(value)
# prettify
label = label.replace('_', ' ')
if label.endswith('prompt'):
value = value.replace(', ', ',').replace(',', ', ') # nicer prompt display
elif 'size' in label and isinstance(value, tuple) and len(value) == 2:
value = f"{value[0]}x{value[1]}"
elif label in ('checkpoint', 'vae'):
value = ckpt_name(value)
else:
value = str(value)

# wrap text
if wrap and len(value) > wrap:
Expand Down Expand Up @@ -1180,9 +1186,9 @@ def main(self):
if self.shift_down:
# cycle detectors
self.state.detectors["detector"] = self.state.detectors["list"][(self.state.detectors["list"].index(self.state.detectors["detector"])+1) % len(self.state.detectors["list"])]
self.osd(text=f"ControlNet detector: {self.state.detectors['detector']}")
self.osd(text=f"ControlNet detector: {self.state.detectors['detector'].replace('_', ' ')}")
else:
self.osd(text=f"Detect {self.state.detectors['detector']}")
self.osd(text=f"Detect {self.state.detectors['detector'].replace('_', ' ')}")
detector = str(self.state.detectors['detector'])

t = threading.Thread(target=functools.partial(self.controlnet_detect, detector))
Expand Down

0 comments on commit 86aa6f9

Please sign in to comment.