Skip to content

Commit

Permalink
add support for masked loss, update backend to update the steps calc …
Browse files Browse the repository at this point in the history
…code to fall in line with kohya
  • Loading branch information
derrian-distro committed May 19, 2024
1 parent 270978e commit 113ecfd
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 100 deletions.
2 changes: 1 addition & 1 deletion backend
Submodule backend updated 2 files
+15 −10 main.py
+47 −26 utils/validation.py
7 changes: 6 additions & 1 deletion main_ui_files/ArgsListUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class ArgsWidget(QtWidgets.QWidget):
sdxlChecked = Signal(bool)
cacheLatentsChecked = Signal(bool)
keepTokensSepChecked = Signal(bool)
maskedLossChecked = Signal(bool)

def __init__(self, parent: QtWidgets.QWidget = None) -> None:
super().__init__(parent)
self.scroll_area = QtWidgets.QScrollArea()
self.scroll_widget = QtWidgets.QWidget()
self.args_widget_array: list[BaseWidget] = []
self.network_widget = NetworkWidget()
self.optimizer_widget = OptimizerWidget()
self.ti_widget = TextualInversionWidget()
self.ti_widget.setVisible(False)

Expand Down Expand Up @@ -51,11 +53,14 @@ def setup_args_widgets(self) -> None:
general_args.keepTokensSepChecked.connect(
lambda x: self.keepTokensSepChecked.emit(x)
)
self.optimizer_widget.maskedLossChecked.connect(
lambda x: self.maskedLossChecked.emit(x)
)
self.args_widget_array.append(general_args)
self.sdxlChecked.connect(self.network_widget.toggle_sdxl)
self.args_widget_array.append(self.network_widget)
self.args_widget_array.append(self.ti_widget)
self.args_widget_array.append(OptimizerWidget())
self.args_widget_array.append(self.optimizer_widget)
self.args_widget_array.append(SavingWidget())
self.args_widget_array.append(BucketWidget())
self.args_widget_array.append(NoiseOffsetWidget())
Expand Down
3 changes: 3 additions & 0 deletions main_ui_files/MainUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def setup_connections(self) -> None:
self.args_widget.keepTokensSepChecked.connect(
self.subset_widget.enable_disable_variable_keep_tokens
)
self.args_widget.maskedLossChecked.connect(
self.subset_widget.enable_disable_masked_loss
)
self.queue_widget.saveQueue.connect(lambda x: self.save_toml(Path(x)))
self.queue_widget.loadQueue.connect(lambda x: self.load_toml(Path(x)))
self.begin_training_button.clicked.connect(self.start_training)
Expand Down
16 changes: 12 additions & 4 deletions main_ui_files/OptimizerUI.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from PySide6.QtCore import Slot
from PySide6.QtCore import Slot, Signal
from PySide6 import QtCore
from PySide6.QtWidgets import QWidget
from ui_files.OptimizerUI import Ui_optimizer_ui
Expand All @@ -9,6 +9,8 @@


class OptimizerWidget(BaseWidget):
maskedLossChecked = Signal(bool)

def __init__(self, parent: QWidget = None) -> None:
super().__init__(parent)
self.colap.set_title("Optimizer Args")
Expand Down Expand Up @@ -90,6 +92,7 @@ def setup_connections(self) -> None:
self.widget.zero_term_enable.clicked.connect(
lambda x: self.edit_args("zero_terminal_snr", x, True)
)
self.widget.masked_loss_enable.clicked.connect(self.enable_disable_masked_loss)
self.widget.huber_schedule_selector.currentTextChanged.connect(
lambda x: self.edit_args("huber_schedule", x.lower())
)
Expand Down Expand Up @@ -223,10 +226,10 @@ def change_loss_type(self, value: str) -> None:
for arg in args:
if arg in self.args:
del self.args[arg]
self.widget.huber_schedule_selector.setEnabled(value == "huber")
self.widget.huber_param_input.setEnabled(value == "huber")
self.widget.huber_schedule_selector.setEnabled(value != "l2")
self.widget.huber_param_input.setEnabled(value != "l2")
self.edit_args("loss_type", value)
if value != "huber":
if value == "l2":
return
self.edit_args(
"huber_schedule", self.widget.huber_schedule_selector.currentText().lower()
Expand Down Expand Up @@ -280,6 +283,11 @@ def enable_disable_min_snr_gamma(self, checked: bool) -> None:
return
self.edit_args("min_snr_gamma", self.widget.min_snr_input.value(), True)

@Slot(bool)
def enable_disable_masked_loss(self, checked: bool) -> None:
self.edit_args("masked_loss", checked, True)
self.maskedLossChecked.emit(checked)

def load_args(self, args: dict) -> bool:
args: dict = args.get(self.name, {})

Expand Down
18 changes: 16 additions & 2 deletions main_ui_files/SubsetListUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, parent: QWidget = None) -> None:
QtCore.Qt.AlignmentFlag.AlignTop
)
self.widget.subset_scroll_area.setWidgetResizable(True)
self.masked_loss_checked = False
self.cache_latents_checked = False
self.variable_keep_tokens_checked = False
self.args = {}
Expand All @@ -38,7 +39,12 @@ def add_empty_subset(self, display_name: str = "") -> SubsetWidget:
subset = SubsetWidget(display_name=display_name, name=name)
subset.colap.extra_elem.clicked.connect(lambda: self.remove_subset(subset))
subset.edited.connect(self.update_args)
subset.enable_disable_cache_dependants(self.cache_latents_checked)

