Skip to content

Commit

Permalink
update frontend to match backend, changed scheduler list to match the…
Browse files Browse the repository at this point in the history
… updated schedulers
  • Loading branch information
derrian-distro committed Jul 21, 2024
1 parent 80e2c49 commit cbd4be9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 95 deletions.
127 changes: 38 additions & 89 deletions main_ui_files/OptimizerUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,73 +30,45 @@ def __init__(self, parent: QWidget = None) -> None:
def setup_widget(self) -> None:
super().setup_widget()
self.widget.setupUi(self.content)
self.widget.optimizer_item_widget.layout().setAlignment(
QtCore.Qt.AlignmentFlag.AlignTop
)
self.widget.optimizer_item_widget.layout().setAlignment(QtCore.Qt.AlignmentFlag.AlignTop)
for opt_arg in self.opt_args:
self.widget.optimizer_item_widget.layout().addWidget(opt_arg)
opt_arg.delete_item.connect(self.remove_optimizer_arg)
opt_arg.item_updated.connect(self.modify_optimizer_args)
self.modify_optimizer_args()

def setup_connections(self) -> None:
self.widget.optimizer_type_selector.currentTextChanged.connect(
self.change_optimizer
)
self.widget.lr_scheduler_selector.currentTextChanged.connect(
self.change_scheduler
)
self.widget.optimizer_type_selector.currentTextChanged.connect(self.change_optimizer)
self.widget.lr_scheduler_selector.currentTextChanged.connect(self.change_scheduler)
self.widget.loss_type_selector.currentTextChanged.connect(self.change_loss_type)
self.widget.main_lr_input.textChanged.connect(
lambda x: self.edit_lr("learning_rate", x)
)
self.widget.main_lr_input.textChanged.connect(lambda x: self.edit_lr("learning_rate", x))
self.widget.warmup_enable.clicked.connect(self.enable_disable_warmup)
self.widget.warmup_input.valueChanged.connect(
lambda x: self.edit_args("warmup_ratio", round(x, 2), True)
)
self.widget.min_lr_input.textChanged.connect(
lambda x: self.edit_lr_args("min_lr", x, True)
)
self.widget.min_lr_input.textChanged.connect(lambda x: self.edit_lr_args("min_lr", x, True))
self.widget.cosine_restart_input.valueChanged.connect(
lambda x: self.edit_args("lr_scheduler_num_cycles", x)
)
self.widget.unet_lr_enable.clicked.connect(self.enable_disable_unet)
self.widget.unet_lr_input.textChanged.connect(
lambda x: self.edit_lr("unet_lr", x, True)
)
self.widget.poly_power_input.valueChanged.connect(
lambda x: self.edit_args("lr_scheduler_power", x)
)
self.widget.unet_lr_input.textChanged.connect(lambda x: self.edit_lr("unet_lr", x, True))
self.widget.poly_power_input.valueChanged.connect(lambda x: self.edit_args("lr_scheduler_power", x))
self.widget.te_lr_enable.clicked.connect(self.enable_disable_te)
self.widget.te_lr_input.textChanged.connect(
lambda x: self.edit_lr("text_encoder_lr", x)
)
self.widget.gamma_input.valueChanged.connect(
lambda x: self.edit_lr_args("gamma", 1 - x)
)
self.widget.scale_weight_enable.clicked.connect(
self.enable_disable_scale_weight_norms
)
self.widget.te_lr_input.textChanged.connect(lambda x: self.edit_lr("text_encoder_lr", x))
self.widget.gamma_input.valueChanged.connect(lambda x: self.edit_lr_args("gamma", 1 - x))
self.widget.scale_weight_enable.clicked.connect(self.enable_disable_scale_weight_norms)
self.widget.scale_weight_input.valueChanged.connect(
lambda x: self.edit_args("scale_weight_norms", x, True)
)
self.widget.max_grad_norm_input.valueChanged.connect(
lambda x: self.edit_args("max_grad_norm", x)
)
self.widget.max_grad_norm_input.valueChanged.connect(lambda x: self.edit_args("max_grad_norm", x))
self.widget.min_snr_enable.clicked.connect(self.enable_disable_min_snr_gamma)
self.widget.min_snr_input.valueChanged.connect(
lambda x: self.edit_args("min_snr_gamma", x)
)
self.widget.zero_term_enable.clicked.connect(
lambda x: self.edit_args("zero_terminal_snr", x, True)
)
self.widget.min_snr_input.valueChanged.connect(lambda x: self.edit_args("min_snr_gamma", x))
self.widget.zero_term_enable.clicked.connect(lambda x: self.edit_args("zero_terminal_snr", x, True))
self.widget.masked_loss_enable.clicked.connect(self.enable_disable_masked_loss)
self.widget.huber_schedule_selector.currentTextChanged.connect(
lambda x: self.edit_args("huber_schedule", x.lower())
)
self.widget.huber_param_input.valueChanged.connect(
lambda x: self.edit_args("huber_c", round(x, 4))
)
self.widget.huber_param_input.valueChanged.connect(lambda x: self.edit_args("huber_c", round(x, 4)))
self.widget.add_opt_button.clicked.connect(self.add_optimizer_arg)