subset.enable_disable_masked_loss(self.masked_loss_checked)
subset.enable_disable_random_crop(
any([self.masked_loss_checked, self.cache_latents_checked])
)
subset.enable_disable_color_aug(self.cache_latents_checked)
subset.enable_disable_keep_tokens(self.variable_keep_tokens_checked)
if not self.elements:
subset.colap.toggle_collapsed()
Expand Down Expand Up @@ -75,10 +81,17 @@ def add_from_root_folder(self) -> None:
def update_args(self, subset_args: dict, subset_name: str) -> None:
self.dataset_args[subset_name] = subset_args

def enable_disable_masked_loss(self, checked: bool) -> None:
self.masked_loss_checked = checked
for elem in self.elements:
elem.enable_disable_masked_loss(checked)
elem.enable_disable_random_crop(any([checked, self.cache_latents_checked]))

def enable_disable_cache_latents(self, checked: bool) -> None:
self.cache_latents_checked = checked
for elem in self.elements:
elem.enable_disable_cache_dependants(checked)
elem.enable_disable_random_crop(any([checked, self.masked_loss_checked]))
elem.enable_disable_color_aug(checked)

def enable_disable_variable_keep_tokens(self, checked: bool) -> None:
self.variable_keep_tokens_checked = checked
Expand All @@ -103,6 +116,7 @@ def load_dataset_args(self, dataset_args: dict) -> bool:
)
elem = self.add_empty_subset(subset_name)
elem.load_dataset_args(subset)
self.enable_disable_masked_loss(self.masked_loss_checked)
self.enable_disable_cache_latents(self.cache_latents_checked)
self.enable_disable_variable_keep_tokens(self.variable_keep_tokens_checked)
return True
87 changes: 60 additions & 27 deletions main_ui_files/SubsetUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from PySide6.QtGui import QIcon
from PySide6.QtWidgets import QWidget, QFileDialog
from PySide6.QtCore import Signal
from modules.BaseWidget import BaseWidget
from modules.DragDropLineEdit import DragDropLineEdit
from modules.BaseWidget import BaseWidget
from ui_files.sub_dataset_input import Ui_sub_dataset_input
from ui_files.sub_dataset_extra_input import Ui_sub_dataset_extra_input

Expand Down Expand Up @@ -43,19 +43,31 @@ def setup_widget(self) -> None:
self.widget.image_folder_selector.setIcon(
QIcon(str(Path("icons/more-horizontal.svg")))
)
self.widget.masked_image_input.setMode("folder")
self.widget.masked_image_input.highlight = True
self.widget.masked_image_selector.setIcon(
QIcon(str(Path("icons/more-horizontal.svg")))
)
self.extra_widget.face_crop_group.setChecked(False)
self.extra_widget.caption_dropout_group.setChecked(False)
self.extra_widget.token_warmup_group.setChecked(False)

def setup_connections(self) -> None:
self.widget.image_folder_input.textChanged.connect(
lambda x: self.edit_dataset_args("image_dir", x, optional=True)
)
self.widget.image_folder_input.editingFinished.connect(
lambda: self.check_validity(self.widget.image_folder_input)
lambda x: self.edit_dataset_args("image_dir", x, True)
)
self.widget.image_folder_selector.clicked.connect(
lambda: self.set_folder_from_dialog("Subset Image Folder")
lambda: self.set_folder_from_dialog(
"Subset Image Folder", self.widget.image_folder_input
)
)
self.widget.masked_image_input.textChanged.connect(
lambda x: self.edit_dataset_args("conditioning_data_dir", x, True)
)
self.widget.masked_image_selector.clicked.connect(
lambda: self.set_folder_from_dialog(
"Masked Image Folder", self.widget.masked_image_input, False
)
)
self.widget.repeats_input.valueChanged.connect(
lambda x: self.edit_dataset_args("num_repeats", x)
Expand Down Expand Up @@ -110,24 +122,23 @@ def setup_connections(self) -> None:
lambda x: self.edit_dataset_args("token_warmup_step", x)
)