def edit_lr(self, name: str, value: str, optional: bool = False) -> None:
Expand All @@ -123,10 +95,7 @@ def edit_lr_args(self, name: str, value: object, optional: bool = False) -> None
@Slot(object)
def remove_optimizer_arg(self, widget: OptimizerItem):
self.layout().removeWidget(widget)
if (
"optimizer_args" in self.args
and widget.arg_name in self.args["optimizer_args"]
):
if "optimizer_args" in self.args and widget.arg_name in self.args["optimizer_args"]:
del self.args["optimizer_args"][widget.arg_name]
widget.deleteLater()
self.opt_args.remove(widget)
Expand Down Expand Up @@ -179,13 +148,13 @@ def change_scheduler(self, value: str) -> None:
self.widget.cosine_restart_input.value(),
True,
)
elif value == "cosine_annealing_warmup_restarts":
elif value == "cosine_annealing_warmup_restarts_(CAWR)":
self.widget.cosine_restart_input.setEnabled(True)
self.widget.min_lr_input.setEnabled(True)
self.widget.gamma_input.setEnabled(True)
self.edit_args(
"lr_scheduler_type",
"LoraEasyCustomOptimizer.CustomOptimizers.CosineAnnealingWarmupRestarts",
"LoraEasyCustomOptimizer.CosineAnnealingWarmRestarts.CosineAnnealingWarmRestarts",
)
self.edit_lr_args("min_lr", self.widget.min_lr_input.text(), True)
self.edit_args(
Expand All @@ -195,19 +164,19 @@ def change_scheduler(self, value: str) -> None:
)
self.edit_lr_args("gamma", 1 - self.widget.gamma_input.value(), True)
return
elif value == "polynomial":
self.widget.poly_power_input.setEnabled(True)
self.edit_args(
"lr_scheduler_power", self.widget.poly_power_input.value(), True
)
elif value == "rex":
elif value == "rex_annealing_warm_restarts_(RAWR)":
self.widget.cosine_restart_input.setEnabled(True)
self.widget.min_lr_input.setEnabled(True)
self.widget.gamma_input.setEnabled(True)
self.edit_args(
"lr_scheduler_type",
"LoraEasyCustomOptimizer.CustomOptimizers.Rex",
"LoraEasyCustomOptimizer.RexAnnealingWarmRestarts.RexAnnealingWarmRestarts",
)
self.edit_lr_args("min_lr", self.widget.min_lr_input.text(), True)
self.edit_lr_args("gamma", 1 - self.widget.gamma_input.value(), True)
return
elif value == "polynomial":
self.widget.poly_power_input.setEnabled(True)
self.edit_args("lr_scheduler_power", self.widget.poly_power_input.value(), True)
self.edit_args("lr_scheduler", value)

def change_loss_type(self, value: str) -> None:
Expand All @@ -221,9 +190,7 @@ def change_loss_type(self, value: str) -> None:
self.edit_args("loss_type", value)
if value == "l2":
return
self.edit_args(
"huber_schedule", self.widget.huber_schedule_selector.currentText().lower()
)
self.edit_args("huber_schedule", self.widget.huber_schedule_selector.currentText().lower())
self.edit_args("huber_c", round(self.widget.huber_param_input.value(), 4))

@Slot(bool)
Expand Down Expand Up @@ -260,9 +227,7 @@ def enable_disable_scale_weight_norms(self, checked: bool) -> None:
self.widget.scale_weight_input.setEnabled(checked)
if not checked:
return
self.edit_args(
"scale_weight_norms", self.widget.scale_weight_input.value(), True
)
self.edit_args("scale_weight_norms", self.widget.scale_weight_input.value(), True)

@Slot(bool)
def enable_disable_min_snr_gamma(self, checked: bool) -> None:
Expand All @@ -288,46 +253,34 @@ def load_args(self, args: dict) -> bool:
)
if "lr_scheduler_type" in args:
self.widget.lr_scheduler_selector.setCurrentText(
"cosine annealing warmup restarts"
if args["lr_scheduler_type"].split(".")[-1] != "Rex"
else "rex"
"cosine annealing warm restarts (CAWR)"
if args["lr_scheduler_type"].split(".")[-1] != "RexAnnealingWarmRestarts"
else "rex annealing warm restarts (RAWR)"
)
else:
self.widget.lr_scheduler_selector.setCurrentText(
args.get("lr_scheduler", "cosine").replace("_", " ")
)
self.widget.loss_type_selector.setCurrentText(
args.get("loss_type", "L2").replace("_", " ").title()
)
self.widget.loss_type_selector.setCurrentText(args.get("loss_type", "L2").replace("_", " ").title())
self.widget.main_lr_input.setText(str(args.get("learning_rate", "1e-4")))
self.widget.warmup_enable.setChecked(bool(args.get("warmup_ratio", False)))
self.widget.warmup_input.setValue(args.get("warmup_ratio", 0.0))
self.widget.min_lr_input.setText(
str(args.get("lr_scheduler_args", {}).get("min_lr", "1e-6"))
)
self.widget.cosine_restart_input.setValue(
args.get("lr_scheduler_num_cycles", 1)
)
self.widget.min_lr_input.setText(str(args.get("lr_scheduler_args", {}).get("min_lr", "1e-6")))
self.widget.cosine_restart_input.setValue(args.get("lr_scheduler_num_cycles", 1))
self.widget.unet_lr_enable.setChecked(bool(args.get("unet_lr", False)))
self.widget.unet_lr_input.setText(str(args.get("unet_lr", "1e-4")))
self.widget.poly_power_input.setValue(args.get("lr_scheduler_power", 1.0))
self.widget.te_lr_enable.setChecked(bool(args.get("text_encoder_lr", False)))
self.widget.te_lr_input.setText(str(args.get("text_encoder_lr", "1e-4")))
self.widget.gamma_input.setValue(
round(1 - args.get("lr_scheduler_args", {}).get("gamma", 0.9), 2)
)
self.widget.scale_weight_enable.setChecked(
bool(args.get("scale_weight_norms", False))
)
self.widget.gamma_input.setValue(round(1 - args.get("lr_scheduler_args", {}).get("gamma", 0.9), 2))
self.widget.scale_weight_enable.setChecked(bool(args.get("scale_weight_norms", False)))
self.widget.scale_weight_input.setValue(args.get("scale_weight_norms", 1.0))
self.widget.max_grad_norm_input.setValue(args.get("max_grad_norm", 1.0))
self.widget.min_snr_enable.setChecked(bool(args.get("min_snr_gamma", False)))
self.widget.min_snr_input.setValue(args.get("min_snr_gamma", 5))
self.widget.zero_term_enable.setChecked(args.get("zero_terminal_snr", False))
self.widget.huber_schedule_selector.setCurrentIndex(
{"snr": 0, "exponential": 1, "constant": 2}.get(
args.get("huber_schedule", "snr").lower(), 0
)
{"snr": 0, "exponential": 1, "constant": 2}.get(args.get("huber_schedule", "snr").lower(), 0)
)
self.widget.huber_param_input.setValue(args.get("huber_c", 0.1))