def check_validity(self, elem: DragDropLineEdit) -> None:
elem.dirty = True
if not elem.allow_empty or elem.text() != "":
elem.update_stylesheet()
else:
elem.setStyleSheet("")

def edit_dataset_args(
self, name: str, value: object, optional: bool = False
) -> None:
super().edit_dataset_args(name, value, optional)
self.edited.emit(self.dataset_args, self.name)

def set_folder_from_dialog(self, title_str: str = "", path: Path = None) -> None:
def set_folder_from_dialog(
self,
title_str: str,
element: DragDropLineEdit,
calc_repeats: bool = True,
path: Path = None,
) -> None:
if path and path.exists():
file_name = path
else:
default_dir = Path(self.widget.image_folder_input.text())
default_dir = Path(element.text())
file_name = QFileDialog.getExistingDirectory(
self,
title_str,
Expand All @@ -136,11 +147,24 @@ def set_folder_from_dialog(self, title_str: str = "", path: Path = None) -> None
if not file_name:
return
file_name = Path(file_name)
element.setText(file_name.as_posix())
element.update_stylesheet()
if not calc_repeats:
return
with contextlib.suppress(ValueError):
repeats = int(file_name.name.split("_")[0])
self.widget.repeats_input.setValue(repeats)
self.widget.image_folder_input.setText(file_name.as_posix())
self.widget.image_folder_input.update_stylesheet()

def enable_disable_masked_loss(self, checked: bool) -> None:
if "conditioning_data_dir" in self.dataset_args:
del self.dataset_args["conditioning_data_dir"]
self.widget.masked_image_input.setEnabled(checked)
self.widget.masked_image_selector.setEnabled(checked)
self.edit_dataset_args(
"conditioning_data_dir",
self.widget.masked_image_input.text() if checked else False,
True,
)

def enable_disable_face_crop(self, checked: bool) -> None:
if "face_crop_aug_range" in self.dataset_args:
Expand Down Expand Up @@ -190,27 +214,30 @@ def enable_disable_token_warmup(self, checked: bool) -> None:
args[1], self.extra_widget.token_warmup_step_input.value()
)

def enable_disable_cache_dependants(self, checked: bool) -> None:
self.widget.color_augment_enable.setEnabled(not checked)
def enable_disable_random_crop(self, checked: bool) -> None:
if "random_crop" in self.dataset_args:
del self.dataset_args["random_crop"]
self.widget.random_crop_enable.setEnabled(not checked)
for arg in ["color_aug", "random_crop"]:
if arg in self.dataset_args:
del self.dataset_args[arg]
self.edit_dataset_args(
"color_aug",
False if checked else self.widget.color_augment_enable.isChecked(),
"random_crop",
False if checked else self.widget.random_crop_enable.isChecked(),
True,
)

def enable_disable_color_aug(self, checked: bool) -> None:
if "color_aug" in self.dataset_args:
del self.dataset_args["color_aug"]
self.widget.color_augment_enable.setEnabled(not checked)
self.edit_dataset_args(
"random_crop",
False if checked else self.widget.random_crop_enable.isChecked(),
"color_aug",
False if checked else self.widget.color_augment_enable.isChecked(),
True,
)