Expand All @@ -349,13 +302,9 @@ def load_args(self, args: dict) -> bool:
self.enable_disable_warmup(self.widget.warmup_enable.isChecked())
self.enable_disable_unet(self.widget.unet_lr_enable.isChecked())
self.enable_disable_te(self.widget.te_lr_enable.isChecked())
self.enable_disable_scale_weight_norms(
self.widget.scale_weight_enable.isChecked()
)
self.enable_disable_scale_weight_norms(self.widget.scale_weight_enable.isChecked())
self.edit_args("max_grad_norm", self.widget.max_grad_norm_input.value())
self.enable_disable_min_snr_gamma(self.widget.min_snr_enable.isChecked())
self.edit_args(
"zero_terminal_snr", self.widget.zero_term_enable.isChecked(), True
)
self.edit_args("zero_terminal_snr", self.widget.zero_term_enable.isChecked(), True)
self.modify_optimizer_args()
return True
6 changes: 3 additions & 3 deletions ui_files/OptimizerUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'OptimizerUI.ui'
##
## Created by: Qt User Interface Compiler version 6.7.0
## Created by: Qt User Interface Compiler version 6.7.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
Expand Down Expand Up @@ -382,13 +382,13 @@ def retranslateUi(self, optimizer_ui):
self.lr_scheduler_label.setText(QCoreApplication.translate("optimizer_ui", u"LR Scheduler", None))
self.lr_scheduler_selector.setItemText(0, QCoreApplication.translate("optimizer_ui", u"cosine", None))
self.lr_scheduler_selector.setItemText(1, QCoreApplication.translate("optimizer_ui", u"cosine with restarts", None))
self.lr_scheduler_selector.setItemText(2, QCoreApplication.translate("optimizer_ui", u"cosine annealing warmup restarts", None))
self.lr_scheduler_selector.setItemText(2, QCoreApplication.translate("optimizer_ui", u"cosine annealing warm restarts (CAWR)", None))
self.lr_scheduler_selector.setItemText(3, QCoreApplication.translate("optimizer_ui", u"linear", None))
self.lr_scheduler_selector.setItemText(4, QCoreApplication.translate("optimizer_ui", u"constant", None))
self.lr_scheduler_selector.setItemText(5, QCoreApplication.translate("optimizer_ui", u"constant with warmup", None))
self.lr_scheduler_selector.setItemText(6, QCoreApplication.translate("optimizer_ui", u"adafactor", None))
self.lr_scheduler_selector.setItemText(7, QCoreApplication.translate("optimizer_ui", u"polynomial", None))
self.lr_scheduler_selector.setItemText(8, QCoreApplication.translate("optimizer_ui", u"rex", None))
self.lr_scheduler_selector.setItemText(8, QCoreApplication.translate("optimizer_ui", u"rex annealing warm restarts (RAWR)", None))

#if QT_CONFIG(tooltip)
self.lr_scheduler_selector.setToolTip(QCoreApplication.translate("optimizer_ui", u"<html><head/><body><p>The scheduler for the lr. The ones I use personally are cosine and cosine with restarts.</p></body></html>", None))
Expand Down
4 changes: 2 additions & 2 deletions ui_files/OptimizerUI.ui
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@
</item>
<item>
<property name="text">
<string>cosine annealing warmup restarts</string>
<string>cosine annealing warm restarts (CAWR)</string>
</property>
</item>
<item>
Expand Down Expand Up @@ -343,7 +343,7 @@
</item>
<item>
<property name="text">
<string>rex</string>
<string>rex annealing warm restarts (RAWR)</string>
</property>
</item>
</widget>
Expand Down

0 comments on commit cbd4be9

Please sign in to comment.