def enable_disable_keep_tokens(self, checked: bool) -> None:
self.widget.keep_tokens_input.setEnabled(not checked)
if "keep_tokens" in self.dataset_args:
del self.dataset_args["keep_tokens"]
self.widget.keep_tokens_input.setEnabled(not checked)
self.edit_dataset_args(
"keep_tokens",
False if checked else self.widget.keep_tokens_input.value(),
Expand All @@ -220,6 +247,9 @@ def enable_disable_keep_tokens(self, checked: bool) -> None:
def load_dataset_args(self, dataset_args: dict) -> bool:
# update element inputs
self.widget.image_folder_input.setText(dataset_args.get("image_dir", ""))
self.widget.masked_image_input.setText(
dataset_args.get("conditioning_data_dir", "")
)
self.widget.repeats_input.setValue(dataset_args.get("num_repeats", 1))
self.widget.shuffle_captions_enable.setChecked(
dataset_args.get("shuffle_caption", False)
Expand Down Expand Up @@ -280,6 +310,9 @@ def load_dataset_args(self, dataset_args: dict) -> bool:

# edit dataset args to match
self.edit_dataset_args("image_dir", self.widget.image_folder_input.text(), True)
self.edit_dataset_args(
"conditioning_data_dir", self.widget.masked_image_input.text(), True
)
self.edit_dataset_args("num_repeats", self.widget.repeats_input.value())
self.edit_dataset_args(
"shuffle_caption", self.widget.shuffle_captions_enable.isChecked(), True
Expand Down
21 changes: 16 additions & 5 deletions ui_files/OptimizerUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
QImage, QKeySequence, QLinearGradient, QPainter,
QPalette, QPixmap, QRadialGradient, QTransform)
from PySide6.QtWidgets import (QApplication, QCheckBox, QFormLayout, QGridLayout,
QLabel, QPushButton, QScrollArea, QSizePolicy,
QVBoxLayout, QWidget)
QHBoxLayout, QLabel, QPushButton, QScrollArea,
QSizePolicy, QVBoxLayout, QWidget)

from modules.LineEditHighlight import LineEditWithHighlight
from modules.ScrollOnSelect import (ComboBox, DoubleSpinBox, SpinBox, TabView)
Expand All @@ -26,7 +26,7 @@ class Ui_optimizer_ui(object):
def setupUi(self, optimizer_ui):
if not optimizer_ui.objectName():
optimizer_ui.setObjectName(u"optimizer_ui")
optimizer_ui.resize(421, 350)
optimizer_ui.resize(463, 354)
self.verticalLayout = QVBoxLayout(optimizer_ui)
self.verticalLayout.setObjectName(u"verticalLayout")
self.verticalLayout.setContentsMargins(0, 0, 0, 0)
Expand Down Expand Up @@ -256,10 +256,20 @@ def setupUi(self, optimizer_ui):

self.formLayout_3.setWidget(4, QFormLayout.FieldRole, self.max_grad_norm_input)

self.horizontalLayout_2 = QHBoxLayout()
self.horizontalLayout_2.setObjectName(u"horizontalLayout_2")
self.zero_term_enable = QCheckBox(self.optimizer_tab_main)
self.zero_term_enable.setObjectName(u"zero_term_enable")

self.formLayout_3.setWidget(5, QFormLayout.SpanningRole, self.zero_term_enable)
self.horizontalLayout_2.addWidget(self.zero_term_enable)

self.masked_loss_enable = QCheckBox(self.optimizer_tab_main)
self.masked_loss_enable.setObjectName(u"masked_loss_enable")

self.horizontalLayout_2.addWidget(self.masked_loss_enable)


self.formLayout_3.setLayout(5, QFormLayout.SpanningRole, self.horizontalLayout_2)

self.label_6 = QLabel(self.optimizer_tab_main)
self.label_6.setObjectName(u"label_6")
Expand Down Expand Up @@ -294,7 +304,7 @@ def setupUi(self, optimizer_ui):
self.scrollArea.setWidgetResizable(True)
self.optimizer_item_widget = QWidget()
self.optimizer_item_widget.setObjectName(u"optimizer_item_widget")
self.optimizer_item_widget.setGeometry(QRect(0, 0, 363, 233))
self.optimizer_item_widget.setGeometry(QRect(0, 0, 98, 28))
self.verticalLayout_3 = QVBoxLayout(self.optimizer_item_widget)
self.verticalLayout_3.setObjectName(u"verticalLayout_3")
self.scrollArea.setWidget(self.optimizer_item_widget)
Expand Down Expand Up @@ -406,6 +416,7 @@ def retranslateUi(self, optimizer_ui):
#endif // QT_CONFIG(tooltip)
self.label_2.setText(QCoreApplication.translate("optimizer_ui", u"Max Grad Norm", None))
self.zero_term_enable.setText(QCoreApplication.translate("optimizer_ui", u"Zero Term SNR", None))
self.masked_loss_enable.setText(QCoreApplication.translate("optimizer_ui", u"Masked Loss", None))
self.label_6.setText(QCoreApplication.translate("optimizer_ui", u"Huber Param", None))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.optimizer_tab_main), QCoreApplication.translate("optimizer_ui", u"Main Args", None))
self.add_opt_button.setText(QCoreApplication.translate("optimizer_ui", u"Add Optimizer Arg", None))
Expand Down
Loading

0 comments on commit 113ecfd

Please sign in to comment.