From 9f342d644429530532c427a18552d57d260a355a Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Wed, 24 Dec 2014 19:08:17 +0000 Subject: [PATCH 01/15] fixes visualize_shape_model --- menpofit/visualize/widgets/base.py | 330 +++++++++++++++++++---------- 1 file changed, 217 insertions(+), 113 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index cb052a3..c58d199 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -1,4 +1,13 @@ -from menpo.visualize.widgets.helpers import (figure_options, +import numpy as np +from collections import OrderedDict +from IPython.html.widgets import (FloatTextWidget, TextWidget, PopupWidget, + ContainerWidget, TabWidget, FloatSliderWidget, + RadioButtonsWidget, CheckboxWidget, + DropdownWidget, AccordionWidget, ButtonWidget) + +from menpo.visualize.widgets.options import (viewer_options, + format_viewer_options, + figure_options, format_figure_options, figure_options_two_scales, format_figure_options_two_scales, @@ -23,19 +32,10 @@ format_plot_options, save_figure_options, format_save_figure_options) -from menpo.visualize.widgets.base import (_plot_figure, _plot_graph, - _plot_eigenvalues, - _check_n_parameters, - _raw_info_string_to_latex, +from menpo.visualize.widgets.tools import logo, format_logo +from menpo.visualize.widgets.base import (_visualize, _raw_info_string_to_latex, _extract_groups_labels) -from IPython.html.widgets import (FloatTextWidget, TextWidget, PopupWidget, - ContainerWidget, TabWidget, FloatSliderWidget, - RadioButtonsWidget, CheckboxWidget, - DropdownWidget, AccordionWidget, ButtonWidget) -from IPython.display import display, clear_output -import matplotlib.pylab as plt -import numpy as np -from collections import OrderedDict +from menpo.visualize.viewmatplotlib import MatplotlibImageViewer2d # This glyph import is called frequently during visualisation, so we ensure # that we only import it once @@ -43,8 +43,8 @@ def visualize_shape_model(shape_models, n_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(7, 7), - mode='multiple', popup=False, **kwargs): + parameters_bounds=(-3.0, 3.0), figure_size=(6, 4), + mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel shape model. @@ -53,32 +53,28 @@ def visualize_shape_model(shape_models, n_parameters=5, shape_models : `list` of :map:`PCAModel` or subclass The multilevel shape model to be displayed. Note that each level can have different number of components. - n_parameters : `int` or `list` of `int` or None, optional The number of principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt + from matplotlib import collections as mc + # make sure that shape_models is a list even with one member if not isinstance(shape_models, list): shape_models = [shape_models] @@ -93,49 +89,101 @@ def visualize_shape_model(shape_models, n_parameters=5, # the returned n_parameters is a list of len n_levels n_parameters = _check_n_parameters(n_parameters, n_levels, max_n_params) + # initial options dictionaries + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': ['r'], + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': False, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options} + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get params + # get selected level level = 0 if n_levels > 1: level = level_wid.value - def_mode = mode_wid.value - axis_mode = axes_mode_wid.value - parameters_values = model_parameters_wid.parameters_values - x_scale = figure_options_wid.x_scale - y_scale = figure_options_wid.y_scale - axes_visible = figure_options_wid.axes_visible # compute weights + parameters_values = model_parameters_wid.parameters_values weights = (parameters_values * - shape_models[level].eigenvalues[:len(parameters_values)] ** 0.5) + shape_models[level].eigenvalues[:len(parameters_values)] ** + 0.5) # compute the mean mean = shape_models[level].mean() - # select figure - figure_id = plt.figure(save_figure_wid.figure_id.number) - - # invert axis if image mode is enabled - if axis_mode == 1: - plt.gca().invert_yaxis() + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) # compute and show instance - if def_mode == 1: + if mode_wid.value == 1: # Deformation mode # compute instance instance = shape_models[level].instance(weights) # plot if mean_wid.value: - mean.view(image_view=axis_mode == 1, colour_array='y', - **kwargs) - plt.hold(True) - instance.view(image_view=axis_mode == 1, **kwargs) + mean.view( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, image_view=axes_mode_wid.value == 1, + render_lines=tmp1['render_lines'], + line_colour='y', + line_style='solid', line_width=tmp1['line_width'], + render_markers=tmp2['render_markers'], + marker_style=tmp2['marker_style'], + marker_size=tmp2['marker_size'], marker_face_colour='y', + marker_edge_colour='y', + marker_edge_width=tmp2['marker_edge_width'], + render_axes=False, figure_size=None) + + renderer = instance.view( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, image_view=axes_mode_wid.value==1, + render_lines=tmp1['render_lines'], + line_colour=tmp1['line_colour'][0], + line_style=tmp1['line_style'], line_width=tmp1['line_width'], + render_markers=tmp2['render_markers'], + marker_style=tmp2['marker_style'], + marker_size=tmp2['marker_size'], + marker_face_colour=tmp2['marker_face_colour'], + marker_edge_colour=tmp2['marker_edge_colour'], + marker_edge_width=tmp2['marker_edge_width'], + render_axes=tmp3['render_axes'], + axes_font_name=tmp3['axes_font_name'], + axes_font_size=tmp3['axes_font_size'], + axes_font_style=tmp3['axes_font_style'], + axes_font_weight=tmp3['axes_font_weight'], + axes_x_limits=tmp3['axes_x_limits'], + axes_y_limits=tmp3['axes_y_limits'], + figure_size=new_figure_size, + label=None) + + if mean_wid.value and axes_mode_wid.value == 1: + plt.gca().invert_yaxis() # instance range tmp_range = instance.range() @@ -146,8 +194,28 @@ def plot_function(name, value): instance_upper = shape_models[level].instance(weights) # plot - mean.view(image_view=axis_mode == 1, **kwargs) - plt.hold(True) + renderer = mean.view( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, image_view=axes_mode_wid.value == 1, + render_lines=tmp1['render_lines'], + line_colour=tmp1['line_colour'][0], + line_style=tmp1['line_style'], line_width=tmp1['line_width'], + render_markers=tmp2['render_markers'], + marker_style=tmp2['marker_style'], + marker_size=tmp2['marker_size'], + marker_face_colour=tmp2['marker_face_colour'], + marker_edge_colour=tmp2['marker_edge_colour'], + marker_edge_width=tmp2['marker_edge_width'], + render_axes=tmp3['render_axes'], + axes_font_name=tmp3['axes_font_name'], + axes_font_size=tmp3['axes_font_size'], + axes_font_style=tmp3['axes_font_style'], + axes_font_weight=tmp3['axes_font_weight'], + axes_x_limits=tmp3['axes_x_limits'], + axes_y_limits=tmp3['axes_y_limits'], + figure_size=new_figure_size, label=None) + + ax = plt.gca() for p in range(mean.n_points): xm = mean.points[p, 0] ym = mean.points[p, 1] @@ -155,33 +223,27 @@ def plot_function(name, value): yl = instance_lower.points[p, 1] xu = instance_upper.points[p, 0] yu = instance_upper.points[p, 1] - if axis_mode == 1: + if axes_mode_wid.value == 1: # image mode - plt.plot([ym, yl], [xm, xl], 'r-', lw=2) - plt.plot([ym, yu], [xm, xu], 'g-', lw=2) + lines = [[(ym, xm), (yl, xl)], [(ym, xm), (yu, xu)]] else: # point cloud mode - plt.plot([xm, xl], [ym, yl], 'r-', lw=2) - plt.plot([xm, xu], [ym, yu], 'g-', lw=2) + lines = [[(xm, ym), (xl, yl)], [(xm, ym), (xu, yu)]] + lc = mc.LineCollection(lines, colors=('g', 'b'), + linestyles='solid', linewidths=2) + ax.add_collection(lc) # instance range tmp_range = mean.range() - plt.hold(False) - plt.gca().axis('equal') - # set figure size - plt.gcf().set_size_inches([x_scale, y_scale] * np.asarray(figure_size)) - # turn axis on/off - if not axes_visible: - plt.axis('off') plt.show() # save the current figure id - save_figure_wid.figure_id = figure_id + save_figure_wid.renderer[0] = renderer # info_wid string info_txt = r""" - Level: {} out of {}. + Level: {} out of {}. {} components in total. {} active components. {:.1f} % variance kept. @@ -199,33 +261,37 @@ def plot_function(name, value): def plot_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, shape_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + shape_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = shape_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # create options widgets mode_dict = OrderedDict() mode_dict['Deformation'] = 1 mode_dict['Vectors'] = 2 - mode_wid = RadioButtonsWidget(values=mode_dict, description='Mode:', - value=1) + mode_wid = ipywidgets.RadioButtonsWidget(values=mode_dict, + description='Mode:', value=1) mode_wid.on_trait_change(plot_function, 'value') - mean_wid = CheckboxWidget(value=False, description='Show mean shape') + mean_wid = ipywidgets.CheckboxWidget(value=False, + description='Show mean shape') mean_wid.on_trait_change(plot_function, 'value') # controls mean shape checkbox visibility @@ -243,19 +309,27 @@ def mean_visible(name, value): toggle_show_visible=False, plot_eig_visible=True, plot_eig_function=plot_eigenvalues) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - axes_mode_wid = RadioButtonsWidget(values={'Image': 1, 'Point cloud': 2}, - description='Axes mode:', value=1) + + # viewer options widget + axes_mode_wid = ipywidgets.RadioButtonsWidget( + values={'Image': 1, 'Point cloud': 2}, description='Axes mode:', + value=2) axes_mode_wid.on_trait_change(plot_function, 'value') - ch = list(figure_options_wid.children) - ch.insert(3, axes_mode_wid) - figure_options_wid.children = ch - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + viewer_options_all = ipywidgets.ContainerWidget(children=[axes_mode_wid, + viewer_options_wid]) + info_wid = info_print(toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -274,32 +348,33 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') radio_children = [level_wid, mode_wid, mean_wid] else: radio_children = [mode_wid, mean_wid] - radio_wids = ContainerWidget(children=radio_children) - tmp_wid = ContainerWidget(children=[radio_wids, model_parameters_wid]) - wid = TabWidget(children=[tmp_wid, figure_options_wid, info_wid, - save_figure_wid]) + radio_wids = ipywidgets.ContainerWidget(children=radio_children) + tmp_wid = ipywidgets.ContainerWidget(children=[radio_wids, + model_parameters_wid]) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, viewer_options_all, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='Shape Model Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='Shape Model Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['Shape parameters', 'Figure options', 'Model info', + tab_titles = ['Shape parameters', 'Viewer options', 'Model info', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # align widgets tmp_wid.remove_class('vbox') @@ -309,11 +384,12 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=True) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) + border_visible=False, + suboptions_border_visible=True) format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -328,7 +404,7 @@ def update_widgets(name, value): update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + axes_mode_wid.value = 1 def visualize_appearance_model(appearance_models, n_parameters=5, @@ -1997,3 +2073,31 @@ def gridlinestyle_visibility(name, value): # return widget object if asked if return_widget: return wid + + +def _check_n_parameters(n_params, n_levels, max_n_params): + r""" + Checks the maximum number of components per level either of the shape + or the appearance model. It must be None or int or float or a list of + those containing 1 or {n_levels} elements. + """ + str_error = ("n_params must be None or 1 <= int <= max_n_params or " + "a list of those containing 1 or {} elements").format(n_levels) + if not isinstance(n_params, list): + n_params_list = [n_params] * n_levels + elif len(n_params) == 1: + n_params_list = [n_params[0]] * n_levels + elif len(n_params) == n_levels: + n_params_list = n_params + else: + raise ValueError(str_error) + for i, comp in enumerate(n_params_list): + if comp is None: + n_params_list[i] = max_n_params[i] + else: + if isinstance(comp, int): + if comp > max_n_params[i]: + n_params_list[i] = max_n_params[i] + else: + raise ValueError(str_error) + return n_params_list From 2c0b3dd7d10ef9540628004c214cf471df376eea Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Thu, 25 Dec 2014 02:24:21 +0000 Subject: [PATCH 02/15] fixes visualize_appearance_model --- menpofit/visualize/widgets/base.py | 295 +++++++++++++++++------------ 1 file changed, 178 insertions(+), 117 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index c58d199..79f15f7 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -35,7 +35,8 @@ from menpo.visualize.widgets.tools import logo, format_logo from menpo.visualize.widgets.base import (_visualize, _raw_info_string_to_latex, _extract_groups_labels) -from menpo.visualize.viewmatplotlib import MatplotlibImageViewer2d +from menpo.visualize.viewmatplotlib import (MatplotlibImageViewer2d, + sample_colours_from_colourmap) # This glyph import is called frequently during visualisation, so we ensure # that we only import it once @@ -408,44 +409,36 @@ def update_widgets(name, value): def visualize_appearance_model(appearance_models, n_parameters=5, - parameters_bounds=(-3.0, 3.0), - figure_size=(7, 7), mode='multiple', - popup=False, **kwargs): + parameters_bounds=(-3.0, 3.0), figure_size=(6, 4), + mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel appearance model. Parameters ----------- appearance_models : `list` of :map:`PCAModel` or subclass - The multilevel appearance model to be displayed. Note that each level - can have different attributes, e.g. number of parameters, feature type, - number of channels. - + The multilevel appearance model to be displayed. Note that each level can + have different number of components. n_parameters : `int` or `list` of `int` or None, optional The number of principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage # make sure that appearance_models is a list even with one member @@ -462,54 +455,120 @@ def visualize_appearance_model(appearance_models, n_parameters=5, # the returned n_parameters is a list of len n_levels n_parameters = _check_n_parameters(n_parameters, n_levels, max_n_params) - # define plot function + # find initial groups and labels that will be passed to the landmark options + # widget creation + mean_has_landmarks = appearance_models[0].mean().landmarks.n_groups != 0 + if mean_has_landmarks: + all_groups_keys, all_labels_keys = _extract_groups_labels( + appearance_models[0].mean()) + else: + all_groups_keys = [' '] + all_labels_keys = [[' ']] + + # get initial line colours for each available label + if len(all_labels_keys[0]) == 1: + line_colours = ['r'] + else: + line_colours = sample_colours_from_colourmap(len(all_labels_keys[0]), + 'jet') + + # initial options dictionaries + channels_default = 0 + if appearance_models[0].mean().n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': appearance_models[0].mean().n_channels, + 'image_is_masked': isinstance(appearance_models[0].mean(), + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': isinstance(appearance_models[0].mean(), MaskedImage)} + landmark_options_default = {'render_landmarks': mean_has_landmarks, + 'group_keys': all_groups_keys, + 'labels_keys': all_labels_keys, + 'group': None, + 'with_labels': None} + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': line_colours, + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options} + + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected level level = 0 if n_levels > 1: level = level_wid.value - # get parameters values + # compute weights and instance parameters_values = model_parameters_wid.parameters_values - - # compute instance - weights = parameters_values * appearance_models[level].eigenvalues[:len(parameters_values)] ** 0.5 + weights = (parameters_values * + appearance_models[level].eigenvalues[:len(parameters_values)] ** + 0.5) instance = appearance_models[level].instance(weights) - # get the current figure id - figure_id = save_figure_wid.figure_id + # update info text widget + update_info(instance, level, + landmark_options_wid.selected_values['group']) + n_labels = len(landmark_options_wid.selected_values['with_labels']) - # show image with selected options - new_figure_id = _plot_figure( - image=instance, figure_id=figure_id, image_enabled=True, - landmarks_enabled=landmark_options_wid.landmarks_enabled, - image_is_masked=channel_options_wid.image_is_masked, - masked_enabled=channel_options_wid.masked_enabled, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=[landmark_options_wid.group], - with_labels=[landmark_options_wid.with_labels], - groups_colours=dict(), subplots_enabled=False, - subplots_titles=dict(), image_axes_mode=True, - legend_enabled=landmark_options_wid.legend_enabled, - numbering_enabled=landmark_options_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # compute the mean + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + renderer = _visualize( + instance, save_figure_wid.renderer[0], True, + landmark_options_wid.selected_values['render_landmarks'], + channel_options_wid.selected_values['image_is_masked'], + channel_options_wid.selected_values['masked_enabled'], + channel_options_wid.selected_values['channels'], + channel_options_wid.selected_values['glyph_enabled'], + channel_options_wid.selected_values['glyph_block_size'], + channel_options_wid.selected_values['glyph_use_negative'], + channel_options_wid.selected_values['sum_enabled'], + [landmark_options_wid.selected_values['group']], + [landmark_options_wid.selected_values['with_labels']], + False, dict(), True, False, + tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], + tmp1['line_colour'][:n_labels], tmp2['render_markers'], + tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], + tmp2['marker_edge_colour'], tmp2['marker_face_colour'], + False, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + False, None, None, new_figure_size, tmp3['render_axes'], + tmp3['axes_font_name'], tmp3['axes_font_size'], + tmp3['axes_font_style'], tmp3['axes_x_limits'], + tmp3['axes_y_limits'], tmp3['axes_font_weight']) # save the current figure id - save_figure_wid.figure_id = new_figure_id - - # update info text widget - update_info(instance, level, landmark_options_wid.group) + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(image, level, group): @@ -538,26 +597,29 @@ def update_info(image, level, group): def plot_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, appearance_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + appearance_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = appearance_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer - # create options widgets + # create parameters, channels nad landmarks options widgets model_parameters_wid = model_parameters(n_parameters[0], plot_function, params_str='param ', mode=mode, params_bounds=parameters_bounds, @@ -565,37 +627,32 @@ def plot_eigenvalues(name): toggle_show_visible=False, plot_eig_visible=True, plot_eig_function=plot_eigenvalues) - channel_options_wid = channel_options( - appearance_models[0].mean().n_channels, - isinstance(appearance_models[0].mean(), MaskedImage), plot_function, - masked_default=True, toggle_show_default=True, - toggle_show_visible=False) - - # find initial groups and labels that will be passed to the landmark options - # widget creation - mean_has_landmarks = appearance_models[0].mean().landmarks.n_groups != 0 - if mean_has_landmarks: - all_groups_keys, all_labels_keys = _extract_groups_labels( - appearance_models[0].mean()) - else: - all_groups_keys = [' '] - all_labels_keys = [[' ']] - landmark_options_wid = landmark_options( - all_groups_keys, all_labels_keys, plot_function, - toggle_show_default=True, landmarks_default=mean_has_landmarks, - legend_default=False, numbering_default=False, - toggle_show_visible=False) + channel_options_wid = channel_options(channels_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) + landmark_options_wid = landmark_options(landmark_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) # if the mean doesn't have landmarks, then landmarks checkbox should be # disabled - landmark_options_wid.children[1].children[0].disabled = \ - not mean_has_landmarks - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + landmark_options_wid.children[1].disabled = not mean_has_landmarks + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + info_wid = info_print(toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -604,6 +661,7 @@ def update_widgets(name, value): # update model parameters update_model_parameters(model_parameters_wid, n_parameters[value], plot_function, params_str='param ') + # update channel options update_channel_options(channel_options_wid, appearance_models[value].mean().n_channels, @@ -621,31 +679,32 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) tmp_wid = ContainerWidget(children=tmp_children) - wid = TabWidget(children=[tmp_wid, channel_options_wid, - landmark_options_wid, figure_options_wid, - info_wid, save_figure_wid]) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, + landmark_options_wid, + viewer_options_wid, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='Appearance Model Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='Appearance Model Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles tab_titles = ['Appearance parameters', 'Channels options', - 'Landmarks options', 'Figure options', 'Model info', + 'Landmarks options', 'Viewer options', 'Model info', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # align widgets tmp_wid.remove_class('vbox') @@ -665,11 +724,12 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) + border_visible=False, + suboptions_border_visible=True) format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -680,11 +740,12 @@ def update_widgets(name, value): toggle_button_font_weight='bold', tab_top_margin='0cm', border_visible=False) - # update widgets' state for level 0 + # update widgets' state for image number 0 update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, From 4ab3c1f7c6188c5840acdefcbaad48a8a3ff0de7 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Thu, 25 Dec 2014 14:45:16 +0000 Subject: [PATCH 03/15] change info_print --- menpofit/visualize/widgets/base.py | 75 +++++++++++++++++++----------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index 79f15f7..e1c0f3f 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -187,7 +187,7 @@ def plot_function(name, value): plt.gca().invert_yaxis() # instance range - tmp_range = instance.range() + instance_range = instance.range() else: # Vectors mode # compute instance @@ -235,13 +235,16 @@ def plot_function(name, value): ax.add_collection(lc) # instance range - tmp_range = mean.range() + instance_range = mean.range() plt.show() # save the current figure id save_figure_wid.renderer[0] = renderer + # update info text widget + update_info(level, instance_range) + # info_wid string info_txt = r""" Level: {} out of {}. @@ -252,12 +255,30 @@ def plot_function(name, value): {} landmark points, {} features. """.format(level + 1, n_levels, shape_models[level].n_components, shape_models[level].n_active_components, - shape_models[level].variance_ratio() * 100, tmp_range[0], - tmp_range[1], mean.n_points, + shape_models[level].variance_ratio() * 100, instance_range[0], + instance_range[1], mean.n_points, shape_models[level].n_features) info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + # define function that updates info text + def update_info(level, instance_range): + lvl_sha_mod = shape_models[level] + info_wid.children[1].children[0].value = "> Level: {} out of {}.".\ + format(level + 1, n_levels) + info_wid.children[1].children[1].value = "> {} components in total.".\ + format(lvl_sha_mod.n_components) + info_wid.children[1].children[2].value = "> {} active components.".\ + format(lvl_sha_mod.n_active_components) + info_wid.children[1].children[3].value = "> {:.1f}% variance kept.".\ + format(lvl_sha_mod.variance_ratio() * 100) + info_wid.children[1].children[4].value = "> Instance range: {:.1f} " \ + "x {:.1f}.".\ + format(instance_range[0], instance_range[1]) + info_wid.children[1].children[5].value = "> {} landmark points, " \ + "{} features.".\ + format(lvl_sha_mod.mean().n_points, lvl_sha_mod.n_features) + # Plot eigenvalues function def plot_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are @@ -324,7 +345,7 @@ def mean_visible(name, value): toggle_show_default=True) viewer_options_all = ipywidgets.ContainerWidget(children=[axes_mode_wid, viewer_options_wid]) - info_wid = info_print(toggle_show_default=True, + info_wid = info_print(n_bullets=6, toggle_show_default=True, toggle_show_visible=False) # save figure widget @@ -391,7 +412,7 @@ def update_widgets(name, value): toggle_button_font_weight='bold', border_visible=False, suboptions_border_visible=True) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -573,25 +594,25 @@ def plot_function(name, value): # define function that updates info text def update_info(image, level, group): lvl_app_mod = appearance_models[level] - - info_txt = r""" - Level: {} out of {}. - {} components in total. - {} active components. - {:.1f}% variance kept. - Reference shape of size {} with {} channel{}. - {} features. - {} landmark points. - Instance: min={:.3f}, max={:.3f} - """.format(level + 1, n_levels, lvl_app_mod.n_components, - lvl_app_mod.n_active_components, - lvl_app_mod.variance_ratio() * 100, image._str_shape, - image.n_channels, 's' * (image.n_channels > 1), - lvl_app_mod.n_features, image.landmarks[group].lms.n_points, - image.pixels.min(), image.pixels.max()) - - # update info widget text - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + info_wid.children[1].children[0].value = "> Level: {} out of {}.".\ + format(level + 1, n_levels) + info_wid.children[1].children[1].value = "> {} components in total.".\ + format(lvl_app_mod.n_components) + info_wid.children[1].children[2].value = "> {} active components.".\ + format(lvl_app_mod.n_active_components) + info_wid.children[1].children[3].value = "> {:.1f}% variance kept.".\ + format(lvl_app_mod.variance_ratio() * 100) + info_wid.children[1].children[4].value = "> Reference shape of size " \ + "{} with {} channel{}.".\ + format(image._str_shape, + image.n_channels, 's' * (image.n_channels > 1)) + info_wid.children[1].children[5].value = "> {} features.".\ + format(lvl_app_mod.n_features) + info_wid.children[1].children[6].value = "> {} landmark points.".\ + format(image.landmarks[group].lms.n_points) + info_wid.children[1].children[7].value = "> Instance: min={:.3f}, " \ + "max={:.3f}".\ + format(image.pixels.min(), image.pixels.max()) # Plot eigenvalues function def plot_eigenvalues(name): @@ -646,7 +667,7 @@ def plot_eigenvalues(name): plot_function=plot_function, toggle_show_visible=False, toggle_show_default=True) - info_wid = info_print(toggle_show_default=True, + info_wid = info_print(n_bullets=8, toggle_show_default=True, toggle_show_visible=False) # save figure widget @@ -730,7 +751,7 @@ def update_widgets(name, value): toggle_button_font_weight='bold', border_visible=False, suboptions_border_visible=True) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) From ef3b333c33023f1cf3cd748397124a4a3024f1e9 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Thu, 25 Dec 2014 15:50:46 +0000 Subject: [PATCH 04/15] fixes visualize_aam --- menpofit/visualize/widgets/base.py | 421 ++++++++++++++++------------- 1 file changed, 239 insertions(+), 182 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index e1c0f3f..fced800 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -1,9 +1,5 @@ import numpy as np from collections import OrderedDict -from IPython.html.widgets import (FloatTextWidget, TextWidget, PopupWidget, - ContainerWidget, TabWidget, FloatSliderWidget, - RadioButtonsWidget, CheckboxWidget, - DropdownWidget, AccordionWidget, ButtonWidget) from menpo.visualize.widgets.options import (viewer_options, format_viewer_options, @@ -33,8 +29,7 @@ save_figure_options, format_save_figure_options) from menpo.visualize.widgets.tools import logo, format_logo -from menpo.visualize.widgets.base import (_visualize, _raw_info_string_to_latex, - _extract_groups_labels) +from menpo.visualize.widgets.base import (_visualize, _extract_groups_labels) from menpo.visualize.viewmatplotlib import (MatplotlibImageViewer2d, sample_colours_from_colourmap) @@ -245,22 +240,6 @@ def plot_function(name, value): # update info text widget update_info(level, instance_range) - # info_wid string - info_txt = r""" - Level: {} out of {}. - {} components in total. - {} active components. - {:.1f} % variance kept. - Instance range: {:.1f} x {:.1f}. - {} landmark points, {} features. - """.format(level + 1, n_levels, shape_models[level].n_components, - shape_models[level].n_active_components, - shape_models[level].variance_ratio() * 100, instance_range[0], - instance_range[1], mean.n_points, - shape_models[level].n_features) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) - # define function that updates info text def update_info(level, instance_range): lvl_sha_mod = shape_models[level] @@ -706,7 +685,7 @@ def update_widgets(name, value): level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) - tmp_wid = ContainerWidget(children=tmp_children) + tmp_wid = ipywidgets.ContainerWidget(children=tmp_children) tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, landmark_options_wid, viewer_options_wid, @@ -770,8 +749,8 @@ def update_widgets(name, value): def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(7, 7), - mode='multiple', popup=False, **kwargs): + parameters_bounds=(-3.0, 3.0), figure_size=(6, 4), + mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel AAM. @@ -781,40 +760,33 @@ def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, The multilevel AAM to be displayed. Note that each level can have different attributes, e.g. number of active components, feature type, number of channels. - n_shape_parameters : `int` or `list` of `int` or None, optional - The number of shape principal components to be used for the parameters + The number of shape components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. n_appearance_parameters : `int` or `list` of `int` or None, optional - The number of appearance principal components to be used for the - parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + The number of appearance components to be used for the parameters + sliders. + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage # find number of levels @@ -831,53 +803,120 @@ def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, n_appearance_parameters = _check_n_parameters(n_appearance_parameters, n_levels, max_n_appearance) - # define plot function + # find initial groups and labels that will be passed to the landmark options + # widget creation + mean_has_landmarks = aam.appearance_models[0].mean().landmarks.n_groups != 0 + if mean_has_landmarks: + all_groups_keys, all_labels_keys = _extract_groups_labels( + aam.appearance_models[0].mean()) + else: + all_groups_keys = [' '] + all_labels_keys = [[' ']] + + # get initial line colours for each available label + if len(all_labels_keys[0]) == 1: + line_colours = ['r'] + else: + line_colours = sample_colours_from_colourmap(len(all_labels_keys[0]), + 'jet') + + # initial options dictionaries + channels_default = 0 + if aam.appearance_models[0].mean().n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': aam.appearance_models[0].mean().n_channels, + 'image_is_masked': isinstance(aam.appearance_models[0].mean(), + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': isinstance(aam.appearance_models[0].mean(), + MaskedImage)} + landmark_options_default = {'render_landmarks': mean_has_landmarks, + 'group_keys': all_groups_keys, + 'labels_keys': all_labels_keys, + 'group': None, + 'with_labels': None} + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': line_colours, + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options} + + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected level level = 0 if n_levels > 1: level = level_wid.value - # get weights and compute instance + # compute weights and instance shape_weights = shape_model_parameters_wid.parameters_values appearance_weights = appearance_model_parameters_wid.parameters_values instance = aam.instance(level=level, shape_weights=shape_weights, appearance_weights=appearance_weights) - # get the current figure id - figure_id = save_figure_wid.figure_id + # update info text widget + update_info(aam, instance, level, + landmark_options_wid.selected_values['group']) + n_labels = len(landmark_options_wid.selected_values['with_labels']) - # show image with selected options - new_figure_id = _plot_figure( - image=instance, figure_id=figure_id, image_enabled=True, - landmarks_enabled=landmark_options_wid.landmarks_enabled, - image_is_masked=channel_options_wid.image_is_masked, - masked_enabled=channel_options_wid.masked_enabled, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=[landmark_options_wid.group], - with_labels=[landmark_options_wid.with_labels], - groups_colours=dict(), subplots_enabled=False, - subplots_titles=dict(), image_axes_mode=True, - legend_enabled=landmark_options_wid.legend_enabled, - numbering_enabled=landmark_options_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # plot + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + renderer = _visualize( + instance, save_figure_wid.renderer[0], True, + landmark_options_wid.selected_values['render_landmarks'], + channel_options_wid.selected_values['image_is_masked'], + channel_options_wid.selected_values['masked_enabled'], + channel_options_wid.selected_values['channels'], + channel_options_wid.selected_values['glyph_enabled'], + channel_options_wid.selected_values['glyph_block_size'], + channel_options_wid.selected_values['glyph_use_negative'], + channel_options_wid.selected_values['sum_enabled'], + [landmark_options_wid.selected_values['group']], + [landmark_options_wid.selected_values['with_labels']], + False, dict(), True, False, + tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], + tmp1['line_colour'][:n_labels], tmp2['render_markers'], + tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], + tmp2['marker_edge_colour'], tmp2['marker_face_colour'], + False, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + False, None, None, new_figure_size, tmp3['render_axes'], + tmp3['axes_font_name'], tmp3['axes_font_size'], + tmp3['axes_font_style'], tmp3['axes_x_limits'], + tmp3['axes_y_limits'], tmp3['axes_font_weight']) # save the current figure id - save_figure_wid.figure_id = new_figure_id - - # update info text widget - update_info(aam, instance, level, landmark_options_wid.group) + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(aam, instance, level, group): @@ -914,113 +953,124 @@ def update_info(aam, instance, level, group): else: tmp_pyramid = "Features were extracted at each pyramid level." - # Formatting is a bit ugly but this is MUCH easier to read. - info_txt = r""" - {} training images. - Warp using {} transform. - Level {}/{} (downscale={:.1f}). - {} - {} - {} - Reference frame of length {} ({} x {}C, {} x {}C). - {} shape components ({:.2f}% of variance) - {} appearance components ({:.2f}% of variance) - {} landmark points. - Instance: min={:.3f} , max={:.3f} - """.format(aam.n_training_images, aam.transform.__name__, - level + 1, - aam.n_levels, aam.downscale, tmp_shape_models, - tmp_pyramid, tmp_feat, lvl_app_mod.n_features, - tmplt_inst.n_true_pixels(), n_channels, - tmplt_inst._str_shape, n_channels, - lvl_shape_mod.n_components, - lvl_shape_mod.variance_ratio() * 100, - lvl_app_mod.n_components, - lvl_app_mod.variance_ratio() * 100, - instance.landmarks[group].lms.n_points, - instance.pixels.min(), instance.pixels.max()) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + # update info widgets + info_wid.children[1].children[0].value = "> {} training images.".\ + format(aam.n_training_images) + info_wid.children[1].children[1].value = "> Warp using {} transform.".\ + format(aam.transform.__name__) + info_wid.children[1].children[2].value = "> Level {}/{} " \ + "(downscale={:.1f}).".\ + format(level + 1, aam.n_levels, aam.downscale) + info_wid.children[1].children[3].value = "> {}".format(tmp_shape_models) + info_wid.children[1].children[4].value = "> {}".format(tmp_pyramid) + info_wid.children[1].children[5].value = "> {}".format(tmp_feat) + info_wid.children[1].children[6].value = "> Reference frame of " \ + "length {} ({} x {}C, {} x " \ + "{}C).".\ + format(lvl_app_mod.n_features, tmplt_inst.n_true_pixels(), + n_channels, tmplt_inst._str_shape, n_channels) + info_wid.children[1].children[7].value = "> {} shape components " \ + "({:.2f}% of variance).".\ + format(lvl_shape_mod.n_components, + lvl_shape_mod.variance_ratio() * 100) + info_wid.children[1].children[8].value = "> {} appearance components " \ + "({:.2f}% of variance).".\ + format(lvl_app_mod.n_components, lvl_app_mod.variance_ratio() * 100) + info_wid.children[1].children[9].value = "> {} landmark points.".\ + format(instance.landmarks[group].lms.n_points) + info_wid.children[1].children[10].value = "> Instance: min={:.3f} , " \ + "max={:.3f}.".\ + format(instance.pixels.min(), instance.pixels.max()) # Plot shape eigenvalues function def plot_shape_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, aam.shape_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + aam.shape_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = aam.shape_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # Plot appearance eigenvalues function def plot_appearance_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, - aam.appearance_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + aam.appearance_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = aam.appearance_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer - # create options widgets + # create parameters, channels nad landmarks options widgets shape_model_parameters_wid = model_parameters( n_shape_parameters[0], plot_function, params_str='param ', mode=mode, - params_bounds=parameters_bounds, toggle_show_default=False, - toggle_show_visible=True, toggle_show_name='Shape Parameters', + params_bounds=parameters_bounds, toggle_show_default=True, + toggle_show_visible=False, toggle_show_name='Shape Parameters', plot_eig_visible=True, plot_eig_function=plot_shape_eigenvalues) appearance_model_parameters_wid = model_parameters( n_appearance_parameters[0], plot_function, params_str='param ', - mode=mode, params_bounds=parameters_bounds, toggle_show_default=False, - toggle_show_visible=True, toggle_show_name='Appearance Parameters', + mode=mode, params_bounds=parameters_bounds, toggle_show_default=True, + toggle_show_visible=False, toggle_show_name='Appearance Parameters', plot_eig_visible=True, plot_eig_function=plot_appearance_eigenvalues) - channel_options_wid = channel_options( - aam.appearance_models[0].mean().n_channels, - isinstance(aam.appearance_models[0].mean(), MaskedImage), plot_function, - masked_default=True, toggle_show_default=True, - toggle_show_visible=False) - all_groups_keys, all_labels_keys = \ - _extract_groups_labels(aam.appearance_models[0].mean()) - landmark_options_wid = landmark_options(all_groups_keys, all_labels_keys, - plot_function, + channel_options_wid = channel_options(channels_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) + landmark_options_wid = landmark_options(landmark_options_default, + plot_function=plot_function, toggle_show_default=True, - landmarks_default=True, - legend_default=False, - numbering_default=False, toggle_show_visible=False) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + # if the mean doesn't have landmarks, then landmarks checkbox should be + # disabled + landmark_options_wid.children[1].disabled = not mean_has_landmarks + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + info_wid = info_print(n_bullets=11, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -1034,6 +1084,7 @@ def update_widgets(name, value): update_model_parameters(appearance_model_parameters_wid, n_appearance_parameters[value], plot_function, params_str='param ') + # update channel options update_channel_options(channel_options_wid, aam.appearance_models[value].mean().n_channels, @@ -1041,7 +1092,7 @@ def update_widgets(name, value): MaskedImage)) # create final widget - model_parameters_wid = ContainerWidget( + model_parameters_wid = ipywidgets.AccordionWidget( children=[shape_model_parameters_wid, appearance_model_parameters_wid]) tmp_children = [model_parameters_wid] if n_levels > 1: @@ -1053,45 +1104,49 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) - tmp_wid = ContainerWidget(children=tmp_children) - wid = TabWidget(children=[tmp_wid, channel_options_wid, - landmark_options_wid, figure_options_wid, - info_wid, save_figure_wid]) + tmp_wid = ipywidgets.ContainerWidget(children=tmp_children) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, + landmark_options_wid, + viewer_options_wid, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='AAM Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='AAM Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['AAM parameters', 'Channels options', 'Landmarks options', - 'Figure options', 'Model info', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + tab_titles = ['AAM parameters', 'Channels options', + 'Landmarks options', 'Viewer options', 'Model info', + 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) + tab_titles = ['Shape parameters', 'Appearance parameters'] + for (k, tl) in enumerate(tab_titles): + model_parameters_wid.set_title(k, tl) # align widgets - if n_levels > 1: - tmp_wid.remove_class('vbox') - tmp_wid.add_class('hbox') + tmp_wid.remove_class('vbox') + tmp_wid.add_class('hbox') format_model_parameters(shape_model_parameters_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=True) + border_visible=False) format_model_parameters(appearance_model_parameters_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=True) + border_visible=False) format_channel_options(channel_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -1102,12 +1157,13 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + border_visible=False, + suboptions_border_visible=True) + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -1117,11 +1173,12 @@ def update_widgets(name, value): toggle_button_font_weight='bold', tab_top_margin='0cm', border_visible=False) - # update widgets' state for level 0 + # update widgets' state for image number 0 update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), From dd1ee3ba0b18a9505e44681c2124a9b87a2451d6 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Fri, 26 Dec 2014 03:10:13 +0000 Subject: [PATCH 05/15] adds image_options --- menpofit/visualize/widgets/base.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index fced800..cfab2ec 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -510,9 +510,12 @@ def visualize_appearance_model(appearance_models, n_parameters=5, 'axes_font_weight': 'normal', 'axes_x_limits': None, 'axes_y_limits': None} + image_options = {'interpolation': 'none', + 'alpha': 1.} viewer_options_default = {'lines': lines_options, 'markers': markers_options, - 'figure': figure_options} + 'figure': figure_options, + 'image': image_options} # Define plot function def plot_function(name, value): @@ -541,6 +544,7 @@ def plot_function(name, value): tmp1 = viewer_options_wid.selected_values[0]['lines'] tmp2 = viewer_options_wid.selected_values[0]['markers'] tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] new_figure_size = (tmp3['x_scale'] * figure_size[0], tmp3['y_scale'] * figure_size[1]) renderer = _visualize( @@ -565,7 +569,8 @@ def plot_function(name, value): False, None, None, new_figure_size, tmp3['render_axes'], tmp3['axes_font_name'], tmp3['axes_font_size'], tmp3['axes_font_style'], tmp3['axes_x_limits'], - tmp3['axes_y_limits'], tmp3['axes_font_weight']) + tmp3['axes_y_limits'], tmp3['axes_font_weight'], + tmp4['interpolation'], tmp4['alpha']) # save the current figure id save_figure_wid.renderer[0] = renderer @@ -641,7 +646,8 @@ def plot_eigenvalues(name): # viewer options widget viewer_options_wid = viewer_options(viewer_options_default, - ['lines', 'markers', 'figure_one'], + ['lines', 'markers', 'figure_one', + 'image'], objects_names=None, plot_function=plot_function, toggle_show_visible=False, @@ -859,9 +865,12 @@ def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, 'axes_font_weight': 'normal', 'axes_x_limits': None, 'axes_y_limits': None} + image_options = {'interpolation': 'none', + 'alpha': 1.0} viewer_options_default = {'lines': lines_options, 'markers': markers_options, - 'figure': figure_options} + 'figure': figure_options, + 'image': image_options} # Define plot function def plot_function(name, value): @@ -889,6 +898,7 @@ def plot_function(name, value): tmp1 = viewer_options_wid.selected_values[0]['lines'] tmp2 = viewer_options_wid.selected_values[0]['markers'] tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] new_figure_size = (tmp3['x_scale'] * figure_size[0], tmp3['y_scale'] * figure_size[1]) renderer = _visualize( @@ -913,7 +923,8 @@ def plot_function(name, value): False, None, None, new_figure_size, tmp3['render_axes'], tmp3['axes_font_name'], tmp3['axes_font_size'], tmp3['axes_font_style'], tmp3['axes_x_limits'], - tmp3['axes_y_limits'], tmp3['axes_font_weight']) + tmp3['axes_y_limits'], tmp3['axes_font_weight'], + tmp4['interpolation'], tmp4['alpha']) # save the current figure id save_figure_wid.renderer[0] = renderer @@ -1059,7 +1070,8 @@ def plot_appearance_eigenvalues(name): # viewer options widget viewer_options_wid = viewer_options(viewer_options_default, - ['lines', 'markers', 'figure_one'], + ['lines', 'markers', 'figure_one', + 'image'], objects_names=None, plot_function=plot_function, toggle_show_visible=False, From 80662cad9cb8be69fac2d3d97d4ee37a9ff59d62 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Fri, 26 Dec 2014 14:31:50 +0000 Subject: [PATCH 06/15] fixes visualize_atm --- menpofit/visualize/widgets/base.py | 340 ++++++++++++++++++----------- 1 file changed, 210 insertions(+), 130 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index cfab2ec..e81ffe0 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -1194,7 +1194,7 @@ def update_widgets(name, value): def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), - figure_size=(7, 7), mode='multiple', popup=False, **kwargs): + figure_size=(6, 4), mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel ATM. @@ -1204,32 +1204,26 @@ def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), The multilevel ATM to be displayed. Note that each level can have different attributes, e.g. number of active components, feature type, number of channels. - n_shape_parameters : `int` or `list` of `int` or None, optional - The number of shape principal components to be used for the parameters + The number of shape components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage # find number of levels @@ -1243,51 +1237,123 @@ def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), n_shape_parameters = _check_n_parameters(n_shape_parameters, n_levels, max_n_shape) - # define plot function + # find initial groups and labels that will be passed to the landmark options + # widget creation + template_has_landmarks = atm.warped_templates[0].landmarks.n_groups != 0 + if template_has_landmarks: + all_groups_keys, all_labels_keys = _extract_groups_labels( + atm.warped_templates[0]) + else: + all_groups_keys = [' '] + all_labels_keys = [[' ']] + + # get initial line colours for each available label + if len(all_labels_keys[0]) == 1: + line_colours = ['r'] + else: + line_colours = sample_colours_from_colourmap(len(all_labels_keys[0]), + 'jet') + + # initial options dictionaries + channels_default = 0 + if atm.warped_templates[0].n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': atm.warped_templates[0].n_channels, + 'image_is_masked': isinstance(atm.warped_templates[0], + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': isinstance(atm.warped_templates[0], + MaskedImage)} + landmark_options_default = {'render_landmarks': template_has_landmarks, + 'group_keys': all_groups_keys, + 'labels_keys': all_labels_keys, + 'group': None, + 'with_labels': None} + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': line_colours, + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + image_options = {'interpolation': 'none', + 'alpha': 1.0} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options, + 'image': image_options} + + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected level level = 0 if n_levels > 1: level = level_wid.value - # get weights and compute instance + # compute weights and instance shape_weights = shape_model_parameters_wid.parameters_values instance = atm.instance(level=level, shape_weights=shape_weights) - # get the current figure id - figure_id = save_figure_wid.figure_id + # update info text widget + update_info(atm, instance, level, + landmark_options_wid.selected_values['group']) + n_labels = len(landmark_options_wid.selected_values['with_labels']) - # show image with selected options - new_figure_id = _plot_figure( - image=instance, figure_id=figure_id, image_enabled=True, - landmarks_enabled=landmark_options_wid.landmarks_enabled, - image_is_masked=channel_options_wid.image_is_masked, - masked_enabled=channel_options_wid.masked_enabled, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=[landmark_options_wid.group], - with_labels=[landmark_options_wid.with_labels], - groups_colours=dict(), subplots_enabled=False, - subplots_titles=dict(), image_axes_mode=True, - legend_enabled=landmark_options_wid.legend_enabled, - numbering_enabled=landmark_options_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # plot + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + renderer = _visualize( + instance, save_figure_wid.renderer[0], True, + landmark_options_wid.selected_values['render_landmarks'], + channel_options_wid.selected_values['image_is_masked'], + channel_options_wid.selected_values['masked_enabled'], + channel_options_wid.selected_values['channels'], + channel_options_wid.selected_values['glyph_enabled'], + channel_options_wid.selected_values['glyph_block_size'], + channel_options_wid.selected_values['glyph_use_negative'], + channel_options_wid.selected_values['sum_enabled'], + [landmark_options_wid.selected_values['group']], + [landmark_options_wid.selected_values['with_labels']], + False, dict(), True, False, + tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], + tmp1['line_colour'][:n_labels], tmp2['render_markers'], + tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], + tmp2['marker_edge_colour'], tmp2['marker_face_colour'], + False, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + False, None, None, new_figure_size, tmp3['render_axes'], + tmp3['axes_font_name'], tmp3['axes_font_size'], + tmp3['axes_font_style'], tmp3['axes_x_limits'], + tmp3['axes_y_limits'], tmp3['axes_font_weight'], + tmp4['interpolation'], tmp4['alpha']) # save the current figure id - save_figure_wid.figure_id = new_figure_id - - # update info text widget - update_info(atm, instance, level, landmark_options_wid.group) + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(atm, instance, level, group): @@ -1322,82 +1388,92 @@ def update_info(atm, instance, level, group): else: tmp_pyramid = "Features were extracted at each pyramid level." - # Formatting is a bit ugly but this is MUCH easier to read. - info_txt = r""" - {} training shapes. - Warp using {} transform. - Level {}/{} (downscale={:.1f}). - {} - {} - {} - Reference frame of length {} ({} x {}C, {} x {}C). - {} shape components ({:.2f}% of variance) - {} landmark points. - Instance: min={:.3f} , max={:.3f} - """.format(atm.n_training_shapes, atm.transform.__name__, - level + 1, - atm.n_levels, atm.downscale, tmp_shape_models, - tmp_pyramid, tmp_feat, - tmplt_inst.n_true_pixels() * n_channels, - tmplt_inst.n_true_pixels(), n_channels, - tmplt_inst._str_shape, n_channels, - lvl_shape_mod.n_components, - lvl_shape_mod.variance_ratio() * 100, - instance.landmarks[group].lms.n_points, - instance.pixels.min(), instance.pixels.max()) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + # update info widgets + info_wid.children[1].children[0].value = "> {} training shapes.".\ + format(atm.n_training_shapes) + info_wid.children[1].children[1].value = "> Warp using {} transform.".\ + format(atm.transform.__name__) + info_wid.children[1].children[2].value = "> Level {}/{} " \ + "(downscale={:.1f}).".\ + format(level + 1, atm.n_levels, atm.downscale) + info_wid.children[1].children[3].value = "> {}".format(tmp_shape_models) + info_wid.children[1].children[4].value = "> {}".format(tmp_pyramid) + info_wid.children[1].children[5].value = "> {}".format(tmp_feat) + info_wid.children[1].children[6].value = "> Reference frame of " \ + "length {} ({} x {}C, {} x " \ + "{}C).".\ + format(tmplt_inst.n_true_pixels() * n_channels, + tmplt_inst.n_true_pixels(), + n_channels, tmplt_inst._str_shape, n_channels) + info_wid.children[1].children[7].value = "> {} shape components " \ + "({:.2f}% of variance).".\ + format(lvl_shape_mod.n_components, + lvl_shape_mod.variance_ratio() * 100) + info_wid.children[1].children[8].value = "> {} landmark points.".\ + format(instance.landmarks[group].lms.n_points) + info_wid.children[1].children[9].value = "> Instance: min={:.3f} , " \ + "max={:.3f}.".\ + format(instance.pixels.min(), instance.pixels.max()) # Plot shape eigenvalues function def plot_shape_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, atm.shape_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + atm.shape_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = atm.shape_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer - # create options widgets + # create parameters, channels nad landmarks options widgets shape_model_parameters_wid = model_parameters( n_shape_parameters[0], plot_function, params_str='param ', mode=mode, params_bounds=parameters_bounds, toggle_show_default=True, toggle_show_visible=False, toggle_show_name='Shape Parameters', plot_eig_visible=True, plot_eig_function=plot_shape_eigenvalues) - channel_options_wid = channel_options( - atm.warped_templates[0].n_channels, - isinstance(atm.warped_templates[0], MaskedImage), - plot_function, masked_default=True, toggle_show_default=True, - toggle_show_visible=False) - all_groups_keys, all_labels_keys = \ - _extract_groups_labels(atm.warped_templates[0]) - landmark_options_wid = landmark_options(all_groups_keys, all_labels_keys, - plot_function, + channel_options_wid = channel_options(channels_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) + landmark_options_wid = landmark_options(landmark_options_default, + plot_function=plot_function, toggle_show_default=True, - landmarks_default=True, - legend_default=False, - numbering_default=False, toggle_show_visible=False) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + # if the mean doesn't have landmarks, then landmarks checkbox should be + # disabled + landmark_options_wid.children[1].disabled = not template_has_landmarks + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one', + 'image'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + info_wid = info_print(n_bullets=10, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -1407,6 +1483,7 @@ def update_widgets(name, value): update_model_parameters(shape_model_parameters_wid, n_shape_parameters[value], plot_function, params_str='param ') + # update channel options update_channel_options(channel_options_wid, atm.warped_templates[value].n_channels, @@ -1424,35 +1501,36 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) - tmp_wid = ContainerWidget(children=tmp_children) - wid = TabWidget(children=[tmp_wid, channel_options_wid, - landmark_options_wid, figure_options_wid, - info_wid, save_figure_wid]) + tmp_wid = ipywidgets.ContainerWidget(children=tmp_children) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, + landmark_options_wid, + viewer_options_wid, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='ATM Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='ATM Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['Shape parameters', 'Channels options', 'Landmarks options', - 'Figure options', 'Model info', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + tab_titles = ['Shape parameters', 'Channels options', + 'Landmarks options', 'Viewer options', 'Model info', + 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # align widgets - if n_levels > 1: - tmp_wid.remove_class('vbox') - tmp_wid.add_class('hbox') + tmp_wid.remove_class('vbox') + tmp_wid.add_class('hbox') format_model_parameters(shape_model_parameters_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -1468,12 +1546,13 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + border_visible=False, + suboptions_border_visible=True) + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -1483,11 +1562,12 @@ def update_widgets(name, value): toggle_button_font_weight='bold', tab_top_margin='0cm', border_visible=False) - # update widgets' state for level 0 + # update widgets' state for image number 0 update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False def visualize_fitting_results(fitting_results, figure_size=(7, 7), popup=False, From ac7f81cd365b3be501df5029046948ee9b40a13d Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Sat, 27 Dec 2014 12:30:38 +0000 Subject: [PATCH 07/15] adds plot_erros and plot_displacements to fittingresult --- menpofit/fittingresult.py | 210 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) diff --git a/menpofit/fittingresult.py b/menpofit/fittingresult.py index 1b52f9e..74b2e20 100644 --- a/menpofit/fittingresult.py +++ b/menpofit/fittingresult.py @@ -211,6 +211,216 @@ def view_widget(self, popup=False): from menpofit.visualize import visualize_fitting_results visualize_fitting_results(self, figure_size=(7, 7), popup=popup) + def plot_errors(self, error_type='me_norm', figure_id=None, + new_figure=False, render_lines=True, line_colour='b', + line_style='-', line_width=2, render_markers=True, + marker_style='o', marker_size=4, marker_face_colour='b', + marker_edge_colour='k', marker_edge_width=1., + render_axes=True, axes_font_name='sans-serif', + axes_font_size=10, axes_font_style='normal', + axes_font_weight='normal', figure_size=(6, 4), + render_grid=True, grid_line_style='--', + grid_line_width=0.5): + r""" + Plot of the error evolution at each fitting iteration. + + Parameters + ---------- + error_type : {``me_norm``, ``me``, ``rmse``}, optional + Specifies the way in which the error between the fitted and + ground truth shapes is to be computed. + figure_id : `object`, optional + The id of the figure to be used. + new_figure : `bool`, optional + If ``True``, a new figure is created. + render_lines : `bool`, optional + If ``True``, the line will be rendered. + line_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} or + ``(3, )`` `ndarray`, optional + The colour of the lines. + line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the lines. + line_width : `float`, optional + The width of the lines. + render_markers : `bool`, optional + If ``True``, the markers will be rendered. + marker_style : {``.``, ``,``, ``o``, ``v``, ``^``, ``<``, ``>``, ``+``, + ``x``, ``D``, ``d``, ``s``, ``p``, ``*``, ``h``, ``H``, + ``1``, ``2``, ``3``, ``4``, ``8``}, optional + The style of the markers. + marker_size : `int`, optional + The size of the markers in points^2. + marker_face_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The face (filling) colour of the markers. + marker_edge_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The edge colour of the markers. + marker_edge_width : `float`, optional + The width of the markers' edge. + render_axes : `bool`, optional + If ``True``, the axes will be rendered. + axes_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the axes. + axes_font_size : `int`, optional + The font size of the axes. + axes_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the axes. + axes_font_weight : {``ultralight``, ``light``, ``normal``, ``regular``, + ``book``, ``medium``, ``roman``, ``semibold``, + ``demibold``, ``demi``, ``bold``, ``heavy``, + ``extra bold``, ``black``}, optional + The font weight of the axes. + figure_size : (`float`, `float`) or `None`, optional + The size of the figure in inches. + render_grid : `bool`, optional + If ``True``, the grid will be rendered. + grid_line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the grid lines. + grid_line_width : `float`, optional + The width of the grid lines. + + Returns + ------- + viewer : :map:`GraphPlotter` + The viewer object. + """ + from menpo.visualize import GraphPlotter + errors_list = self.errors(error_type=error_type) + return GraphPlotter(figure_id=figure_id, new_figure=new_figure, + x_axis=range(len(errors_list)), + y_axis=[errors_list], + title='Fitting Errors per Iteration', + x_label='Iteration', y_label='Fitting Error', + x_axis_limits=(0, len(errors_list)-1), + y_axis_limits=None).render( + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, render_legend=False, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, render_grid=render_grid, + grid_line_style=grid_line_style, grid_line_width=grid_line_width, + figure_size=figure_size) + + def plot_displacements(self, stat_type='mean', figure_id=None, + new_figure=False, render_lines=True, line_colour='b', + line_style='-', line_width=2, render_markers=True, + marker_style='o', marker_size=4, + marker_face_colour='b', marker_edge_colour='k', + marker_edge_width=1., render_axes=True, + axes_font_name='sans-serif', axes_font_size=10, + axes_font_style='normal', axes_font_weight='normal', + figure_size=(6, 4), render_grid=True, + grid_line_style='--', grid_line_width=0.5): + r""" + Plot of a statistical metric of the displacement between the shape of + each iteration and the shape of the previous one. + + Parameters + ---------- + stat_type : {``mean``, ``median``, ``min``, ``max``}, optional + Specifies a statistic metric to be extracted from the displacements + (see also `displacements_stats()` method). + figure_id : `object`, optional + The id of the figure to be used. + new_figure : `bool`, optional + If ``True``, a new figure is created. + render_lines : `bool`, optional + If ``True``, the line will be rendered. + line_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} or + ``(3, )`` `ndarray`, optional + The colour of the lines. + line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the lines. + line_width : `float`, optional + The width of the lines. + render_markers : `bool`, optional + If ``True``, the markers will be rendered. + marker_style : {``.``, ``,``, ``o``, ``v``, ``^``, ``<``, ``>``, ``+``, + ``x``, ``D``, ``d``, ``s``, ``p``, ``*``, ``h``, ``H``, + ``1``, ``2``, ``3``, ``4``, ``8``}, optional + The style of the markers. + marker_size : `int`, optional + The size of the markers in points^2. + marker_face_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The face (filling) colour of the markers. + marker_edge_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The edge colour of the markers. + marker_edge_width : `float`, optional + The width of the markers' edge. + render_axes : `bool`, optional + If ``True``, the axes will be rendered. + axes_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the axes. + axes_font_size : `int`, optional + The font size of the axes. + axes_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the axes. + axes_font_weight : {``ultralight``, ``light``, ``normal``, ``regular``, + ``book``, ``medium``, ``roman``, ``semibold``, + ``demibold``, ``demi``, ``bold``, ``heavy``, + ``extra bold``, ``black``}, optional + The font weight of the axes. + figure_size : (`float`, `float`) or `None`, optional + The size of the figure in inches. + render_grid : `bool`, optional + If ``True``, the grid will be rendered. + grid_line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the grid lines. + grid_line_width : `float`, optional + The width of the grid lines. + + Returns + ------- + viewer : :map:`GraphPlotter` + The viewer object. + """ + from menpo.visualize import GraphPlotter + # set labels + if stat_type == 'max': + ylabel = 'Maximum Displacement' + title = 'Maximum displacement per Iteration' + elif stat_type == 'min': + ylabel = 'Minimum Displacement' + title = 'Minimum displacement per Iteration' + elif stat_type == 'mean': + ylabel = 'Mean Displacement' + title = 'Mean displacement per Iteration' + elif stat_type == 'median': + ylabel = 'Median Displacement' + title = 'Median displacement per Iteration' + else: + raise ValueError('stat_type must be one of {max, min, mean, ' + 'median}.') + # plot + displacements_list = self.displacements_stats(stat_type=stat_type) + return GraphPlotter(figure_id=figure_id, new_figure=new_figure, + x_axis=range(len(displacements_list)), + y_axis=[displacements_list], + title=title, + x_label='Iteration', y_label=ylabel, + x_axis_limits=(0, len(displacements_list)-1), + y_axis_limits=None).render( + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, render_legend=False, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, render_grid=render_grid, + grid_line_style=grid_line_style, grid_line_width=grid_line_width, + figure_size=figure_size) + def as_serializable(self): r"""" Returns a serializable version of the fitting result. This is a much From cf5b5c8a61993fbf5b98412d0b692b78987f0c45 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Sat, 27 Dec 2014 15:15:38 +0000 Subject: [PATCH 08/15] adds widgets options and updates final_result_options --- menpofit/visualize/widgets/options.py | 1626 +++++++++++++++++++++++++ 1 file changed, 1626 insertions(+) create mode 100644 menpofit/visualize/widgets/options.py diff --git a/menpofit/visualize/widgets/options.py b/menpofit/visualize/widgets/options.py new file mode 100644 index 0000000..a4646e2 --- /dev/null +++ b/menpofit/visualize/widgets/options.py @@ -0,0 +1,1626 @@ +from collections import OrderedDict + +from menpo.visualize.widgets.tools import (colour_selection, + format_colour_selection) +from menpo.visualize.widgets.options import _compare_groups_and_labels + + +def model_parameters(n_params, plot_function=None, params_str='', + mode='multiple', params_bounds=(-3., 3.), + plot_eig_visible=True, plot_eig_function=None, + toggle_show_default=True, toggle_show_visible=True, + toggle_show_name='Parameters'): + r""" + Creates a widget with Model Parameters. Specifically, it has: + 1) A slider for each parameter if mode is 'multiple'. + 2) A single slider and a drop down menu selection if mode is 'single'. + 3) A reset button. + 4) A button and two radio buttons for plotting the eigenvalues variance + ratio. + + The structure of the widgets is the following: + model_parameters_wid.children = [toggle_button, parameters_and_reset] + parameters_and_reset.children = [parameters_widgets, reset] + If plot_eig_visible is True: + reset = [plot_eigenvalues, reset_button] + Else: + reset = reset_button + If mode is single: + parameters_widgets.children = [drop_down_menu, slider] + If mode is multiple: + parameters_widgets.children = [all_sliders] + + The returned widget saves the selected values in the following fields: + model_parameters_wid.parameters_values + model_parameters_wid.mode + model_parameters_wid.plot_eig_visible + + To fix the alignment within this widget please refer to + `format_model_parameters()` function. + + To update the state of this widget, please refer to + `update_model_parameters()` function. + + Parameters + ---------- + n_params : `int` + The number of principal components to use for the sliders. + + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + + params_str : `str`, optional + The string that will be used for each parameters name. + + mode : 'single' or 'multiple', optional + If single, only a single slider is constructed along with a drop down + menu. + If multiple, a slider is constructed for each parameter. + + params_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + + plot_eig_visible : `boolean`, optional + Defines whether the options for plotting the eigenvalues variance ratio + will be visible upon construction. + + plot_eig_function : `function` or None, optional + The plot function that is executed when the plot eigenvalues button is + clicked. If None, then nothing is assigned. + + toggle_show_default : `boolean`, optional + Defines whether the options will be visible upon construction. + + toggle_show_visible : `boolean`, optional + The visibility of the toggle button. + + toggle_show_name : `str`, optional + The name of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # If only one slider requested, then set mode to multiple + if n_params == 1: + mode = 'multiple' + + # Create all necessary widgets + but = ipywidgets.ToggleButtonWidget(description=toggle_show_name, + value=toggle_show_default, + visible=toggle_show_visible) + reset_button = ipywidgets.ButtonWidget(description='Reset') + if mode == 'multiple': + sliders = [ipywidgets.FloatSliderWidget( + description="{}{}".format(params_str, p), + min=params_bounds[0], max=params_bounds[1], + value=0.) + for p in range(n_params)] + parameters_wid = ipywidgets.ContainerWidget(children=sliders) + else: + vals = OrderedDict() + for p in range(n_params): + vals["{}{}".format(params_str, p)] = p + slider = ipywidgets.FloatSliderWidget(description='', + min=params_bounds[0], + max=params_bounds[1], value=0.) + dropdown_params = ipywidgets.DropdownWidget(values=vals) + parameters_wid = ipywidgets.ContainerWidget( + children=[dropdown_params, slider]) + + # Group widgets + if plot_eig_visible: + plot_button = ipywidgets.ButtonWidget(description='Plot eigenvalues') + if plot_eig_function is not None: + plot_button.on_click(plot_eig_function) + plot_and_reset = ipywidgets.ContainerWidget( + children=[plot_button, reset_button]) + params_and_reset = ipywidgets.ContainerWidget(children=[parameters_wid, + plot_and_reset]) + else: + params_and_reset = ipywidgets.ContainerWidget(children=[parameters_wid, + reset_button]) + + # Widget container + model_parameters_wid = ipywidgets.ContainerWidget( + children=[but, params_and_reset]) + + # Save mode and parameters values + model_parameters_wid.parameters_values = [0.0] * n_params + model_parameters_wid.mode = mode + model_parameters_wid.plot_eig_visible = plot_eig_visible + + # set up functions + if mode == 'single': + # assign slider value to parameters values list + def save_slider_value(name, value): + model_parameters_wid.parameters_values[dropdown_params.value] = \ + value + slider.on_trait_change(save_slider_value, 'value') + + # set correct value to slider when drop down menu value changes + def set_slider_value(name, value): + slider.value = model_parameters_wid.parameters_values[value] + dropdown_params.on_trait_change(set_slider_value, 'value') + + # assign main plotting function when slider value changes + if plot_function is not None: + slider.on_trait_change(plot_function, 'value') + else: + # assign slider value to parameters values list + def save_slider_value_from_id(description, name, value): + i = int(description[len(params_str)::]) + model_parameters_wid.parameters_values[i] = value + + # partial function that helps get the widget's description str + def partial_widget(description): + return lambda name, value: save_slider_value_from_id(description, + name, value) + + # assign saving values and main plotting function to all sliders + for w in parameters_wid.children: + # The widget (w) is lexically scoped and so we need a way of + # ensuring that we don't just receive the final value of w at every + # iteration. Therefore we create another lambda function that + # creates a new lexical scoping so that we can ensure the value of w + # is maintained (as x) at each iteration. + # In JavaScript, we would just use the 'let' keyword... + w.on_trait_change(partial_widget(w.description), 'value') + if plot_function is not None: + w.on_trait_change(plot_function, 'value') + + # reset function + def reset_params(name): + model_parameters_wid.parameters_values = \ + [0.0] * len(model_parameters_wid.parameters_values) + if mode == 'multiple': + for ww in parameters_wid.children: + ww.value = 0. + else: + parameters_wid.children[0].value = 0 + parameters_wid.children[1].value = 0. + reset_button.on_click(reset_params) + + # Toggle button function + def show_options(name, value): + params_and_reset.visible = value + show_options('', toggle_show_default) + but.on_trait_change(show_options, 'value') + + return model_parameters_wid + + +def format_model_parameters(model_parameters_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=True): + r""" + Function that corrects the align (style format) of a given model_parameters + widget. Usage example: + model_parameters_wid = model_parameters() + display(model_parameters_wid) + format_model_parameters(model_parameters_wid) + + Parameters + ---------- + model_parameters_wid : + The widget object generated by the `model_parameters()` function. + + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + + border_visible : `boolean`, optional + Defines whether to draw the border line around the widget. + """ + if model_parameters_wid.mode == 'single': + # align drop down menu and slider + model_parameters_wid.children[1].children[0].remove_class('vbox') + model_parameters_wid.children[1].children[0].add_class('hbox') + else: + # align sliders + model_parameters_wid.children[1].children[0].add_class('start') + + # align reset button to right + if model_parameters_wid.plot_eig_visible: + model_parameters_wid.children[1].children[1].remove_class('vbox') + model_parameters_wid.children[1].children[1].add_class('hbox') + model_parameters_wid.children[1].add_class('align-end') + + # set toggle button font bold + model_parameters_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + + # margin and border around plot_eigenvalues widget + if model_parameters_wid.plot_eig_visible: + model_parameters_wid.children[1].children[1].children[0].set_css( + 'margin-right', container_margin) + + # margin and border around container widget + model_parameters_wid.set_css('padding', container_padding) + model_parameters_wid.set_css('margin', container_margin) + if border_visible: + model_parameters_wid.set_css('border', container_border) + + +def update_model_parameters(model_parameters_wid, n_params, plot_function=None, + params_str=''): + r""" + Function that updates the state of a given model_parameters widget if the + requested number of parameters has changed. Usage example: + model_parameters_wid = model_parameters(n_params=5) + display(model_parameters_wid) + format_model_parameters(model_parameters_wid) + update_model_parameters(model_parameters_wid, 3) + + Parameters + ---------- + model_parameters_wid : + The widget object generated by the `model_parameters()` function. + + n_params : `int` + The requested number of parameters. + + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + + params_str : `str`, optional + The string that will be used for each parameters name. + """ + import IPython.html.widgets as ipywidgets + + if model_parameters_wid.mode == 'multiple': + # get the number of enabled parameters (number of sliders) + enabled_params = len(model_parameters_wid.children[1].children[0].children) + if n_params != enabled_params: + # reset all parameters values + model_parameters_wid.parameters_values = [0.0] * n_params + # get params_bounds + pb = [model_parameters_wid.children[1].children[0].children[0].min, + model_parameters_wid.children[1].children[0].children[0].max] + # create sliders widgets + sliders = [ipywidgets.FloatSliderWidget( + description="{}{}".format(params_str, + p), + min=pb[0], max=pb[1], value=0.) + for p in range(n_params)] + # assign sliders to container + model_parameters_wid.children[1].children[0].children = sliders + + # assign slider value to parameters values list + def save_slider_value_from_id(description, name, value): + i = int(description[len(params_str)::]) + model_parameters_wid.parameters_values[i] = value + + # partial function that helps get the widget's description str + def partial_widget(description): + return lambda name, value: save_slider_value_from_id( + description, + name, value) + + # assign saving values and main plotting function to all sliders + for w in model_parameters_wid.children[1].children[0].children: + # The widget (w) is lexically scoped and so we need a way of + # ensuring that we don't just receive the final value of w at + # every iteration. Therefore we create another lambda function + # that creates a new lexical scoping so that we can ensure the + # value of w is maintained (as x) at each iteration + # In JavaScript, we would just use the 'let' keyword... + w.on_trait_change(partial_widget(w.description), 'value') + if plot_function is not None: + w.on_trait_change(plot_function, 'value') + else: + # get the number of enabled parameters (len of list of drop down menu) + enabled_params = len( + model_parameters_wid.children[1].children[0].children[0].values) + if n_params != enabled_params: + # reset all parameters values + model_parameters_wid.parameters_values = [0.0] * n_params + # change drop down menu values + vals = OrderedDict() + for p in range(n_params): + vals["{}{}".format(params_str, p)] = p + model_parameters_wid.children[1].children[0].children[0].values = \ + vals + # set initial value to the first and slider value to zero + model_parameters_wid.children[1].children[0].children[0].value = \ + vals["{}{}".format(params_str, 0)] + model_parameters_wid.children[1].children[0].children[1].value = 0. + + +def final_result_options(final_result_options_default, plot_function=None, + title='Final Result', toggle_show_default=True, + toggle_show_visible=True): + r""" + Creates a widget with Final Result Options. Specifically, it has: + 1) A set of toggle buttons representing usually the initial, final and + ground truth shapes. + 2) A checkbox that controls the rendering of the image. + 3) A set of radio buttons that define whether subplots are enabled. + 4) A toggle button that controls the visibility of all the above, i.e. + the final result options. + + The structure of the widgets is the following: + final_result_wid.children = [toggle_button, shapes_toggle_buttons, + options] + options.children = [plot_mode_radio_buttons, show_image_checkbox] + + The returned widget saves the selected values in the following fields: + final_result_wid.selected_values + + To fix the alignment within this widget please refer to + `format_final_result_options()` function. + + To update the state of this widget, please refer to + `update_final_result_options()` function. + + Parameters + ---------- + final_result_options_default : `dict` + The default options. For example: + final_result_options_default = {'all_groups': ['initial', 'final', + 'ground'], + 'selected_groups': ['final'], + 'render_image': True, + 'subplots_enabled': True} + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + title : `str`, optional + The title of the widget printed at the toggle button. + toggle_show_default : `bool`, optional + Defines whether the options will be visible upon construction. + toggle_show_visible : `bool`, optional + The visibility of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # Toggle button that controls options' visibility + but = ipywidgets.ToggleButtonWidget(description=title, + value=toggle_show_default, + visible=toggle_show_visible) + + # Create widgets + shapes_checkboxes = [ipywidgets.LatexWidget(value='Select shape:')] + for group in final_result_options_default['all_groups']: + t = ipywidgets.ToggleButtonWidget( + description=group, + value=group in final_result_options_default['selected_groups']) + shapes_checkboxes.append(t) + render_image = ipywidgets.CheckboxWidget( + description='Render image', + value=final_result_options_default['render_image']) + mode = ipywidgets.RadioButtonsWidget( + description='Plot mode:', values={'Single': False, 'Multiple': True}, + value=final_result_options_default['subplots_enabled']) + + # Group widgets + shapes_wid = ipywidgets.ContainerWidget(children=shapes_checkboxes) + opts = ipywidgets.ContainerWidget(children=[mode, render_image]) + + # Widget container + final_result_wid = ipywidgets.ContainerWidget(children=[but, shapes_wid, + opts]) + + # Initialize variables + final_result_wid.selected_values = final_result_options_default + + # Groups function + def groups_fun(name, value): + final_result_wid.selected_values['selected_groups'] = [] + for i in shapes_wid.children[1::]: + if i.value: + final_result_wid.selected_values['selected_groups'].\ + append(str(i.description)) + for w in shapes_wid.children[1::]: + w.on_trait_change(groups_fun, 'value') + + # Render image function + def render_image_fun(name, value): + final_result_wid.selected_values['render_image'] = value + render_image.on_trait_change(render_image_fun, 'value') + + # Plot mode function + def plot_mode_fun(name, value): + final_result_wid.selected_values['subplots_enabled'] = value + mode.on_trait_change(plot_mode_fun, 'value') + + # Toggle button function + def show_options(name, value): + shapes_wid.visible = value + opts.visible = value + show_options('', toggle_show_default) + but.on_trait_change(show_options, 'value') + + # assign plot_function + if plot_function is not None: + render_image.on_trait_change(plot_function, 'value') + mode.on_trait_change(plot_function, 'value') + for w in shapes_wid.children[1::]: + w.on_trait_change(plot_function, 'value') + + return final_result_wid + + +def format_final_result_options(final_result_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=True): + r""" + Function that corrects the align (style format) of a given + final_result_options widget. Usage example: + final_result_options_default = {'all_groups': ['initial', 'final', + 'ground'], + 'selected_groups': ['final'], + 'render_image': True, + 'subplots_enabled': True} + final_result_wid = final_result_options(final_result_options_default) + display(final_result_wid) + format_final_result_options(final_result_wid) + + Parameters + ---------- + final_result_wid : + The widget object generated by the `final_result_options()` function. + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + border_visible : `boolean`, optional + Defines whether to draw the border line around the widget. + """ + # align shapes toggle buttons + final_result_wid.children[1].remove_class('vbox') + final_result_wid.children[1].add_class('hbox') + final_result_wid.children[1].add_class('align-center') + final_result_wid.children[1].children[0].set_css('margin-right', + container_margin) + + # align mode and legend options + final_result_wid.children[2].remove_class('vbox') + final_result_wid.children[2].add_class('hbox') + final_result_wid.children[2].children[0].set_css('margin-right', '20px') + + # set toggle button font bold + final_result_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + final_result_wid.children[1].set_css('margin-top', container_margin) + + # margin and border around container widget + final_result_wid.set_css('padding', container_padding) + final_result_wid.set_css('margin', container_margin) + if border_visible: + final_result_wid.set_css('border', container_border) + + +def update_final_result_options(final_result_wid, group_keys, plot_function): + r""" + Function that updates the state of a given final_result_options widget if + the group keys of an image has changed. Usage example: + final_result_options_default = {'all_groups': ['group1', 'group2'], + 'selected_groups': ['group1'], + 'render_image': True, + 'subplots_enabled': True} + final_result_wid = final_result_options(final_result_options_default) + display(final_result_wid) + format_final_result_options(final_result_wid) + update_final_result_options(final_result_wid, group_keys=['group3']) + format_final_result_options(final_result_wid) + + Note that the `format_final_result_options()` function needs to be called + again after the `update_final_result_options()` function. + + Parameters + ---------- + final_result_wid : + The widget object generated by the `final_result_options()` function. + group_keys : `list` of `str` + A list of the available landmark groups. + plot_function : `function` or None + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + """ + import IPython.html.widgets as ipywidgets + # check if the new group_keys are the same as the old ones + if not _compare_groups_and_labels( + group_keys, [], final_result_wid.selected_values['all_groups'], []): + # Create all necessary widgets + shapes_checkboxes = [ipywidgets.LatexWidget(value='Select shape:')] + for group in group_keys: + t = ipywidgets.ToggleButtonWidget(description=group, value=True) + shapes_checkboxes.append(t) + + # Group widgets + final_result_wid.children[1].children = shapes_checkboxes + + # Initialize output variables + final_result_wid.selected_values['all_groups'] = group_keys + final_result_wid.selected_values['selected_groups'] = group_keys + + # Groups function + def groups_fun(name, value): + final_result_wid.selected_values['selected_groups'] = [] + for i in final_result_wid.children[1].children[1::]: + if i.value: + final_result_wid.selected_values['selected_groups'].append(str(i.description)) + for w in final_result_wid.children[1].children[1::]: + w.on_trait_change(groups_fun, 'value') + + # Toggle button function + def show_options(name, value): + final_result_wid.children[1].visible = value + final_result_wid.children[2].visible = value + show_options('', final_result_wid.children[0].value) + final_result_wid.children[0].on_trait_change(show_options, 'value') + + # assign plot_function + if plot_function is not None: + final_result_wid.children[2].children[0].on_trait_change( + plot_function, 'value') + final_result_wid.children[2].children[1].on_trait_change( + plot_function, 'value') + for w in final_result_wid.children[1].children[1::]: + w.on_trait_change(plot_function, 'value') + + +def iterations_result_options(n_iters, image_has_gt_shape, n_points, + plot_function=None, plot_errors_function=None, + plot_displacements_function=None, + iter_str='iter_', title='Iterations Result', + show_image_default=True, + subplots_enabled_default=False, + numbering_default=False, + legend_default=True, toggle_show_default=True, + toggle_show_visible=True): + r""" + Creates a widget with Iterations Result Options. Specifically, it has: + 1) Two radio buttons that select an options mode, depending on whether + the user wants to visualize iterations in "Animation" or "Static" mode. + 2) If mode is "Animation", an animation options widget appears. + If mode is "Static", the iterations range is selected by two + sliders and there is an update plot button. + 3) A checkbox that controls the visibility of the image. + 4) A set of radio buttons that define whether subplots are enabled. + 5) A checkbox that controls the legend's visibility. + 6) A checkbox that controls the numbering visibility. + 7) A button to plot the error evolution. + 8) A button to plot the landmark points' displacement. + 9) A drop down menu to select which displacement to plot. + 10) A toggle button that controls the visibility of all the above, i.e. + the final result options. + + The structure of the widgets is the following: + iterations_result_wid.children = [toggle_button, + iterations_mode_and_sliders, + options] + iterations_mode_and_sliders.children = [iterations_mode_radio_buttons, + all_sliders] + all_sliders.children = [animation_slider, first_slider, second_slider, + update_button] + options.children = [plot_mode_radio_buttons, show_image_checkbox, + show_legend_checkbox, plot_errors_button, + plot_displacements, show_numbering_checkbox] + plot_displacements.children = [plot_displacements_button, + plot_displacements_drop_down_menu] + + The returned widget saves the selected values in the following fields: + iterations_result_wid.groups + iterations_result_wid.image_has_gt_shape + iterations_result_wid.n_iters + iterations_result_wid.n_points + iterations_result_wid.show_image + iterations_result_wid.subplots_enabled + iterations_result_wid.legend_enabled + iterations_result_wid.numbering_enabled + iterations_result_wid.displacement_type + + To fix the alignment within this widget please refer to + `format_iterations_result_options()` function. + + To update the state of this widget, please refer to + `update_iterations_result_options()` function. + + Parameters + ---------- + n_iters : `int` + The number of iterations. + image_has_gt_shape : `bool` + Flag that defines whether the fitted image has a ground shape attached. + n_points : `int` + The number of the object's landmark points. It is required by the + displacement dorp down menu. + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + plot_errors_function : `function` or None, optional + The plot function that is executed when the 'Plot Errors' button is + pressed. + If None, then nothing is assigned. + plot_displacements_function : `function` or None, optional + The plot function that is executed when the 'Plot Displacements' button + is pressed. + If None, then nothing is assigned. + iter_str : `str`, optional + The str that is used in the landmark groups shapes. + E.g. if iter_str == "iter_" then the group label of iteration i has the + form "{}{}".format(iter_str, i) + title : `str`, optional + The title of the widget printed at the toggle button. + show_image_default : `bool`, optional + The initial value of the image's visibility checkbox. + subplots_enabled_default : `bool`, optional + The initial value of the plot options' radio buttons that determine + whether a single plot or subplots will be used. + legend_default : `bool`, optional + The initial value of the legend's visibility checkbox. + numbering_default : `bool`, optional + The initial value of the numbering visibility checkbox. + toggle_show_default : `bool`, optional + Defines whether the options will be visible upon construction. + toggle_show_visible : `bool`, optional + The visibility of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # Create all necessary widgets + but = ipywidgets.ToggleButtonWidget(description=title, + value=toggle_show_default, + visible=toggle_show_visible) + iterations_mode = ipywidgets.RadioButtonsWidget( + values={'Animation': 0, 'Static': 1}, + value=0, + description='Iterations mode:', + visible=toggle_show_default) + # Don't assign the plot function to the animation_wid at this point. We + # first need to assign the get_groups function and then the plot_function() + # for synchronization reasons. + animation_wid = animation_options(index_min_val=0, + index_max_val=n_iters - 1, + plot_function=None, + update_function=None, + index_step=1, index_default=0, + index_description='Iteration', + index_style='slider', + loop_default=False, interval_default=0.2, + toggle_show_default=toggle_show_default, + toggle_show_visible=False) + first_slider_wid = ipywidgets.IntSliderWidget(min=0, max=n_iters - 1, + step=1, + value=0, description='From', + visible=False) + second_slider_wid = ipywidgets.IntSliderWidget(min=0, max=n_iters - 1, + step=1, + value=n_iters - 1, + description='To', + visible=False) + update_but = ipywidgets.ButtonWidget(description='Update Plot', + visible=False) + show_image = ipywidgets.CheckboxWidget(description='Show image', + value=show_image_default) + plot_errors_button = ipywidgets.ButtonWidget(description='Plot Errors') + plot_displacements_button = ipywidgets.ButtonWidget( + description='Plot Displacements') + dropdown_menu = OrderedDict() + dropdown_menu['mean'] = 'mean' + dropdown_menu['median'] = 'median' + dropdown_menu['max'] = 'max' + dropdown_menu['min'] = 'min' + for p in range(n_points): + dropdown_menu["point {}".format(p)] = p + plot_displacements_menu = ipywidgets.SelectWidget(values=dropdown_menu, + value='mean') + plot_mode = ipywidgets.RadioButtonsWidget(description='Plot mode:', + values={'Single': False, + 'Multiple': True}) + plot_mode.value = subplots_enabled_default + show_legend = ipywidgets.CheckboxWidget(description='Show legend', + value=legend_default) + show_numbering = ipywidgets.CheckboxWidget(description='Show numbering', + value=numbering_default) + # if just one iteration, disable multiple options + if n_iters == 1: + iterations_mode.value = 0 + iterations_mode.disabled = True + first_slider_wid.disabled = True + animation_wid.children[1].children[0].children[2].disabled = True + animation_wid.children[1].children[1].children[0].children[0]. \ + disabled = True + animation_wid.children[1].children[1].children[0].children[1]. \ + disabled = True + animation_wid.children[1].children[1].children[0].children[2]. \ + disabled = True + second_slider_wid.disabled = True + plot_errors_button.disabled = True + plot_displacements_button.disabled = True + plot_displacements_menu.disabled = True + + # Group widgets + sliders = ipywidgets.ContainerWidget( + children=[animation_wid, first_slider_wid, + second_slider_wid, update_but]) + iterations_mode_and_sliders = ipywidgets.ContainerWidget( + children=[iterations_mode, + sliders]) + plot_displacements = ipywidgets.ContainerWidget( + children=[plot_displacements_button, + plot_displacements_menu]) + opts = ipywidgets.ContainerWidget( + children=[plot_mode, show_image, show_legend, + show_numbering, plot_errors_button, + plot_displacements]) + + # Widget container + iterations_result_wid = ipywidgets.ContainerWidget(children=[ + but, iterations_mode_and_sliders, opts]) + + # Initialize variables + iterations_result_wid.groups = _convert_iterations_to_groups(0, 0, iter_str) + iterations_result_wid.image_has_gt_shape = image_has_gt_shape + iterations_result_wid.n_iters = n_iters + iterations_result_wid.n_points = n_points + iterations_result_wid.show_image = show_image_default + iterations_result_wid.subplots_enabled = subplots_enabled_default + iterations_result_wid.legend_enabled = legend_default + iterations_result_wid.numbering_enabled = numbering_default + iterations_result_wid.displacement_type = 'mean' + + # Define iterations mode visibility + def iterations_mode_selection(name, value): + if value == 0: + # get val that needs to be assigned + val = first_slider_wid.value + # update visibility + animation_wid.visible = True + first_slider_wid.visible = False + second_slider_wid.visible = False + update_but.visible = False + # set correct values + animation_wid.children[1].children[0].children[2].value = val + animation_wid.selected_index = val + first_slider_wid.value = 0 + second_slider_wid.value = n_iters - 1 + else: + # get val that needs to be assigned + val = animation_wid.selected_index + # update visibility + animation_wid.visible = False + first_slider_wid.visible = True + second_slider_wid.visible = True + update_but.visible = True + # set correct values + second_slider_wid.value = val + first_slider_wid.value = val + animation_wid.children[1].children[0].children[2].value = 0 + animation_wid.selected_index = 0 + iterations_mode.on_trait_change(iterations_mode_selection, 'value') + + # Check first slider's value + def first_slider_val(name, value): + if value > second_slider_wid.value: + first_slider_wid.value = second_slider_wid.value + first_slider_wid.on_trait_change(first_slider_val, 'value') + + # Check second slider's value + def second_slider_val(name, value): + if value < first_slider_wid.value: + second_slider_wid.value = first_slider_wid.value + second_slider_wid.on_trait_change(second_slider_val, 'value') + + # Convert slider values to groups + def get_groups(name, value): + if iterations_mode.value == 0: + iterations_result_wid.groups = _convert_iterations_to_groups( + animation_wid.selected_index, + animation_wid.selected_index, iter_str) + else: + iterations_result_wid.groups = _convert_iterations_to_groups( + first_slider_wid.value, second_slider_wid.value, iter_str) + first_slider_wid.on_trait_change(get_groups, 'value') + second_slider_wid.on_trait_change(get_groups, 'value') + + # assign get_groups() to the slider of animation_wid + animation_wid.children[1].children[0].children[2].\ + on_trait_change(get_groups, 'value') + + # Show image function + def show_image_fun(name, value): + iterations_result_wid.show_image = value + show_image.on_trait_change(show_image_fun, 'value') + + # Plot mode function + def plot_mode_fun(name, value): + iterations_result_wid.subplots_enabled = value + plot_mode.on_trait_change(plot_mode_fun, 'value') + + # Legend function + def legend_fun(name, value): + iterations_result_wid.legend_enabled = value + show_legend.on_trait_change(legend_fun, 'value') + + # Numbering function + def numbering_fun(name, value): + iterations_result_wid.numbering_enabled = value + show_numbering.on_trait_change(numbering_fun, 'value') + + # Displacement type function + def displacement_type_fun(name, value): + iterations_result_wid.displacement_type = value + plot_displacements_menu.on_trait_change(displacement_type_fun, 'value') + + # Toggle button function + def show_options(name, value): + iterations_mode.visible = value + plot_mode.visible = value + show_image.visible = value + show_legend.visible = value + show_numbering.visible = value + plot_errors_button.visible = image_has_gt_shape and value + plot_displacements.visible = value + if value: + if iterations_mode.value == 0: + animation_wid.visible = True + else: + first_slider_wid.visible = True + second_slider_wid.visible = True + else: + animation_wid.visible = False + first_slider_wid.visible = False + second_slider_wid.visible = False + show_options('', toggle_show_default) + but.on_trait_change(show_options, 'value') + + # assign general plot_function + if plot_function is not None: + def plot_function_but(name): + plot_function(name, 0) + + update_but.on_click(plot_function_but) + # Here we assign plot_function() to the slider of animation_wid, as + # we didn't do it at its creation. + animation_wid.children[1].children[0].children[2].on_trait_change( + plot_function, 'value') + show_image.on_trait_change(plot_function, 'value') + plot_mode.on_trait_change(plot_function, 'value') + show_legend.on_trait_change(plot_function, 'value') + show_numbering.on_trait_change(plot_function, 'value') + + # assign plot function of errors button + if plot_errors_function is not None: + plot_errors_button.on_click(plot_errors_function) + + # assign plot function of displacements button + if plot_displacements_function is not None: + plot_displacements_button.on_click(plot_displacements_function) + + return iterations_result_wid + + +def format_iterations_result_options(iterations_result_wid, + container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=True): + r""" + Function that corrects the align (style format) of a given + iterations_result_options widget. Usage example: + iterations_result_wid = iterations_result_options() + display(iterations_result_wid) + format_iterations_result_options(iterations_result_wid) + + Parameters + ---------- + iterations_result_wid : + The widget object generated by the `iterations_result_options()` + function. + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + border_visible : `bool`, optional + Defines whether to draw the border line around the widget. + """ + # format animations options + format_animation_options( + iterations_result_wid.children[1].children[1].children[0], + index_text_width='0.5cm', container_padding=container_padding, + container_margin=container_margin, container_border=container_border, + toggle_button_font_weight=toggle_button_font_weight, + border_visible=False) + + # align displacement button and drop down menu + iterations_result_wid.children[2].children[5].add_class('align-center') + iterations_result_wid.children[2].children[5].children[1].set_css('width', + '2.5cm') + iterations_result_wid.children[2].children[5].children[1].set_css('height', + '2cm') + + # align options + iterations_result_wid.children[2].remove_class('vbox') + iterations_result_wid.children[2].add_class('hbox') + iterations_result_wid.children[2].add_class('align-start') + iterations_result_wid.children[2].children[0].set_css('margin-right', + '20px') + iterations_result_wid.children[2].children[1].set_css('margin-right', + '10px') + iterations_result_wid.children[2].children[2].set_css('margin-right', + '20px') + iterations_result_wid.children[2].children[3].set_css('margin-right', + '20px') + iterations_result_wid.children[2].children[4].set_css('margin-right', + '10px') + + # align sliders + iterations_result_wid.children[1].children[1].add_class('align-end') + iterations_result_wid.children[1].children[1].set_css('margin-bottom', + '20px') + + # align sliders and iterations_mode + iterations_result_wid.children[1].remove_class('vbox') + iterations_result_wid.children[1].add_class('hbox') + iterations_result_wid.children[1].add_class('align-start') + + # set toggle button font bold + iterations_result_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + iterations_result_wid.children[1].set_css('margin-top', container_margin) + + # margin and border around container widget + iterations_result_wid.set_css('padding', container_padding) + iterations_result_wid.set_css('margin', container_margin) + if border_visible: + iterations_result_wid.set_css('border', container_border) + + +def update_iterations_result_options(iterations_result_wid, n_iters, + image_has_gt_shape, n_points, + iter_str='iter_'): + r""" + Function that updates the state of a given iterations_result_options widget + if the number of iterations or the number of landmark points or the + image_has_gt_shape flag has changed. Usage example: + iterations_result_wid = iterations_result_options( + n_iters=50, image_has_gt_shape=True, n_points=68) + display(iterations_result_wid) + format_iterations_result_options(iterations_result_wid) + update_iterations_result_options(iterations_result_wid, n_iters=52, + image_has_gt_shape=False, n_points=68) + + Parameters + ---------- + iterations_result_wid : + The widget generated by `iterations_result_options()` function. + + n_iters : `int` + The number of iterations. + + image_has_gt_shape : `boolean` + Flag that defines whether the fitted image has a ground shape attached. + + n_points : `int` + The number of the object's landmark points. It is required by the + displacement dorp down menu. + + iter_str : `str`, optional + The str that is used in the landmark groups shapes. + E.g. if iter_str == "iter_" then the group label of iteration i has the + form "{}{}".format(iter_str, i) + """ + # if image_has_gt_shape flag has actually changed from the previous value + if image_has_gt_shape != iterations_result_wid.image_has_gt_shape: + # set the plot buttons visibility + iterations_result_wid.children[2].children[4].visible = \ + iterations_result_wid.children[0].value and image_has_gt_shape + iterations_result_wid.children[2].children[5].visible = \ + iterations_result_wid.children[0].value + # store the flag + iterations_result_wid.image_has_gt_shape = image_has_gt_shape + + # if n_points has actually changed from the previous value + if n_points != iterations_result_wid.n_points: + # change the contents of the displacement types + select_menu = OrderedDict() + select_menu['mean'] = 'mean' + select_menu['median'] = 'median' + select_menu['max'] = 'max' + select_menu['min'] = 'min' + for p in range(n_points): + select_menu["point {}".format(p + 1)] = p + iterations_result_wid.children[2].children[5].children[1].values = \ + select_menu + # store the number of points + iterations_result_wid.n_points = n_points + + # if n_iters are actually different from the previous value + if n_iters != iterations_result_wid.n_iters: + # change the iterations_result_wid output + iterations_result_wid.n_iters = n_iters + iterations_result_wid.groups = _convert_iterations_to_groups(0, 0, + iter_str) + + animation_options_wid = \ + iterations_result_wid.children[1].children[1].children[0] + # set the iterations options state + if n_iters == 1: + # set sliders values and visibility + for t in range(4): + if t == 0: + # first slider + iterations_result_wid.children[1].children[1].children[1]. \ + value = 0 + iterations_result_wid.children[1].children[1].children[1]. \ + max = 0 + iterations_result_wid.children[1].children[1].children[1]. \ + visible = False + elif t == 1: + # second slider + iterations_result_wid.children[1].children[1].children[2]. \ + value = 0 + iterations_result_wid.children[1].children[1].children[2]. \ + max = 0 + iterations_result_wid.children[1].children[1].children[2]. \ + visible = False + elif t == 2: + # animation slider + animation_options_wid.selected_index = 0 + animation_options_wid.index_max = 0 + animation_options_wid.children[1].children[0].children[2]. \ + value = 0 + animation_options_wid.children[1].children[0].children[2]. \ + max = 0 + animation_options_wid.children[1].children[0].children[2]. \ + disabled = True + animation_options_wid.children[1].children[1].children[0]. \ + children[0].disabled = True + animation_options_wid.children[1].children[1].children[0]. \ + children[1].disabled = True + animation_options_wid.children[1].children[1].children[0]. \ + children[2].disabled = True + else: + # iterations mode + iterations_result_wid.children[1].children[0].value = 0 + iterations_result_wid.groups = [iter_str + "0"] + iterations_result_wid.children[1].children[0]. \ + disabled = True + else: + # set sliders max and min values + for t in range(4): + if t == 0: + # first slider + iterations_result_wid.children[1].children[1].children[1]. \ + value = 0 + iterations_result_wid.children[1].children[1].children[1]. \ + max = n_iters - 1 + iterations_result_wid.children[1].children[1].children[1]. \ + visible = False + elif t == 1: + # second slider + iterations_result_wid.children[1].children[1].children[2]. \ + value = n_iters - 1 + iterations_result_wid.children[1].children[1].children[2]. \ + max = n_iters - 1 + iterations_result_wid.children[1].children[1].children[2]. \ + visible = False + elif t == 2: + # animation slider + animation_options_wid.children[1].children[0].children[2]. \ + value = 0 + animation_options_wid.children[1].children[0].children[2]. \ + max = n_iters - 1 + animation_options_wid.selected_index = 0 + animation_options_wid.index_max = n_iters - 1 + animation_options_wid.children[1].children[0].children[2]. \ + disabled = False + animation_options_wid.children[1].children[1].children[0]. \ + children[0].disabled = False + animation_options_wid.children[1].children[1].children[0]. \ + children[1].disabled = True + animation_options_wid.children[1].children[1].children[0]. \ + children[2].disabled = False + else: + # iterations mode + iterations_result_wid.children[1].children[0].value = 0 + iterations_result_wid.groups = [iter_str + "0"] + iterations_result_wid.children[1].children[0]. \ + disabled = False + + +def plot_options(plot_options_default, plot_function=None, + toggle_show_visible=True, toggle_show_default=True): + r""" + Creates a widget with Plot Options. Specifically, it has: + 1) A drop down menu for curve selection. + 2) A text area for the legend entry. + 3) A checkbox that controls line's visibility. + 4) A checkbox that controls markers' visibility. + 5) Options for line colour, style and width. + 6) Options for markers face colour, edge colour, size, edge width and + style. + 7) A toggle button that controls the visibility of all the above, i.e. + the plot options. + + The structure of the widgets is the following: + plot_options_wid.children = [toggle_button, options] + options.children = [curve_menu, per_curve_options_wid] + per_curve_options_wid = ipywidgets.ContainerWidget(children=[legend_entry, + line_marker_wid]) + line_marker_wid = ipywidgets.ContainerWidget(children=[line_widget, marker_widget]) + line_widget.children = [show_line_checkbox, line_options] + marker_widget.children = [show_marker_checkbox, marker_options] + line_options.children = [linestyle, linewidth, linecolour] + marker_options.children = [markerstyle, markersize, markeredgewidth, + markerfacecolour, markeredgecolour] + + The returned widget saves the selected values in the following dictionary: + plot_options_wid.selected_options + + To fix the alignment within this widget please refer to + `format_plot_options()` function. + + Parameters + ---------- + plot_options_default : list of `dict` + A list of dictionaries with the initial selected plot options per curve. + Example: + plot_options_1={'show_line':True, + 'linewidth':2, + 'linecolour':'r', + 'linestyle':'-', + 'show_marker':True, + 'markersize':20, + 'markerfacecolour':'r', + 'markeredgecolour':'b', + 'markerstyle':'o', + 'markeredgewidth':1, + 'legend_entry':'final errors'} + plot_options_2={'show_line':False, + 'linewidth':3, + 'linecolour':'r', + 'linestyle':'-', + 'show_marker':True, + 'markersize':60, + 'markerfacecolour':[0.1, 0.2, 0.3], + 'markeredgecolour':'k', + 'markerstyle':'x', + 'markeredgewidth':1, + 'legend_entry':'initial errors'} + plot_options_default = [plot_options_1, plot_options_2] + + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + + toggle_show_default : `boolean`, optional + Defines whether the options will be visible upon construction. + + toggle_show_visible : `boolean`, optional + The visibility of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # make sure that plot_options_default is a list even with one member + if not isinstance(plot_options_default, list): + plot_options_default = [plot_options_default] + + # find number of curves + n_curves = len(plot_options_default) + + # Create widgets + # toggle button + but = ipywidgets.ToggleButtonWidget(description='Plot Options', + value=toggle_show_default, + visible=toggle_show_visible) + + # select curve drop down menu + curves_dict = OrderedDict() + for k in range(n_curves): + curves_dict['Curve ' + str(k)] = k + curve_selection = ipywidgets.DropdownWidget(values=curves_dict, + value=0, + description='Select curve', + visible=n_curves > 1) + + # legend entry + legend_entry = ipywidgets.TextWidget(description='Legend entry', + value=plot_options_default[0][ + 'legend_entry']) + + # show line, show markers checkboxes + show_line = ipywidgets.CheckboxWidget(description='Show line', + value=plot_options_default[0][ + 'show_line']) + show_marker = ipywidgets.CheckboxWidget(description='Show markers', + value=plot_options_default[0][ + 'show_marker']) + + # linewidth, markersize + linewidth = ipywidgets.FloatTextWidget(description='Width', + value=plot_options_default[0][ + 'linewidth']) + markersize = ipywidgets.IntTextWidget(description='Size', + value=plot_options_default[0][ + 'markersize']) + markeredgewidth = ipywidgets.FloatTextWidget( + description='Edge width', + value=plot_options_default[0]['markeredgewidth']) + + # markerstyle + markerstyle_dict = OrderedDict() + markerstyle_dict['point'] = '.' + markerstyle_dict['pixel'] = ',' + markerstyle_dict['circle'] = 'o' + markerstyle_dict['triangle down'] = 'v' + markerstyle_dict['triangle up'] = '^' + markerstyle_dict['triangle left'] = '<' + markerstyle_dict['triangle right'] = '>' + markerstyle_dict['tri down'] = '1' + markerstyle_dict['tri up'] = '2' + markerstyle_dict['tri left'] = '3' + markerstyle_dict['tri right'] = '4' + markerstyle_dict['octagon'] = '8' + markerstyle_dict['square'] = 's' + markerstyle_dict['pentagon'] = 'p' + markerstyle_dict['star'] = '*' + markerstyle_dict['hexagon 1'] = 'h' + markerstyle_dict['hexagon 2'] = 'H' + markerstyle_dict['plus'] = '+' + markerstyle_dict['x'] = 'x' + markerstyle_dict['diamond'] = 'D' + markerstyle_dict['thin diamond'] = 'd' + markerstyle = ipywidgets.DropdownWidget(values=markerstyle_dict, + value=plot_options_default[0][ + 'markerstyle'], + description='Style') + + # linestyle + linestyle_dict = OrderedDict() + linestyle_dict['solid'] = '-' + linestyle_dict['dashed'] = '--' + linestyle_dict['dash-dot'] = '-.' + linestyle_dict['dotted'] = ':' + linestyle = ipywidgets.DropdownWidget(values=linestyle_dict, + value=plot_options_default[0][ + 'linestyle'], + description='Style') + + # colours + # do not assign the plot_function here + linecolour = colour_selection(plot_options_default[0]['linecolour'], + title='Colour') + markerfacecolour = colour_selection( + plot_options_default[0]['markerfacecolour'], + title='Face Colour') + markeredgecolour = colour_selection( + plot_options_default[0]['markeredgecolour'], + title='Edge Colour') + + # Group widgets + line_options = ipywidgets.ContainerWidget( + children=[linestyle, linewidth, linecolour]) + marker_options = ipywidgets.ContainerWidget( + children=[markerstyle, markersize, + markeredgewidth, + markerfacecolour, + markeredgecolour]) + line_wid = ipywidgets.ContainerWidget(children=[show_line, line_options]) + marker_wid = ipywidgets.ContainerWidget( + children=[show_marker, marker_options]) + line_options_options_wid = ipywidgets.ContainerWidget( + children=[line_wid, marker_wid]) + options_wid = ipywidgets.ContainerWidget(children=[legend_entry, + line_options_options_wid]) + options_and_curve_wid = ipywidgets.ContainerWidget( + children=[curve_selection, + options_wid]) + plot_options_wid = ipywidgets.ContainerWidget( + children=[but, options_and_curve_wid]) + + # initialize output + plot_options_wid.selected_options = plot_options_default + + # line options visibility + def line_options_visible(name, value): + linestyle.disabled = not value + linewidth.disabled = not value + linecolour.children[0].disabled = not value + linecolour.children[1].children[0].disabled = not value + linecolour.children[1].children[1].disabled = not value + linecolour.children[1].children[2].disabled = not value + show_line.on_trait_change(line_options_visible, 'value') + + # marker options visibility + def marker_options_visible(name, value): + markerstyle.disabled = not value + markersize.disabled = not value + markeredgewidth.disabled = not value + markerfacecolour.children[0].disabled = not value + markerfacecolour.children[1].children[0].disabled = not value + markerfacecolour.children[1].children[1].disabled = not value + markerfacecolour.children[1].children[2].disabled = not value + markeredgecolour.children[0].disabled = not value + markeredgecolour.children[1].children[0].disabled = not value + markeredgecolour.children[1].children[1].disabled = not value + markeredgecolour.children[1].children[2].disabled = not value + show_marker.on_trait_change(marker_options_visible, 'value') + + # function that gets colour selection + def get_colour(colour_wid): + if colour_wid.children[0].value == 'custom': + return [float(colour_wid.children[1].children[0].value), + float(colour_wid.children[1].children[1].value), + float(colour_wid.children[1].children[2].value)] + else: + return colour_wid.children[0].value + + # assign options + def save_legend_entry(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'legend_entry'] = str(value) + + legend_entry.on_trait_change(save_legend_entry, 'value') + + def save_show_line(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'show_line'] = value + + show_line.on_trait_change(save_show_line, 'value') + + def save_show_marker(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'show_marker'] = value + + show_marker.on_trait_change(save_show_marker, 'value') + + def save_linewidth(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'linewidth'] = float(value) + + linewidth.on_trait_change(save_linewidth, 'value') + + def save_linestyle(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'linestyle'] = value + + linestyle.on_trait_change(save_linestyle, 'value') + + def save_markersize(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markersize'] = int(value) + + markersize.on_trait_change(save_markersize, 'value') + + def save_markeredgewidth(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markeredgewidth'] = float(value) + + markeredgewidth.on_trait_change(save_markeredgewidth, 'value') + + def save_markerstyle(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markerstyle'] = value + + markerstyle.on_trait_change(save_markerstyle, 'value') + + def save_linecolour(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'linecolour'] = get_colour(linecolour) + + linecolour.children[0].on_trait_change(save_linecolour, 'value') + linecolour.children[1].children[0].on_trait_change(save_linecolour, 'value') + linecolour.children[1].children[1].on_trait_change(save_linecolour, 'value') + linecolour.children[1].children[2].on_trait_change(save_linecolour, 'value') + + def save_markerfacecolour(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markerfacecolour'] = get_colour(markerfacecolour) + + markerfacecolour.children[0].on_trait_change(save_markerfacecolour, 'value') + markerfacecolour.children[1].children[0].on_trait_change( + save_markerfacecolour, 'value') + markerfacecolour.children[1].children[1].on_trait_change( + save_markerfacecolour, 'value') + markerfacecolour.children[1].children[2].on_trait_change( + save_markerfacecolour, 'value') + + def save_markeredgecolour(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markeredgecolour'] = get_colour(markeredgecolour) + + markeredgecolour.children[0].on_trait_change(save_markeredgecolour, 'value') + markeredgecolour.children[1].children[0].on_trait_change( + save_markeredgecolour, 'value') + markeredgecolour.children[1].children[1].on_trait_change( + save_markeredgecolour, 'value') + markeredgecolour.children[1].children[2].on_trait_change( + save_markeredgecolour, 'value') + + # set correct value to slider when drop down menu value changes + def set_options(name, value): + legend_entry.value = plot_options_wid.selected_options[value][ + 'legend_entry'] + show_line.value = plot_options_wid.selected_options[value]['show_line'] + show_marker.value = plot_options_wid.selected_options[value][ + 'show_marker'] + linewidth.value = plot_options_wid.selected_options[value]['linewidth'] + linestyle.value = plot_options_wid.selected_options[value]['linestyle'] + markersize.value = plot_options_wid.selected_options[value][ + 'markersize'] + markerstyle.value = plot_options_wid.selected_options[value][ + 'markerstyle'] + markeredgewidth.value = plot_options_wid.selected_options[value][ + 'markeredgewidth'] + default_colour = plot_options_wid.selected_options[value]['linecolour'] + if not isinstance(default_colour, str): + r_val = default_colour[0] + g_val = default_colour[1] + b_val = default_colour[2] + default_colour = 'custom' + linecolour.children[1].children[0].value = r_val + linecolour.children[1].children[1].value = g_val + linecolour.children[1].children[2].value = b_val + linecolour.children[0].value = default_colour + default_colour = plot_options_wid.selected_options[value][ + 'markerfacecolour'] + if not isinstance(default_colour, str): + r_val = default_colour[0] + g_val = default_colour[1] + b_val = default_colour[2] + default_colour = 'custom' + markerfacecolour.children[1].children[0].value = r_val + markerfacecolour.children[1].children[1].value = g_val + markerfacecolour.children[1].children[2].value = b_val + markerfacecolour.children[0].value = default_colour + default_colour = plot_options_wid.selected_options[value][ + 'markeredgecolour'] + if not isinstance(default_colour, str): + r_val = default_colour[0] + g_val = default_colour[1] + b_val = default_colour[2] + default_colour = 'custom' + markeredgecolour.children[1].children[0].value = r_val + markeredgecolour.children[1].children[1].value = g_val + markeredgecolour.children[1].children[2].value = b_val + markeredgecolour.children[0].value = default_colour + curve_selection.on_trait_change(set_options, 'value') + + # Toggle button function + def toggle_fun(name, value): + options_and_curve_wid.visible = value + toggle_fun('', toggle_show_default) + but.on_trait_change(toggle_fun, 'value') + + # assign plot_function + if plot_function is not None: + legend_entry.on_trait_change(plot_function, 'value') + show_line.on_trait_change(plot_function, 'value') + linestyle.on_trait_change(plot_function, 'value') + linewidth.on_trait_change(plot_function, 'value') + show_marker.on_trait_change(plot_function, 'value') + markerstyle.on_trait_change(plot_function, 'value') + markeredgewidth.on_trait_change(plot_function, 'value') + markersize.on_trait_change(plot_function, 'value') + linecolour.children[0].on_trait_change(plot_function, 'value') + linecolour.children[1].children[0].on_trait_change(plot_function, + 'value') + linecolour.children[1].children[1].on_trait_change(plot_function, + 'value') + linecolour.children[1].children[2].on_trait_change(plot_function, + 'value') + markerfacecolour.children[0].on_trait_change(plot_function, 'value') + markerfacecolour.children[1].children[0].on_trait_change(plot_function, + 'value') + markerfacecolour.children[1].children[1].on_trait_change(plot_function, + 'value') + markerfacecolour.children[1].children[2].on_trait_change(plot_function, + 'value') + markeredgecolour.children[0].on_trait_change(plot_function, 'value') + markeredgecolour.children[1].children[0].on_trait_change(plot_function, + 'value') + markeredgecolour.children[1].children[1].on_trait_change(plot_function, + 'value') + markeredgecolour.children[1].children[2].on_trait_change(plot_function, + 'value') + + return plot_options_wid + + +def format_plot_options(plot_options_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', border_visible=True, + suboptions_border_visible=True): + r""" + Function that corrects the align (style format) of a given figure_options + widget. Usage example: + plot_options_wid = plot_options() + display(plot_options_wid) + format_plot_options(figure_options_wid) + + Parameters + ---------- + plot_options_wid : + The widget object generated by the `figure_options()` function. + + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + + border_visible : `boolean`, optional + Defines whether to draw the border line around the widget. + + suboptions_border_visible : `boolean`, optional + Defines whether to draw the border line around the per curve options. + """ + # align line options with checkbox + plot_options_wid.children[1].children[1].children[1].children[0]. \ + add_class('align-end') + + # align marker options with checkbox + plot_options_wid.children[1].children[1].children[1].children[1]. \ + add_class('align-end') + + # set text boxes width + plot_options_wid.children[1].children[1].children[1].children[0].children[ + 1].children[1]. \ + set_css('width', '1cm') + plot_options_wid.children[1].children[1].children[1].children[1].children[ + 1].children[1]. \ + set_css('width', '1cm') + plot_options_wid.children[1].children[1].children[1].children[1].children[ + 1].children[2]. \ + set_css('width', '1cm') + + # align line and marker options + plot_options_wid.children[1].children[1].children[1].remove_class('vbox') + plot_options_wid.children[1].children[1].children[1].add_class('hbox') + if suboptions_border_visible: + plot_options_wid.children[1].children[1].set_css('margin', + container_margin) + plot_options_wid.children[1].children[1].set_css('border', + container_border) + + # align curve selection with line and marker options + plot_options_wid.children[1].add_class('align-start') + + # format colour options + format_colour_selection( + plot_options_wid.children[1].children[1].children[1].children[ + 0].children[1].children[2]) + format_colour_selection( + plot_options_wid.children[1].children[1].children[1].children[ + 1].children[1].children[3]) + format_colour_selection( + plot_options_wid.children[1].children[1].children[1].children[ + 1].children[1].children[4]) + + # set toggle button font bold + plot_options_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + + # margin and border around container widget + plot_options_wid.set_css('padding', container_padding) + plot_options_wid.set_css('margin', container_margin) + if border_visible: + plot_options_wid.set_css('border', container_border) From 1648ebbd70dddb17ca7ffa4e8fb09c8dfe71be13 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Mon, 29 Dec 2014 11:50:08 +0000 Subject: [PATCH 09/15] fixes update_final_result_options --- menpofit/visualize/widgets/base.py | 11 - menpofit/visualize/widgets/options.py | 582 +++++++++++++------------- 2 files changed, 293 insertions(+), 300 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index e81ffe0..b1feb25 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -13,19 +13,8 @@ landmark_options, format_landmark_options, info_print, format_info_print, - model_parameters, - format_model_parameters, - update_model_parameters, - final_result_options, - format_final_result_options, - update_final_result_options, - iterations_result_options, - format_iterations_result_options, - update_iterations_result_options, animation_options, format_animation_options, - plot_options, - format_plot_options, save_figure_options, format_save_figure_options) from menpo.visualize.widgets.tools import logo, format_logo diff --git a/menpofit/visualize/widgets/options.py b/menpofit/visualize/widgets/options.py index a4646e2..210e0ae 100644 --- a/menpofit/visualize/widgets/options.py +++ b/menpofit/visualize/widgets/options.py @@ -2,7 +2,9 @@ from menpo.visualize.widgets.tools import (colour_selection, format_colour_selection) -from menpo.visualize.widgets.options import _compare_groups_and_labels +from menpo.visualize.widgets.options import (animation_options, + format_animation_options, + _compare_groups_and_labels) def model_parameters(n_params, plot_function=None, params_str='', @@ -575,56 +577,43 @@ def show_options(name, value): w.on_trait_change(plot_function, 'value') -def iterations_result_options(n_iters, image_has_gt_shape, n_points, +def iterations_result_options(iterations_result_options_default, plot_function=None, plot_errors_function=None, plot_displacements_function=None, - iter_str='iter_', title='Iterations Result', - show_image_default=True, - subplots_enabled_default=False, - numbering_default=False, - legend_default=True, toggle_show_default=True, + title='Iterations Result', + toggle_show_default=True, toggle_show_visible=True): r""" Creates a widget with Iterations Result Options. Specifically, it has: 1) Two radio buttons that select an options mode, depending on whether - the user wants to visualize iterations in "Animation" or "Static" mode. - 2) If mode is "Animation", an animation options widget appears. - If mode is "Static", the iterations range is selected by two + the user wants to visualize iterations in ``Animation`` or ``Static`` + mode. + 2) If mode is ``Animation``, an animation options widget appears. + If mode is ``Static``, the iterations range is selected by two sliders and there is an update plot button. 3) A checkbox that controls the visibility of the image. 4) A set of radio buttons that define whether subplots are enabled. - 5) A checkbox that controls the legend's visibility. - 6) A checkbox that controls the numbering visibility. - 7) A button to plot the error evolution. - 8) A button to plot the landmark points' displacement. - 9) A drop down menu to select which displacement to plot. - 10) A toggle button that controls the visibility of all the above, i.e. + 5) A button to plot the error evolution. + 6) A button to plot the landmark points' displacement. + 7) A drop down menu to select which displacement to plot. + 8) A toggle button that controls the visibility of all the above, i.e. the final result options. The structure of the widgets is the following: - iterations_result_wid.children = [toggle_button, - iterations_mode_and_sliders, - options] + iterations_result_wid.children = [toggle_button, all_options] + all_options.children = [iterations_mode_and_sliders, options] iterations_mode_and_sliders.children = [iterations_mode_radio_buttons, all_sliders] all_sliders.children = [animation_slider, first_slider, second_slider, - update_button] - options.children = [plot_mode_radio_buttons, show_image_checkbox, - show_legend_checkbox, plot_errors_button, - plot_displacements, show_numbering_checkbox] + update_and_axes] + update_and_axes.children = [same_axes_checkbox, update_button] + options.children = [render_image_checkbox, plot_errors_button, + plot_displacements] plot_displacements.children = [plot_displacements_button, plot_displacements_drop_down_menu] The returned widget saves the selected values in the following fields: - iterations_result_wid.groups - iterations_result_wid.image_has_gt_shape - iterations_result_wid.n_iters - iterations_result_wid.n_points - iterations_result_wid.show_image - iterations_result_wid.subplots_enabled - iterations_result_wid.legend_enabled - iterations_result_wid.numbering_enabled - iterations_result_wid.displacement_type + iterations_result_wid.selected_values To fix the alignment within this widget please refer to `format_iterations_result_options()` function. @@ -634,13 +623,16 @@ def iterations_result_options(n_iters, image_has_gt_shape, n_points, Parameters ---------- - n_iters : `int` - The number of iterations. - image_has_gt_shape : `bool` - Flag that defines whether the fitted image has a ground shape attached. - n_points : `int` - The number of the object's landmark points. It is required by the - displacement dorp down menu. + iterations_result_options_default : `dict` + The default options. For example: + iterations_result_options_default = {'n_iters': 10, + 'image_has_gt_shape': True, + 'n_points': 68, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} plot_function : `function` or None, optional The plot function that is executed when a widgets' value changes. If None, then nothing is assigned. @@ -652,21 +644,8 @@ def iterations_result_options(n_iters, image_has_gt_shape, n_points, The plot function that is executed when the 'Plot Displacements' button is pressed. If None, then nothing is assigned. - iter_str : `str`, optional - The str that is used in the landmark groups shapes. - E.g. if iter_str == "iter_" then the group label of iteration i has the - form "{}{}".format(iter_str, i) title : `str`, optional The title of the widget printed at the toggle button. - show_image_default : `bool`, optional - The initial value of the image's visibility checkbox. - subplots_enabled_default : `bool`, optional - The initial value of the plot options' radio buttons that determine - whether a single plot or subplots will be used. - legend_default : `bool`, optional - The initial value of the legend's visibility checkbox. - numbering_default : `bool`, optional - The initial value of the numbering visibility checkbox. toggle_show_default : `bool`, optional Defines whether the options will be visible upon construction. toggle_show_visible : `bool`, optional @@ -678,36 +657,35 @@ def iterations_result_options(n_iters, image_has_gt_shape, n_points, value=toggle_show_default, visible=toggle_show_visible) iterations_mode = ipywidgets.RadioButtonsWidget( - values={'Animation': 0, 'Static': 1}, - value=0, - description='Iterations mode:', - visible=toggle_show_default) + values={'Animation': 'animation', 'Static': 'static'}, + value='animation', description='Mode:', visible=toggle_show_default) # Don't assign the plot function to the animation_wid at this point. We # first need to assign the get_groups function and then the plot_function() # for synchronization reasons. - animation_wid = animation_options(index_min_val=0, - index_max_val=n_iters - 1, - plot_function=None, - update_function=None, - index_step=1, index_default=0, - index_description='Iteration', - index_style='slider', - loop_default=False, interval_default=0.2, - toggle_show_default=toggle_show_default, - toggle_show_visible=False) - first_slider_wid = ipywidgets.IntSliderWidget(min=0, max=n_iters - 1, - step=1, - value=0, description='From', - visible=False) - second_slider_wid = ipywidgets.IntSliderWidget(min=0, max=n_iters - 1, - step=1, - value=n_iters - 1, - description='To', - visible=False) + index_selection_default = { + 'min': 0, 'max': iterations_result_options_default['n_iters'] - 1, + 'step': 1, 'index': 0} + animation_wid = animation_options( + index_selection_default, plot_function=None, update_function=None, + index_description='Iteration', index_style='slider', loop_default=False, + interval_default=0.2, toggle_show_default=toggle_show_default, + toggle_show_visible=False) + first_slider_wid = ipywidgets.IntSliderWidget( + min=0, max=iterations_result_options_default['n_iters'] - 1, step=1, + value=0, description='From', visible=False) + second_slider_wid = ipywidgets.IntSliderWidget( + min=0, max=iterations_result_options_default['n_iters'] - 1, step=1, + value=iterations_result_options_default['n_iters'] - 1, + description='To', visible=False) + same_axes = ipywidgets.CheckboxWidget( + description='Same axes', + value=not iterations_result_options_default['subplots_enabled'], + visible=False) update_but = ipywidgets.ButtonWidget(description='Update Plot', visible=False) - show_image = ipywidgets.CheckboxWidget(description='Show image', - value=show_image_default) + render_image = ipywidgets.CheckboxWidget( + description='Render image', + value=iterations_result_options_default['render_image']) plot_errors_button = ipywidgets.ButtonWidget(description='Plot Errors') plot_displacements_button = ipywidgets.ButtonWidget( description='Plot Displacements') @@ -716,29 +694,23 @@ def iterations_result_options(n_iters, image_has_gt_shape, n_points, dropdown_menu['median'] = 'median' dropdown_menu['max'] = 'max' dropdown_menu['min'] = 'min' - for p in range(n_points): + for p in range(iterations_result_options_default['n_points']): dropdown_menu["point {}".format(p)] = p - plot_displacements_menu = ipywidgets.SelectWidget(values=dropdown_menu, - value='mean') - plot_mode = ipywidgets.RadioButtonsWidget(description='Plot mode:', - values={'Single': False, - 'Multiple': True}) - plot_mode.value = subplots_enabled_default - show_legend = ipywidgets.CheckboxWidget(description='Show legend', - value=legend_default) - show_numbering = ipywidgets.CheckboxWidget(description='Show numbering', - value=numbering_default) + plot_displacements_menu = ipywidgets.DropdownWidget( + values=dropdown_menu, + value=iterations_result_options_default['displacement_type']) + # if just one iteration, disable multiple options - if n_iters == 1: - iterations_mode.value = 0 + if iterations_result_options_default['n_iters'] == 1: + iterations_mode.value = 'animation' iterations_mode.disabled = True first_slider_wid.disabled = True - animation_wid.children[1].children[0].children[2].disabled = True - animation_wid.children[1].children[1].children[0].children[0]. \ + animation_wid.children[1].children[0].disabled = True + animation_wid.children[1].children[1].children[0].children[0].\ disabled = True - animation_wid.children[1].children[1].children[0].children[1]. \ + animation_wid.children[1].children[1].children[0].children[1].\ disabled = True - animation_wid.children[1].children[1].children[0].children[2]. \ + animation_wid.children[1].children[1].children[0].children[2].\ disabled = True second_slider_wid.disabled = True plot_errors_button.disabled = True @@ -746,63 +718,61 @@ def iterations_result_options(n_iters, image_has_gt_shape, n_points, plot_displacements_menu.disabled = True # Group widgets + update_and_subplots = ipywidgets.ContainerWidget( + children=[same_axes, update_but]) sliders = ipywidgets.ContainerWidget( - children=[animation_wid, first_slider_wid, - second_slider_wid, update_but]) + children=[animation_wid, first_slider_wid, second_slider_wid, + update_and_subplots]) iterations_mode_and_sliders = ipywidgets.ContainerWidget( - children=[iterations_mode, - sliders]) + children=[iterations_mode, sliders]) plot_displacements = ipywidgets.ContainerWidget( - children=[plot_displacements_button, - plot_displacements_menu]) + children=[plot_displacements_button, plot_displacements_menu]) opts = ipywidgets.ContainerWidget( - children=[plot_mode, show_image, show_legend, - show_numbering, plot_errors_button, - plot_displacements]) + children=[render_image, plot_errors_button, plot_displacements]) + all_options = ipywidgets.ContainerWidget( + children=[iterations_mode_and_sliders, opts]) # Widget container - iterations_result_wid = ipywidgets.ContainerWidget(children=[ - but, iterations_mode_and_sliders, opts]) + iterations_result_wid = ipywidgets.ContainerWidget(children=[but, + all_options]) # Initialize variables - iterations_result_wid.groups = _convert_iterations_to_groups(0, 0, iter_str) - iterations_result_wid.image_has_gt_shape = image_has_gt_shape - iterations_result_wid.n_iters = n_iters - iterations_result_wid.n_points = n_points - iterations_result_wid.show_image = show_image_default - iterations_result_wid.subplots_enabled = subplots_enabled_default - iterations_result_wid.legend_enabled = legend_default - iterations_result_wid.numbering_enabled = numbering_default - iterations_result_wid.displacement_type = 'mean' + iterations_result_options_default['selected_groups'] = \ + _convert_iterations_to_groups( + 0, 0, iterations_result_options_default['iter_str']) + iterations_result_wid.selected_values = iterations_result_options_default # Define iterations mode visibility def iterations_mode_selection(name, value): - if value == 0: + if value == 'animation': # get val that needs to be assigned val = first_slider_wid.value # update visibility animation_wid.visible = True first_slider_wid.visible = False second_slider_wid.visible = False + same_axes.visible = False update_but.visible = False # set correct values - animation_wid.children[1].children[0].children[2].value = val - animation_wid.selected_index = val + animation_wid.children[1].children[0].value = val + animation_wid.selected_values['index'] = val first_slider_wid.value = 0 - second_slider_wid.value = n_iters - 1 + second_slider_wid.value = \ + iterations_result_wid.selected_values['n_iters'] - 1 else: # get val that needs to be assigned - val = animation_wid.selected_index + val = animation_wid.selected_values['index'] # update visibility animation_wid.visible = False first_slider_wid.visible = True second_slider_wid.visible = True + same_axes.visible = True update_but.visible = True # set correct values second_slider_wid.value = val first_slider_wid.value = val - animation_wid.children[1].children[0].children[2].value = 0 - animation_wid.selected_index = 0 + animation_wid.children[1].children[0].value = 0 + animation_wid.selected_values['index'] = 0 iterations_mode.on_trait_change(iterations_mode_selection, 'value') # Check first slider's value @@ -819,64 +789,59 @@ def second_slider_val(name, value): # Convert slider values to groups def get_groups(name, value): - if iterations_mode.value == 0: - iterations_result_wid.groups = _convert_iterations_to_groups( - animation_wid.selected_index, - animation_wid.selected_index, iter_str) + if iterations_mode.value == 'animation': + iterations_result_wid.selected_values['selected_groups'] = \ + _convert_iterations_to_groups( + animation_wid.selected_values['index'], + animation_wid.selected_values['index'], + iterations_result_wid.selected_values['iter_str']) else: - iterations_result_wid.groups = _convert_iterations_to_groups( - first_slider_wid.value, second_slider_wid.value, iter_str) + iterations_result_wid.selected_values['selected_groups'] = \ + _convert_iterations_to_groups( + first_slider_wid.value, second_slider_wid.value, + iterations_result_wid.selected_values['iter_str']) first_slider_wid.on_trait_change(get_groups, 'value') second_slider_wid.on_trait_change(get_groups, 'value') # assign get_groups() to the slider of animation_wid - animation_wid.children[1].children[0].children[2].\ - on_trait_change(get_groups, 'value') + animation_wid.children[1].children[0].on_trait_change(get_groups, 'value') - # Show image function - def show_image_fun(name, value): - iterations_result_wid.show_image = value - show_image.on_trait_change(show_image_fun, 'value') - - # Plot mode function - def plot_mode_fun(name, value): - iterations_result_wid.subplots_enabled = value - plot_mode.on_trait_change(plot_mode_fun, 'value') - - # Legend function - def legend_fun(name, value): - iterations_result_wid.legend_enabled = value - show_legend.on_trait_change(legend_fun, 'value') + # Render image function + def render_image_fun(name, value): + iterations_result_wid.selected_values['render_image'] = value + render_image.on_trait_change(render_image_fun, 'value') - # Numbering function - def numbering_fun(name, value): - iterations_result_wid.numbering_enabled = value - show_numbering.on_trait_change(numbering_fun, 'value') + # Same axes function + def same_axes_fun(name, value): + iterations_result_wid.selected_values['subplots_enabled'] = not value + same_axes.on_trait_change(same_axes_fun, 'value') # Displacement type function def displacement_type_fun(name, value): - iterations_result_wid.displacement_type = value + iterations_result_wid.selected_values['displacement_type'] = value plot_displacements_menu.on_trait_change(displacement_type_fun, 'value') # Toggle button function def show_options(name, value): iterations_mode.visible = value - plot_mode.visible = value - show_image.visible = value - show_legend.visible = value - show_numbering.visible = value - plot_errors_button.visible = image_has_gt_shape and value + render_image.visible = value + plot_errors_button.visible = \ + iterations_result_wid.selected_values['image_has_gt_shape'] and value plot_displacements.visible = value if value: - if iterations_mode.value == 0: + if iterations_mode.value == 'animation': animation_wid.visible = True else: first_slider_wid.visible = True second_slider_wid.visible = True + same_axes.visible = True + update_but.visible = True else: animation_wid.visible = False first_slider_wid.visible = False second_slider_wid.visible = False + same_axes.visible = False + update_but.visible = False show_options('', toggle_show_default) but.on_trait_change(show_options, 'value') @@ -888,12 +853,9 @@ def plot_function_but(name): update_but.on_click(plot_function_but) # Here we assign plot_function() to the slider of animation_wid, as # we didn't do it at its creation. - animation_wid.children[1].children[0].children[2].on_trait_change( - plot_function, 'value') - show_image.on_trait_change(plot_function, 'value') - plot_mode.on_trait_change(plot_function, 'value') - show_legend.on_trait_change(plot_function, 'value') - show_numbering.on_trait_change(plot_function, 'value') + animation_wid.children[1].children[0].on_trait_change(plot_function, + 'value') + render_image.on_trait_change(plot_function, 'value') # assign plot function of errors button if plot_errors_function is not None: @@ -937,48 +899,58 @@ def format_iterations_result_options(iterations_result_wid, """ # format animations options format_animation_options( - iterations_result_wid.children[1].children[1].children[0], + iterations_result_wid.children[1].children[0].children[1].children[0], index_text_width='0.5cm', container_padding=container_padding, container_margin=container_margin, container_border=container_border, toggle_button_font_weight=toggle_button_font_weight, border_visible=False) # align displacement button and drop down menu - iterations_result_wid.children[2].children[5].add_class('align-center') - iterations_result_wid.children[2].children[5].children[1].set_css('width', - '2.5cm') - iterations_result_wid.children[2].children[5].children[1].set_css('height', - '2cm') + iterations_result_wid.children[1].children[1].children[2].\ + remove_class('vbox') + iterations_result_wid.children[1].children[1].children[2].\ + add_class('hbox') + iterations_result_wid.children[1].children[1].children[2].\ + add_class('align-center') + iterations_result_wid.children[1].children[1].children[2].children[0].\ + set_css('margin-right', '0px') + iterations_result_wid.children[1].children[1].children[2].children[1].\ + set_css('margin-left', '0px') # align options - iterations_result_wid.children[2].remove_class('vbox') - iterations_result_wid.children[2].add_class('hbox') - iterations_result_wid.children[2].add_class('align-start') - iterations_result_wid.children[2].children[0].set_css('margin-right', - '20px') - iterations_result_wid.children[2].children[1].set_css('margin-right', - '10px') - iterations_result_wid.children[2].children[2].set_css('margin-right', - '20px') - iterations_result_wid.children[2].children[3].set_css('margin-right', - '20px') - iterations_result_wid.children[2].children[4].set_css('margin-right', - '10px') + iterations_result_wid.children[1].children[1].remove_class('vbox') + iterations_result_wid.children[1].children[1].add_class('hbox') + iterations_result_wid.children[1].children[1].add_class('align-center') + iterations_result_wid.children[1].children[1].children[0].\ + set_css('margin-right', '30px') + iterations_result_wid.children[1].children[1].children[1].\ + set_css('margin-right', '30px') + + # align update button and same axes checkbox + iterations_result_wid.children[1].children[0].children[1].children[3].\ + remove_class('vbox') + iterations_result_wid.children[1].children[0].children[1].children[3].\ + add_class('hbox') + iterations_result_wid.children[1].children[0].children[1].children[3].children[0].\ + set_css('margin-right', '20px') # align sliders - iterations_result_wid.children[1].children[1].add_class('align-end') - iterations_result_wid.children[1].children[1].set_css('margin-bottom', - '20px') + iterations_result_wid.children[1].children[0].children[1].\ + add_class('align-end') + iterations_result_wid.children[1].children[0].children[1].\ + set_css('margin-bottom', '20px') # align sliders and iterations_mode - iterations_result_wid.children[1].remove_class('vbox') - iterations_result_wid.children[1].add_class('hbox') - iterations_result_wid.children[1].add_class('align-start') + iterations_result_wid.children[1].children[0].remove_class('vbox') + iterations_result_wid.children[1].children[0].add_class('hbox') + iterations_result_wid.children[1].children[0].add_class('align-start') + + # align sliders and options + iterations_result_wid.children[1].add_class('align-end') # set toggle button font bold iterations_result_wid.children[0].set_css('font-weight', toggle_button_font_weight) - iterations_result_wid.children[1].set_css('margin-top', container_margin) # margin and border around container widget iterations_result_wid.set_css('padding', container_padding) @@ -987,156 +959,180 @@ def format_iterations_result_options(iterations_result_wid, iterations_result_wid.set_css('border', container_border) -def update_iterations_result_options(iterations_result_wid, n_iters, - image_has_gt_shape, n_points, - iter_str='iter_'): +def update_iterations_result_options(iterations_result_wid, + iterations_result_default): r""" - Function that updates the state of a given iterations_result_options widget - if the number of iterations or the number of landmark points or the - image_has_gt_shape flag has changed. Usage example: - iterations_result_wid = iterations_result_options( - n_iters=50, image_has_gt_shape=True, n_points=68) + Function that updates the state of a given iterations_result_options widget. Usage example: + iterations_result_options_default = {'n_iters': 10, + 'image_has_gt_shape': True, + 'n_points': 68, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} + iterations_result_wid = iterations_result_options(iterations_result_options_default) display(iterations_result_wid) format_iterations_result_options(iterations_result_wid) - update_iterations_result_options(iterations_result_wid, n_iters=52, - image_has_gt_shape=False, n_points=68) + iterations_result_options_default = {'n_iters': 100, + 'image_has_gt_shape': False, + 'n_points': 15, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': False, + 'subplots_enabled': False, + 'displacement_type': 'median'} + update_iterations_result_options(iterations_result_wid, iterations_result_options_default) Parameters ---------- iterations_result_wid : The widget generated by `iterations_result_options()` function. - - n_iters : `int` - The number of iterations. - - image_has_gt_shape : `boolean` - Flag that defines whether the fitted image has a ground shape attached. - - n_points : `int` - The number of the object's landmark points. It is required by the - displacement dorp down menu. - - iter_str : `str`, optional - The str that is used in the landmark groups shapes. - E.g. if iter_str == "iter_" then the group label of iteration i has the - form "{}{}".format(iter_str, i) + iterations_result_options_default : `dict` + The default options. For example: + iterations_result_options_default = {'n_iters': 10, + 'image_has_gt_shape': True, + 'n_points': 68, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} """ # if image_has_gt_shape flag has actually changed from the previous value - if image_has_gt_shape != iterations_result_wid.image_has_gt_shape: - # set the plot buttons visibility - iterations_result_wid.children[2].children[4].visible = \ - iterations_result_wid.children[0].value and image_has_gt_shape - iterations_result_wid.children[2].children[5].visible = \ - iterations_result_wid.children[0].value + if ('image_has_gt_shape' in iterations_result_default and + iterations_result_default['image_has_gt_shape'] != + iterations_result_wid.selected_values['image_has_gt_shape']): + # set the plot errors visibility + iterations_result_wid.children[1].children[1].children[1].visible = \ + (iterations_result_wid.children[0].value and + iterations_result_default['image_has_gt_shape']) # store the flag - iterations_result_wid.image_has_gt_shape = image_has_gt_shape + iterations_result_wid.selected_values['image_has_gt_shape'] = \ + iterations_result_default['image_has_gt_shape'] # if n_points has actually changed from the previous value - if n_points != iterations_result_wid.n_points: + if ('n_points' in iterations_result_default and + iterations_result_default['n_points'] != + iterations_result_wid.selected_values['n_points']): # change the contents of the displacement types select_menu = OrderedDict() select_menu['mean'] = 'mean' select_menu['median'] = 'median' select_menu['max'] = 'max' select_menu['min'] = 'min' - for p in range(n_points): + for p in range(iterations_result_default['n_points']): select_menu["point {}".format(p + 1)] = p - iterations_result_wid.children[2].children[5].children[1].values = \ - select_menu + iterations_result_wid.children[1].children[1].children[2].children[1].\ + values = select_menu # store the number of points - iterations_result_wid.n_points = n_points + iterations_result_wid.selected_values['n_points'] = \ + iterations_result_default['n_points'] + + # if displacement_type has actually changed from the previous value + if ('displacement_type' in iterations_result_default and + iterations_result_default['displacement_type'] != + iterations_result_wid.selected_values['displacement_type']): + iterations_result_wid.children[1].children[1].children[2].children[1].\ + value = iterations_result_default['displacement_type'] + + # if iter_str are actually different from the previous value + if ('iter_str' in iterations_result_default and + iterations_result_default['iter_str'] != + iterations_result_wid.selected_values['iter_str']): + iterations_result_wid.selected_values['iter_str'] = \ + iterations_result_default['iter_str'] + + # if render_image are actually different from the previous value + if ('render_image' in iterations_result_default and + iterations_result_default['render_image'] != + iterations_result_wid.selected_values['render_image']): + iterations_result_wid.children[1].children[1].children[0].value = \ + iterations_result_default['render_image'] + + # if subplots_enabled are actually different from the previous value + if ('subplots_enabled' in iterations_result_default and + iterations_result_default['subplots_enabled'] != + iterations_result_wid.selected_values['subplots_enabled']): + iterations_result_wid.children[1].children[0].children[1].children[3].children[0].value = \ + not iterations_result_default['subplots_enabled'] # if n_iters are actually different from the previous value - if n_iters != iterations_result_wid.n_iters: + if ('n_iters' in iterations_result_default and + iterations_result_default['n_iters'] != + iterations_result_wid.selected_values['n_iters']): # change the iterations_result_wid output - iterations_result_wid.n_iters = n_iters - iterations_result_wid.groups = _convert_iterations_to_groups(0, 0, - iter_str) + iterations_result_wid.selected_values['n_iters'] = \ + iterations_result_default['n_iters'] + iterations_result_wid.selected_values['selected_groups'] = \ + _convert_iterations_to_groups( + 0, 0, iterations_result_wid.selected_values['iter_str']) - animation_options_wid = \ - iterations_result_wid.children[1].children[1].children[0] + animation_options_wid = iterations_result_wid.children[1].children[0].children[1].children[0] # set the iterations options state - if n_iters == 1: + if iterations_result_default['n_iters'] == 1: # set sliders values and visibility for t in range(4): if t == 0: # first slider - iterations_result_wid.children[1].children[1].children[1]. \ - value = 0 - iterations_result_wid.children[1].children[1].children[1]. \ - max = 0 - iterations_result_wid.children[1].children[1].children[1]. \ - visible = False + iterations_result_wid.children[1].children[0].children[1].children[1].value = 0 + iterations_result_wid.children[1].children[0].children[1].children[1].max = 0 + iterations_result_wid.children[1].children[0].children[1].children[1].visible = False elif t == 1: # second slider - iterations_result_wid.children[1].children[1].children[2]. \ - value = 0 - iterations_result_wid.children[1].children[1].children[2]. \ - max = 0 - iterations_result_wid.children[1].children[1].children[2]. \ - visible = False + iterations_result_wid.children[1].children[0].children[1].children[2].value = 0 + iterations_result_wid.children[1].children[0].children[1].children[2].max = 0 + iterations_result_wid.children[1].children[0].children[1].children[2].visible = False elif t == 2: # animation slider - animation_options_wid.selected_index = 0 - animation_options_wid.index_max = 0 - animation_options_wid.children[1].children[0].children[2]. \ - value = 0 - animation_options_wid.children[1].children[0].children[2]. \ - max = 0 - animation_options_wid.children[1].children[0].children[2]. \ - disabled = True - animation_options_wid.children[1].children[1].children[0]. \ - children[0].disabled = True - animation_options_wid.children[1].children[1].children[0]. \ - children[1].disabled = True - animation_options_wid.children[1].children[1].children[0]. \ - children[2].disabled = True + animation_options_wid.selected_values['index'] = 0 + animation_options_wid.selected_values['max'] = 0 + animation_options_wid.children[1].children[0].value = 0 + animation_options_wid.children[1].children[0]. max = 0 + animation_options_wid.children[1].children[0].disabled = True + animation_options_wid.children[1].children[1].children[0].children[0].disabled = True + animation_options_wid.children[1].children[1].children[0].children[1].disabled = True + animation_options_wid.children[1].children[1].children[0].children[2].disabled = True else: # iterations mode - iterations_result_wid.children[1].children[0].value = 0 - iterations_result_wid.groups = [iter_str + "0"] - iterations_result_wid.children[1].children[0]. \ - disabled = True + iterations_result_wid.children[1].children[0].children[0].value = 'animation' + #iterations_result_wid.groups = [iter_str + "0"] + iterations_result_wid.children[1].children[0].children[0].disabled = True else: # set sliders max and min values for t in range(4): if t == 0: # first slider - iterations_result_wid.children[1].children[1].children[1]. \ - value = 0 - iterations_result_wid.children[1].children[1].children[1]. \ - max = n_iters - 1 - iterations_result_wid.children[1].children[1].children[1]. \ - visible = False + iterations_result_wid.children[1].children[0].children[1].children[1].value = 0 + iterations_result_wid.children[1].children[0].children[1].children[1].max = \ + iterations_result_default['n_iters'] - 1 + iterations_result_wid.children[1].children[0].children[1].children[1].visible = False elif t == 1: # second slider - iterations_result_wid.children[1].children[1].children[2]. \ - value = n_iters - 1 - iterations_result_wid.children[1].children[1].children[2]. \ - max = n_iters - 1 - iterations_result_wid.children[1].children[1].children[2]. \ - visible = False + iterations_result_wid.children[1].children[0].children[1].children[2].value = \ + iterations_result_default['n_iters'] - 1 + iterations_result_wid.children[1].children[0].children[1].children[2].max = \ + iterations_result_default['n_iters'] - 1 + iterations_result_wid.children[1].children[0].children[1].children[2].visible = False elif t == 2: # animation slider - animation_options_wid.children[1].children[0].children[2]. \ - value = 0 - animation_options_wid.children[1].children[0].children[2]. \ - max = n_iters - 1 - animation_options_wid.selected_index = 0 - animation_options_wid.index_max = n_iters - 1 - animation_options_wid.children[1].children[0].children[2]. \ - disabled = False - animation_options_wid.children[1].children[1].children[0]. \ - children[0].disabled = False - animation_options_wid.children[1].children[1].children[0]. \ - children[1].disabled = True - animation_options_wid.children[1].children[1].children[0]. \ - children[2].disabled = False + animation_options_wid.children[1].children[0].value = 0 + animation_options_wid.children[1].children[0].max = \ + iterations_result_default['n_iters'] - 1 + animation_options_wid.selected_values['index'] = 0 + animation_options_wid.selected_values['max'] = \ + iterations_result_default['n_iters'] - 1 + animation_options_wid.children[1].children[0].disabled = \ + False + animation_options_wid.children[1].children[1].children[0].children[0].disabled = False + animation_options_wid.children[1].children[1].children[0].children[1].disabled = True + animation_options_wid.children[1].children[1].children[0].children[2].disabled = False else: # iterations mode - iterations_result_wid.children[1].children[0].value = 0 - iterations_result_wid.groups = [iter_str + "0"] - iterations_result_wid.children[1].children[0]. \ + iterations_result_wid.children[1].children[0].children[0].\ + value = 'animation' + #iterations_result_wid.groups = [iter_str + "0"] + iterations_result_wid.children[1].children[0].children[0].\ disabled = False @@ -1624,3 +1620,11 @@ def format_plot_options(plot_options_wid, container_padding='6px', plot_options_wid.set_css('margin', container_margin) if border_visible: plot_options_wid.set_css('border', container_border) + + +def _convert_iterations_to_groups(from_iter, to_iter, iter_str): + r""" + Function that generates a list of group labels given the range bounds and + the str to be used. + """ + return ["{}{}".format(iter_str, i) for i in range(from_iter, to_iter + 1)] From 8b392b616ff005b321dde7fa2e14462d4af9aa6b Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Fri, 2 Jan 2015 21:29:13 +0000 Subject: [PATCH 10/15] adds plot_cumulative_error_distribution() in fittingresult.py --- menpofit/fittingresult.py | 259 +++++++++++++++++++++++++- menpofit/visualize/widgets/base.py | 35 ++-- menpofit/visualize/widgets/options.py | 1 + 3 files changed, 275 insertions(+), 20 deletions(-) diff --git a/menpofit/fittingresult.py b/menpofit/fittingresult.py index 74b2e20..0a2d127 100644 --- a/menpofit/fittingresult.py +++ b/menpofit/fittingresult.py @@ -984,6 +984,259 @@ def compute_cumulative_error(errors, x_axis): r""" """ n_errors = len(errors) - cumulative_error = [np.count_nonzero([errors <= x]) - for x in x_axis] - return np.array(cumulative_error) / n_errors + return [np.count_nonzero([errors <= x]) / n_errors for x in x_axis] + + +def plot_cumulative_error_distribution(errors, errors_max=0.055, + errors_step=0.005, figure_id=None, + new_figure=False, + title='Cumulative Error Distribution', + x_label='Normalized Point-to-Point Error', + y_label='Images Proportion', + legend_entries=None, render_lines=True, + line_colour=None, line_style='-', + line_width=2, render_markers=True, + marker_style='s', marker_size=10, + marker_face_colour='w', + marker_edge_colour=None, + marker_edge_width=2, render_legend=True, + legend_title=None, + legend_font_name='sans-serif', + legend_font_style='normal', + legend_font_size=10, + legend_font_weight='normal', + legend_marker_scale=1., + legend_location=2, + legend_bbox_to_anchor=(1.05, 1.), + legend_border_axes_pad=1., + legend_n_columns=1, + legend_horizontal_spacing=1., + legend_vertical_spacing=1., + legend_border=True, + legend_border_padding=0.5, + legend_shadow=False, + legend_rounded_corners=False, + render_axes=True, + axes_font_name='sans-serif', + axes_font_size=10, + axes_font_style='normal', + axes_font_weight='normal', + axes_x_limits=None, axes_y_limits=None, + figure_size=(6, 4), render_grid=True, + grid_line_style='--', + grid_line_width=0.5): + r""" + Plot the cumulative error distribution (CED) of the provided fitting errors. + + Parameters + ---------- + errors : `list` of `lists` + A `list` with `lists` of fitting errors. A separate CED curve will be + rendered for each errors `list`. + errors_max : `float`, optional + The maximum error value for which to compute the distribution. Note that + it depends on the error type. + errors_step : `float`, optional + The step of the error values for which to compute the distribution. Note + that it depends on the error type. + figure_id : `object`, optional + The id of the figure to be used. + new_figure : `bool`, optional + If ``True``, a new figure is created. + title : `str`, optional + The figure's title. + x_label : `str`, optional + The label of the horizontal axis. + y_label : `str`, optional + The label of the vertical axis. + legend_entries : `list of `str` or ``None``, optional + If `list` of `str`, it must have the same length as `errors` `list` and + each `str` will be used to name each curve. If ``None``, the CED curves + will be named as `'Curve %d'`. + render_lines : `bool` or `list` of `bool`, optional + If ``True``, the line will be rendered. If `bool`, this value will be + used for all curves. If `list`, a value must be specified for each + fitting errors curve, thus it must have the same length as `errors`. + line_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} or + ``(3, )`` `ndarray` or `list` of those or ``None``, optional + The colour of the lines. If not a `list`, this value will be + used for all curves. If `list`, a value must be specified for each + fitting errors curve, thus it must have the same length as `errors`. If + ``None``, the colours will be linearly sampled from jet colormap. + line_style : {``-``, ``--``, ``-.``, ``:``} or `list` of those, optional + The style of the lines. If not a `list`, this value will be used for all + curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + line_width : `float` or `list` of `float`, optional + The width of the lines. If `float`, this value will be used for all + curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + render_markers : `bool` or `list` of `bool`, optional + If ``True``, the markers will be rendered. If `bool`, this value will be + used for all curves. If `list`, a value must be specified for each + fitting errors curve, thus it must have the same length as `errors`. + marker_style : {``.``, ``,``, ``o``, ``v``, ``^``, ``<``, ``>``, ``+``, + ``x``, ``D``, ``d``, ``s``, ``p``, ``*``, ``h``, ``H``, + ``1``, ``2``, ``3``, ``4``, ``8``} or `list` of those, optional + The style of the markers. If not a `list`, this value will be used for + all curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + marker_size : `int` or `list` of `int`, optional + The size of the markers in points^2. If `int`, this value will be used + for all curves. If `list`, a value must be specified for each fitting + errors curve, thus it must have the same length as `errors`. + marker_face_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray` or `list` of those or ``None``, optional + The face (filling) colour of the markers. If not a `list`, this value + will be used for all curves. If `list`, a value must be specified for + each fitting errors curve, thus it must have the same length as + `errors`. If ``None``, the colours will be linearly sampled from jet + colormap. + marker_edge_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray` or `list` of those or ``None``, optional + The edge colour of the markers. If not a `list`, this value will be used + for all curves. If `list`, a value must be specified for each fitting + errors curve, thus it must have the same length as `errors`. If + ``None``, the colours will be linearly sampled from jet colormap. + marker_edge_width : `float` or `list` of `float`, optional + The width of the markers' edge. If `float`, this value will be used for + all curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + render_legend : `bool`, optional + If ``True``, the legend will be rendered. + legend_title : `str`, optional + The title of the legend. + legend_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the legend. + legend_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the legend. + legend_font_size : `int`, optional + The font size of the legend. + legend_font_weight : {``ultralight``, ``light``, ``normal``, + ``regular``, ``book``, ``medium``, ``roman``, + ``semibold``, ``demibold``, ``demi``, ``bold``, + ``heavy``, ``extra bold``, ``black``}, optional + The font weight of the legend. + legend_marker_scale : `float`, optional + The relative size of the legend markers with respect to the original + legend_location : `int`, optional + The location of the legend. The predefined values are: + + =============== === + 'best' 0 + 'upper right' 1 + 'upper left' 2 + 'lower left' 3 + 'lower right' 4 + 'right' 5 + 'center left' 6 + 'center right' 7 + 'lower center' 8 + 'upper center' 9 + 'center' 10 + =============== === + + legend_bbox_to_anchor : (`float`, `float`), optional + The bbox that the legend will be anchored. + legend_border_axes_pad : `float`, optional + The pad between the axes and legend border. + legend_n_columns : `int`, optional + The number of the legend's columns. + legend_horizontal_spacing : `float`, optional + The spacing between the columns. + legend_vertical_spacing : `float`, optional + The vertical space between the legend entries. + legend_border : `bool`, optional + If ``True``, a frame will be drawn around the legend. + legend_border_padding : `float`, optional + The fractional whitespace inside the legend border. + legend_shadow : `bool`, optional + If ``True``, a shadow will be drawn behind legend. + legend_rounded_corners : `bool`, optional + If ``True``, the frame's corners will be rounded (fancybox). + render_axes : `bool`, optional + If ``True``, the axes will be rendered. + axes_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the axes. + axes_font_size : `int`, optional + The font size of the axes. + axes_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the axes. + axes_font_weight : {``ultralight``, ``light``, ``normal``, ``regular``, + ``book``, ``medium``, ``roman``, ``semibold``, + ``demibold``, ``demi``, ``bold``, ``heavy``, + ``extra bold``, ``black``}, optional + The font weight of the axes. + axes_x_limits : (`float`, `float`) or ``None``, optional + The limits of the x axis. If ``None``, it is set to + ``(0., 'errors_max')``. + axes_y_limits : (`float`, `float`) or ``None``, optional + The limits of the y axis. If ``None``, it is set to ``(0., 1.)``. + figure_size : (`float`, `float`) or ``None``, optional + The size of the figure in inches. + render_grid : `bool`, optional + If ``True``, the grid will be rendered. + grid_line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the grid lines. + grid_line_width : `float`, optional + The width of the grid lines. + + Raises + ------ + ValueError + legend_entries list has different length than errors list + + Returns + ------- + viewer : :map:`GraphPlotter` + The viewer object. + """ + from menpo.visualize import GraphPlotter + + # create x and y axes lists + x_axis = list(np.arange(0, errors_max, errors_step)) + ceds = [compute_cumulative_error(e, x_axis) for e in errors] + + # parse legend_entries, axes_x_limits and axes_y_limits + if legend_entries is None: + legend_entries = ["Curve {}".format(k) for k in range(len(ceds))] + if len(legend_entries) != len(ceds): + raise ValueError('legend_entries list has different length than errors ' + 'list') + if axes_x_limits is None: + axes_x_limits = (0., x_axis[-1]) + if axes_y_limits is None: + axes_y_limits = (0., 1.) + + # render + return GraphPlotter(figure_id=figure_id, new_figure=new_figure, + x_axis=x_axis, y_axis=ceds, title=title, + legend_entries=legend_entries, x_label=x_label, + y_label=y_label, x_axis_limits=axes_x_limits, + y_axis_limits=axes_y_limits).render( + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, render_legend=render_legend, + legend_title=legend_title, legend_font_name=legend_font_name, + legend_font_style=legend_font_style, legend_font_size=legend_font_size, + legend_font_weight=legend_font_weight, + legend_marker_scale=legend_marker_scale, + legend_location=legend_location, + legend_bbox_to_anchor=legend_bbox_to_anchor, + legend_border_axes_pad=legend_border_axes_pad, + legend_n_columns=legend_n_columns, + legend_horizontal_spacing=legend_horizontal_spacing, + legend_vertical_spacing=legend_vertical_spacing, + legend_border=legend_border, + legend_border_padding=legend_border_padding, + legend_shadow=legend_shadow, + legend_rounded_corners=legend_rounded_corners, render_axes=render_axes, + axes_font_name=axes_font_name, axes_font_size=axes_font_size, + axes_font_style=axes_font_style, axes_font_weight=axes_font_weight, + figure_size=figure_size, render_grid=render_grid, + grid_line_style=grid_line_style, grid_line_width=grid_line_width) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index b1feb25..53b2201 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -17,11 +17,15 @@ format_animation_options, save_figure_options, format_save_figure_options) -from menpo.visualize.widgets.tools import logo, format_logo -from menpo.visualize.widgets.base import (_visualize, _extract_groups_labels) +from menpo.visualize.widgets.tools import logo +from menpo.visualize.widgets.base import _visualize as _visualize_menpo +from menpo.visualize.widgets.base import _extract_groups_labels from menpo.visualize.viewmatplotlib import (MatplotlibImageViewer2d, sample_colours_from_colourmap) +from .options import (model_parameters, format_model_parameters, + update_model_parameters) + # This glyph import is called frequently during visualisation, so we ensure # that we only import it once glyph = None @@ -536,8 +540,8 @@ def plot_function(name, value): tmp4 = viewer_options_wid.selected_values[0]['image'] new_figure_size = (tmp3['x_scale'] * figure_size[0], tmp3['y_scale'] * figure_size[1]) - renderer = _visualize( - instance, save_figure_wid.renderer[0], True, + renderer = _visualize_menpo( + instance, save_figure_wid.renderer[0], landmark_options_wid.selected_values['render_landmarks'], channel_options_wid.selected_values['image_is_masked'], channel_options_wid.selected_values['masked_enabled'], @@ -546,9 +550,8 @@ def plot_function(name, value): channel_options_wid.selected_values['glyph_block_size'], channel_options_wid.selected_values['glyph_use_negative'], channel_options_wid.selected_values['sum_enabled'], - [landmark_options_wid.selected_values['group']], - [landmark_options_wid.selected_values['with_labels']], - False, dict(), True, False, + landmark_options_wid.selected_values['group'], + landmark_options_wid.selected_values['with_labels'], tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], tmp1['line_colour'][:n_labels], tmp2['render_markers'], tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], @@ -890,8 +893,8 @@ def plot_function(name, value): tmp4 = viewer_options_wid.selected_values[0]['image'] new_figure_size = (tmp3['x_scale'] * figure_size[0], tmp3['y_scale'] * figure_size[1]) - renderer = _visualize( - instance, save_figure_wid.renderer[0], True, + renderer = _visualize_menpo( + instance, save_figure_wid.renderer[0], landmark_options_wid.selected_values['render_landmarks'], channel_options_wid.selected_values['image_is_masked'], channel_options_wid.selected_values['masked_enabled'], @@ -900,9 +903,8 @@ def plot_function(name, value): channel_options_wid.selected_values['glyph_block_size'], channel_options_wid.selected_values['glyph_use_negative'], channel_options_wid.selected_values['sum_enabled'], - [landmark_options_wid.selected_values['group']], - [landmark_options_wid.selected_values['with_labels']], - False, dict(), True, False, + landmark_options_wid.selected_values['group'], + landmark_options_wid.selected_values['with_labels'], tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], tmp1['line_colour'][:n_labels], tmp2['render_markers'], tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], @@ -1316,8 +1318,8 @@ def plot_function(name, value): tmp4 = viewer_options_wid.selected_values[0]['image'] new_figure_size = (tmp3['x_scale'] * figure_size[0], tmp3['y_scale'] * figure_size[1]) - renderer = _visualize( - instance, save_figure_wid.renderer[0], True, + renderer = _visualize_menpo( + instance, save_figure_wid.renderer[0], landmark_options_wid.selected_values['render_landmarks'], channel_options_wid.selected_values['image_is_masked'], channel_options_wid.selected_values['masked_enabled'], @@ -1326,9 +1328,8 @@ def plot_function(name, value): channel_options_wid.selected_values['glyph_block_size'], channel_options_wid.selected_values['glyph_use_negative'], channel_options_wid.selected_values['sum_enabled'], - [landmark_options_wid.selected_values['group']], - [landmark_options_wid.selected_values['with_labels']], - False, dict(), True, False, + landmark_options_wid.selected_values['group'], + landmark_options_wid.selected_values['with_labels'], tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], tmp1['line_colour'][:n_labels], tmp2['render_markers'], tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], diff --git a/menpofit/visualize/widgets/options.py b/menpofit/visualize/widgets/options.py index 210e0ae..9d3c019 100644 --- a/menpofit/visualize/widgets/options.py +++ b/menpofit/visualize/widgets/options.py @@ -856,6 +856,7 @@ def plot_function_but(name): animation_wid.children[1].children[0].on_trait_change(plot_function, 'value') render_image.on_trait_change(plot_function, 'value') + iterations_mode.on_trait_change(plot_function, 'value') # assign plot function of errors button if plot_errors_function is not None: From 89565507c7ef32bbce84c4bd7e41fcd1cb23bff2 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Sat, 3 Jan 2015 00:09:25 +0000 Subject: [PATCH 11/15] fixes plot_ced --- menpofit/fittingresult.py | 26 +- menpofit/visualize/widgets/base.py | 382 +++++++++++++++++------------ 2 files changed, 238 insertions(+), 170 deletions(-) diff --git a/menpofit/fittingresult.py b/menpofit/fittingresult.py index 0a2d127..134518b 100644 --- a/menpofit/fittingresult.py +++ b/menpofit/fittingresult.py @@ -987,8 +987,7 @@ def compute_cumulative_error(errors, x_axis): return [np.count_nonzero([errors <= x]) / n_errors for x in x_axis] -def plot_cumulative_error_distribution(errors, errors_max=0.055, - errors_step=0.005, figure_id=None, +def plot_cumulative_error_distribution(errors, error_range=None, figure_id=None, new_figure=False, title='Cumulative Error Distribution', x_label='Normalized Point-to-Point Error', @@ -1033,12 +1032,16 @@ def plot_cumulative_error_distribution(errors, errors_max=0.055, errors : `list` of `lists` A `list` with `lists` of fitting errors. A separate CED curve will be rendered for each errors `list`. - errors_max : `float`, optional - The maximum error value for which to compute the distribution. Note that - it depends on the error type. - errors_step : `float`, optional - The step of the error values for which to compute the distribution. Note - that it depends on the error type. + error_range : `list` of `float` with length 3, optional + Specifies the horizontal axis range, i.e. + + :: + + error_range[0] = min_error + error_range[1] = max_error + error_range[2] = error_step + + If ``None``, then ``'error_range = [0., 0.101, 0.005]'``. figure_id : `object`, optional The id of the figure to be used. new_figure : `bool`, optional @@ -1195,8 +1198,12 @@ def plot_cumulative_error_distribution(errors, errors_max=0.055, """ from menpo.visualize import GraphPlotter + # make sure that errors is a list even with one list member + if not isinstance(errors[0], list): + errors = [errors] + # create x and y axes lists - x_axis = list(np.arange(0, errors_max, errors_step)) + x_axis = list(np.arange(error_range[0], error_range[1], error_range[2])) ceds = [compute_cumulative_error(e, x_axis) for e in errors] # parse legend_entries, axes_x_limits and axes_y_limits @@ -1240,3 +1247,4 @@ def plot_cumulative_error_distribution(errors, errors_max=0.055, axes_font_style=axes_font_style, axes_font_weight=axes_font_weight, figure_size=figure_size, render_grid=render_grid, grid_line_style=grid_line_style, grid_line_width=grid_line_width) + diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index 53b2201..6ca1856 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -5,8 +5,6 @@ format_viewer_options, figure_options, format_figure_options, - figure_options_two_scales, - format_figure_options_two_scales, channel_options, format_channel_options, update_channel_options, @@ -2040,7 +2038,7 @@ def close_plot_ced_fun_2(name, value): figure_options_wid.children[2].value = False -def plot_ced(errors, figure_size=(9, 5), popup=False, error_type='me_norm', +def plot_ced(errors, figure_size=(6, 4), popup=False, error_type='me_norm', error_range=None, legend_entries=None, return_widget=False): r""" Widget for visualizing the cumulative error curves of the provided errors. @@ -2050,38 +2048,43 @@ def plot_ced(errors, figure_size=(9, 5), popup=False, error_type='me_norm', ----------- errors : `list` of `list` of `float` The list of errors to be used. - figure_size : (`int`, `int`), optional The initial size of the plotted figures. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - error_type : `str` ``{'me_norm', 'me', 'rmse'}``, optional + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + error_type : {``me_norm``, ``me``, ``rmse``}, optional Specifies the type of the provided errors. - error_range : `list` of `float` with length 3, optional Specifies the horizontal axis range, i.e. + + :: + error_range[0] = min_error error_range[1] = max_error error_range[2] = error_step - If None, then + + If ``None``, then + + :: + error_range = [0., 0.101, 0.005] for error_type = 'me_norm' error_range = [0., 20., 1.] for error_type = 'me' error_range = [0., 20., 1.] for error_type = 'rmse' legend_entries : `list` of `str` The entries of the legend. The list must have the same length as errors. - If None, the entries will have the form 'Curve %d'. - - return_widget : `boolean`, optional - If True, the widget object will be returned so that it can be used as a - part of a bigger widget. If False, the widget object is not returned, it - is just visualized. + If ``None``, the entries will have the form ``'Curve %d'``. + return_widget : `bool`, optional + If ``True``, the widget object will be returned so that it can be used + as part of a bigger widget. If ``False``, the widget object is not + returned, it is just visualized. """ - from menpofit.fittingresult import compute_cumulative_error + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap + from menpofit.fittingresult import plot_cumulative_error_distribution - # make sure that images is a list even with one image member + # make sure that errors is a list even with one list member if not isinstance(errors[0], list): errors = [errors] @@ -2095,137 +2098,203 @@ def plot_ced(errors, figure_size=(9, 5), popup=False, error_type='me_norm', # get horizontal axis errors x_label_initial_value = 'Error' x_axis_limit_initial_value = 0 + x_axis_step_initial_value = 0 if error_range is None: if error_type == 'me_norm': error_range = [0., 0.101, 0.005] x_axis_limit_initial_value = 0.05 + x_axis_step_initial_value = 0.005 x_label_initial_value = 'Normalized Point-to-Point Error' elif error_type == 'me' or error_type == 'rmse': error_range = [0., 20., 0.5] x_axis_limit_initial_value = 5. + x_axis_step_initial_value = 0.5 x_label_initial_value = 'Point-to-Point Error' else: x_axis_limit_initial_value = (error_range[1] + error_range[0]) / 2 - x_axis = np.arange(error_range[0], error_range[1], error_range[2]) - - # compute cumulative error curves - ceds = [compute_cumulative_error(e, x_axis) for e in errors] - x_axis = [x_axis] * len(ceds) - - # initialize plot options dictionaries and legend entries - colors = [np.random.random((3,)) for _ in range(n_curves)] - plot_options_list = [] - for k in range(n_curves): - plot_options_list.append({'show_line':True, - 'linewidth':2, - 'linecolor':colors[k], - 'linestyle':'-', - 'show_marker':True, - 'markersize':10, - 'markerfacecolor':'w', - 'markeredgecolor':colors[k], - 'markerstyle':'s', - 'markeredgewidth':1, - 'legend_entry':legend_entries[k]}) + + # initial options dictionaries + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': (0., 1.)} + legend_options = {'render_legend': True, + 'legend_title': '', + 'legend_font_name': 'sans-serif', + 'legend_font_style': 'normal', + 'legend_font_size': 10, + 'legend_font_weight': 'normal', + 'legend_marker_scale': 1., + 'legend_location': 2, + 'legend_bbox_to_anchor': (1.05, 1.), + 'legend_border_axes_pad': 1., + 'legend_n_columns': 1, + 'legend_horizontal_spacing': 1., + 'legend_vertical_spacing': 1., + 'legend_border': True, + 'legend_border_padding': 0.5, + 'legend_shadow': False, + 'legend_rounded_corners': False} + grid_options = {'render_grid': True, + 'grid_line_style': '--', + 'grid_line_width': 0.5} + + colours = sample_colours_from_colourmap(n_curves, 'jet') + viewer_options_default = [] + for i in range(n_curves): + lines_options_default = {'render_lines': True, + 'line_width': 2, + 'line_colour': [colours[i]], + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 10, + 'marker_face_colour': ['w'], + 'marker_edge_colour': [colours[i]], + 'marker_style': 's', + 'marker_edge_width': 1} + tmp = {'lines': lines_options_default, + 'markers': markers_options, + 'legend': legend_options, + 'figure': figure_options, + 'grid': grid_options} + viewer_options_default.append(tmp) # define plot function def plot_function(name, value): + import matplotlib.pyplot as plt # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) - - # get the current figure id - figure_id = save_figure_wid.figure_id + ipydisplay.clear_output(wait=True) - # plot the graph with the selected options - new_figure_id = _plot_graph( - figure_id, horizontal_axis_values=x_axis, vertical_axis_values=ceds, - plot_options_list=plot_options_wid.selected_options, - legend_visible=legend_visible.value, - grid_visible=grid_visible.value, gridlinestyle=gridlinestyle.value, - x_limit=x_axis_limit.value, y_limit=y_axis_limit.value, + # get options that need to be a list + render_lines = [] + line_colour = [] + line_style = [] + line_width = [] + render_markers = [] + marker_style = [] + marker_size = [] + marker_face_colour = [] + marker_edge_colour = [] + marker_edge_width = [] + for idx in range(n_curves): + tmp1 = viewer_options_wid.selected_values[idx]['lines'] + tmp2 = viewer_options_wid.selected_values[idx]['markers'] + render_lines.append(tmp1['render_lines']) + line_colour.append(tmp1['line_colour'][0]) + line_style.append(tmp1['line_style']) + line_width.append(tmp1['line_width']) + render_markers.append(tmp2['render_markers']) + marker_style.append(tmp2['marker_style']) + marker_size.append(tmp2['marker_size']) + marker_face_colour.append(tmp2['marker_face_colour'][0]) + marker_edge_colour.append(tmp2['marker_edge_colour'][0]) + marker_edge_width.append(tmp2['marker_edge_width']) + + # rest of options + tmp3 = viewer_options_wid.selected_values[0]['legend'] + tmp4 = viewer_options_wid.selected_values[0]['figure'] + tmp5 = viewer_options_wid.selected_values[0]['grid'] + new_figure_size = (tmp4['x_scale'] * figure_size[0], + tmp4['y_scale'] * figure_size[1]) + + # horizontal axis limits + x_axis_limits = (0, + np.arange(0, errors_max.value, errors_step.value)[-1]) + + # render + renderer = plot_cumulative_error_distribution( + errors, error_range=[0., errors_max.value, errors_step.value], + figure_id=save_figure_wid.renderer[0].figure_id, new_figure=False, title=title.value, x_label=x_label.value, y_label=y_label.value, - x_scale=fig.x_scale, y_scale=fig.y_scale, figure_size=figure_size, - axes_fontsize=axes_fontsize.value, - labels_fontsize=labels_fontsize.value) + legend_entries=str(legend_entries_wid.value).split('\n')[:n_curves], + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, + render_legend=tmp3['render_legend'], + legend_title=tmp3['legend_title'], + legend_font_name=tmp3['legend_font_name'], + legend_font_style=tmp3['legend_font_style'], + legend_font_size=tmp3['legend_font_size'], + legend_font_weight=tmp3['legend_font_weight'], + legend_marker_scale=tmp3['legend_marker_scale'], + legend_location=tmp3['legend_location'], + legend_bbox_to_anchor=tmp3['legend_bbox_to_anchor'], + legend_border_axes_pad=tmp3['legend_border_axes_pad'], + legend_n_columns=tmp3['legend_n_columns'], + legend_horizontal_spacing=tmp3['legend_horizontal_spacing'], + legend_vertical_spacing=tmp3['legend_vertical_spacing'], + legend_border=tmp3['legend_border'], + legend_border_padding=tmp3['legend_border_padding'], + legend_shadow=tmp3['legend_shadow'], + legend_rounded_corners=tmp3['legend_rounded_corners'], + render_axes=tmp4['render_axes'], + axes_font_name=tmp4['axes_font_name'], + axes_font_size=tmp4['axes_font_size'], + axes_font_style=tmp4['axes_font_style'], + axes_font_weight=tmp4['axes_font_weight'], + axes_x_limits=x_axis_limits, + axes_y_limits=viewer_options_wid.selected_values[0]['figure']['axes_y_limits'], + figure_size=new_figure_size, render_grid=tmp5['render_grid'], + grid_line_style=tmp5['grid_line_style'], + grid_line_width=tmp5['grid_line_width']) + + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # create options widgets - # x label, y label, title container - x_label = TextWidget(description='Horizontal axis label', - value=x_label_initial_value) - y_label = TextWidget(description='Vertical axis label', - value='Images Proportion') - title = TextWidget(description='Figure title', - value='Cumulative error ditribution') - labels_wid = ContainerWidget(children=[x_label, y_label, title]) - - # figure size - fig = figure_options_two_scales(plot_function, x_scale_default=1., - y_scale_default=1., coupled_default=False, - show_axes_default=True, - toggle_show_default=True, - figure_scales_bounds=(0.1, 2), - figure_scales_step=0.1, - figure_scales_visible=True, - show_axes_visible=False, - toggle_show_visible=False) - # fontsizes - labels_fontsize = FloatTextWidget(description='Labels fontsize', value=12.) - axes_fontsize = FloatTextWidget(description='Axes fontsize', value=12.) - fontsize_wid = ContainerWidget(children=[labels_fontsize, axes_fontsize]) - - # checkboxes - grid_visible = CheckboxWidget(description='Grid visible', value=False) - gridlinestyle_dict = OrderedDict() - gridlinestyle_dict['solid'] = '-' - gridlinestyle_dict['dashed'] = '--' - gridlinestyle_dict['dash-dot'] = '-.' - gridlinestyle_dict['dotted'] = ':' - gridlinestyle = DropdownWidget(values=gridlinestyle_dict, - value=':', - description='Grid style', disabled=False) - - def gridlinestyle_visibility(name, value): - gridlinestyle.disabled = not value - grid_visible.on_trait_change(gridlinestyle_visibility, 'value') - legend_visible = CheckboxWidget(description='Legend visible', value=True) - checkbox_wid = ContainerWidget(children=[grid_visible, gridlinestyle, - legend_visible]) - - # container of various options - tmp_various_wid = ContainerWidget(children=[fontsize_wid, checkbox_wid]) - various_wid = ContainerWidget(children=[fig, tmp_various_wid]) - - # axis limits - y_axis_limit = FloatSliderWidget(min=0., max=1.1, step=0.1, - description='Y axis limit', value=1.) - x_axis_limit = FloatSliderWidget(min=error_range[0] + error_range[2], - max=error_range[1], - step=error_range[2], - description='X axis limit', - value=x_axis_limit_initial_value) - axis_limits_wid = ContainerWidget(children=[x_axis_limit, y_axis_limit]) - - # accordion widget - figure_wid = AccordionWidget(children=[axis_limits_wid, labels_wid, - various_wid]) - figure_wid.set_title(0, 'Axes Limits') - figure_wid.set_title(1, 'Labels and Title') - figure_wid.set_title(2, 'Figure Size, Grid and Legend') - - # per curve options - plot_options_wid = plot_options(plot_options_list, - plot_function=plot_function, - toggle_show_visible=False, - toggle_show_default=True) - - # save figure options widget - # create figure and store its id - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + # error_range + errors_max = ipywidgets.FloatSliderWidget( + min=error_range[0] + error_range[2], max=error_range[1], + step=error_range[2], description='Error axis max', + value=x_axis_limit_initial_value) + if error_type == 'me_norm': + errors_step = ipywidgets.FloatSliderWidget( + min=0., max=0.05, step=0.001, description='Error axis step', + value=x_axis_step_initial_value) + else: + errors_step = ipywidgets.FloatSliderWidget( + min=0., max=error_range[1], step=error_range[2] / 10., + description='Error axis step', value=x_axis_step_initial_value) + error_range_wid = ipywidgets.ContainerWidget(children=[errors_max, + errors_step]) + + # legend_entries, x label, y label, title container + legend_entries_wid = ipywidgets.TextareaWidget( + description='Legend entries', value="\n".join(legend_entries)) + x_label = ipywidgets.TextWidget(description='Horizontal axis label', + value=x_label_initial_value) + y_label = ipywidgets.TextWidget(description='Vertical axis label', + value='Images Proportion') + title = ipywidgets.TextWidget(description='Figure title', + value=' ') + labels_wid = ipywidgets.ContainerWidget(children=[legend_entries_wid, + x_label, y_label, title]) + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'legend', + 'figure_two', 'grid'], + objects_names=legend_entries, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True, + labels=None) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -2233,29 +2302,24 @@ def gridlinestyle_visibility(name, value): x_label.on_trait_change(plot_function, 'value') y_label.on_trait_change(plot_function, 'value') title.on_trait_change(plot_function, 'value') - grid_visible.on_trait_change(plot_function, 'value') - gridlinestyle.on_trait_change(plot_function, 'value') - legend_visible.on_trait_change(plot_function, 'value') - y_axis_limit.on_trait_change(plot_function, 'value') - x_axis_limit.on_trait_change(plot_function, 'value') - labels_fontsize.on_trait_change(plot_function, 'value') - axes_fontsize.on_trait_change(plot_function, 'value') + legend_entries_wid.on_trait_change(plot_function, 'value') + errors_max.on_trait_change(plot_function, 'value') + errors_step.on_trait_change(plot_function, 'value') # create final widget - wid = TabWidget(children=[figure_wid, plot_options_wid, - save_figure_wid]) + wid = ipywidgets.TabWidget(children=[error_range_wid, labels_wid, + viewer_options_wid, save_figure_wid]) # create popup widget if asked if popup: - wid = PopupWidget(children=[wid], button_text='CED Menu') + wid = ipywidgets.PopupWidget(children=[wid], button_text='CED Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['Figure options', 'Per Curve options', 'Save figure'] - if n_curves == 1: - tab_titles[1] = 'Curve options' + tab_titles = ['Error axis options', 'Labels options', 'Viewer options', + 'Save figure'] if popup: for (k, tl) in enumerate(tab_titles): wid.children[0].set_title(k, tl) @@ -2265,31 +2329,27 @@ def gridlinestyle_visibility(name, value): # format options' widgets labels_wid.add_class('align-end') - axis_limits_wid.add_class('align-start') - fontsize_wid.add_class('align-end') - fontsize_wid.set_css('margin-right', '1cm') - checkbox_wid.add_class('align-end') - tmp_various_wid.remove_class('vbox') - tmp_various_wid.add_class('hbox') - format_plot_options(plot_options_wid, container_padding='1px', - container_margin='1px', - container_border='1px solid black', - toggle_button_font_weight='bold', border_visible=False, - suboptions_border_visible=True) - format_figure_options_two_scales(fig, container_padding='6px', - container_margin='6px', - container_border='1px solid black', - toggle_button_font_weight='bold', - border_visible=False) + legend_entries_wid.set_css('width', '6cm') + legend_entries_wid.set_css('height', '2cm') + x_label.set_css('width', '6cm') + y_label.set_css('width', '6cm') + title.set_css('width', '6cm') + errors_max.set_css('width', '6cm') + errors_step.set_css('width', '6cm') + format_viewer_options(viewer_options_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=False, + suboptions_border_visible=True) format_save_figure_options(save_figure_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - tab_top_margin='0cm', - border_visible=False) + tab_top_margin='0cm', border_visible=False) # Reset value to trigger initial visualization - grid_visible.value = True + title.value = 'Cumulative error distribution' # return widget object if asked if return_widget: From e304ae00355b19db184dd77ae09d019d8e5a5e4d Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Sat, 3 Jan 2015 01:00:10 +0000 Subject: [PATCH 12/15] fixes visualize_fitting_results, adds _visualize --- menpofit/visualize/widgets/base.py | 911 +++++++++++++++++++++-------- 1 file changed, 668 insertions(+), 243 deletions(-) diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index 6ca1856..9a4a238 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -3,8 +3,6 @@ from menpo.visualize.widgets.options import (viewer_options, format_viewer_options, - figure_options, - format_figure_options, channel_options, format_channel_options, update_channel_options, @@ -19,10 +17,15 @@ from menpo.visualize.widgets.base import _visualize as _visualize_menpo from menpo.visualize.widgets.base import _extract_groups_labels from menpo.visualize.viewmatplotlib import (MatplotlibImageViewer2d, - sample_colours_from_colourmap) + sample_colours_from_colourmap, + MatplotlibSubplots) from .options import (model_parameters, format_model_parameters, - update_model_parameters) + update_model_parameters, final_result_options, + format_final_result_options, update_final_result_options, + iterations_result_options, + format_iterations_result_options, + update_iterations_result_options) # This glyph import is called frequently during visualisation, so we ensure # that we only import it once @@ -1558,28 +1561,31 @@ def update_widgets(name, value): False -def visualize_fitting_results(fitting_results, figure_size=(7, 7), popup=False, - **kwargs): +def visualize_fitting_results(fitting_results, figure_size=(6, 4), + browser_style='buttons', popup=False): r""" Widget that allows browsing through a list of fitting results. Parameters ----------- fitting_results : `list` of :map:`FittingResult` or subclass - The list of fitting results to be displayed. Note that the fitting + The `list` of fitting results to be displayed. Note that the fitting results can have different attributes between them, i.e. different number of iterations, number of channels etc. - figure_size : (`int`, `int`), optional The initial size of the plotted figures. - + browser_style : {``buttons``, ``slider``}, optional + It defines whether the selector of the fitting results will have the form of + plus/minus buttons or a slider. popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage + from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap + print 'Initializing...' # make sure that fitting_results is a list even with one fitting_result if not isinstance(fitting_results, list): @@ -1592,117 +1598,200 @@ def visualize_fitting_results(fitting_results, figure_size=(7, 7), popup=False, n_fitting_results = len(fitting_results) # create dictionaries - iter_str = 'iter_' + all_groups = ['final', 'initial', 'ground', 'iterations'] groups_final_dict = dict() colour_final_dict = dict() groups_final_dict['initial'] = 'Initial shape' - colour_final_dict['initial'] = 'r' + colour_final_dict['initial'] = 'b' groups_final_dict['final'] = 'Final shape' - colour_final_dict['final'] = 'b' + colour_final_dict['final'] = 'r' groups_final_dict['ground'] = 'Ground-truth shape' colour_final_dict['ground'] = 'y' + groups_final_dict['iterations'] = 'Iterations' + colour_final_dict['iterations'] = 'r' + + # initial options dictionaries + channels_default = 0 + if fitting_results[0].fitted_image.n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': fitting_results[0].fitted_image.n_channels, + 'image_is_masked': isinstance(fitting_results[0].fitted_image, + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': False} + all_groups_keys, _ = _extract_groups_labels(fitting_results[0].fitted_image) + final_result_options_default = {'all_groups': all_groups_keys, + 'render_image': True, + 'selected_groups': ['final'], + 'subplots_enabled': True} + iterations_result_options_default = \ + {'n_iters': fitting_results[0].n_iters, + 'image_has_gt_shape': not fitting_results[0].gt_shape is None, + 'n_points': fitting_results[0].fitted_image.landmarks['final'].lms.n_points, + 'iter_str': 'iter_', + 'selected_groups': ['iter_0'], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + lines_options_default = {'render_lines': True, + 'line_width': 1, + 'line_colour': ['r'], + 'line_style': '-'} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + numbering_options = {'render_numbering': False, + 'numbers_font_name': 'serif', + 'numbers_font_size': 10, + 'numbers_font_style': 'normal', + 'numbers_font_weight': 'normal', + 'numbers_font_colour': ['k'], + 'numbers_horizontal_align': 'center', + 'numbers_vertical_align': 'bottom'} + legend_options = {'render_legend': True, + 'legend_title': '', + 'legend_font_name': 'sans-serif', + 'legend_font_style': 'normal', + 'legend_font_size': 11, + 'legend_font_weight': 'normal', + 'legend_marker_scale': 1., + 'legend_location': 2, + 'legend_bbox_to_anchor': (1.05, 1.), + 'legend_border_axes_pad': 1., + 'legend_n_columns': 1, + 'legend_horizontal_spacing': 1., + 'legend_vertical_spacing': 1., + 'legend_border': True, + 'legend_border_padding': 0.5, + 'legend_shadow': False, + 'legend_rounded_corners': True} + image_options = {'interpolation': 'bilinear', + 'alpha': 1.0} + viewer_options_default = [] + for group in all_groups: + tmp_lines = lines_options_default.copy() + tmp_lines['line_colour'] = [colour_final_dict[group]] + tmp_markers = markers_options.copy() + tmp_markers['marker_face_colour'] = [colour_final_dict[group]] + tmp = {'markers': tmp_markers, + 'lines': tmp_lines, + 'figure': figure_options, + 'legend': legend_options, + 'numbering': numbering_options, + 'image': image_options} + viewer_options_default.append(tmp) + index_selection_default = {'min': 0, + 'max': n_fitting_results - 1, + 'step': 1, + 'index': 0} # define function that plots errors curve def plot_errors_function(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] - # select figure - figure_id = plt.figure(save_figure_wid.figure_id.number) + # get figure size + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * figure_size[0], + viewer_options_wid.selected_values[0]['figure']['y_scale'] * figure_size[1]) # plot errors curve - plt.plot(range(len(fitting_results[im].errors())), - fitting_results[im].errors(), '-bo') - plt.gca().set_xlim(0, len(fitting_results[im].errors())-1) - plt.xlabel('Iteration') - plt.ylabel('Fitting Error') - plt.title("Fitting error evolution of Image {}".format(im)) - plt.grid("on") - - # set figure size - x_scale = figure_options_wid.x_scale - y_scale = figure_options_wid.y_scale - plt.gcf().set_size_inches([x_scale, y_scale] * np.asarray(figure_size)) + renderer = fitting_results[im].plot_errors( + error_type=error_type_wid.value, + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) # show figure plt.show() # save the current figure id - save_figure_wid.figure_id = figure_id + save_figure_wid.renderer[0] = renderer # define function that plots displacements curve def plot_displacements_function(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] - # select figure - figure_id = plt.figure(save_figure_wid.figure_id.number) + # get figure size + new_figure_size = (viewer_options_wid.selected_values[0]['figure'][ + 'x_scale'] * figure_size[0], + viewer_options_wid.selected_values[0]['figure'][ + 'y_scale'] * figure_size[1]) - # plot displacements curve - d_type = iterations_wid.displacement_type + # plot errors curve + d_type = iterations_wid.selected_values['displacement_type'] if (d_type == 'max' or d_type == 'min' or d_type == 'mean' or d_type == 'median'): - d_curve = fitting_results[im].displacements_stats(stat_type=d_type) + renderer = fitting_results[im].plot_displacements( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size, stat_type=d_type) else: all_displacements = fitting_results[im].displacements() d_curve = [iteration_displacements[d_type] for iteration_displacements in all_displacements] - plt.plot(range(len(d_curve)), d_curve, '-bo') - plt.gca().set_xlim(0, len(d_curve)-1) - plt.grid("on") - plt.xlabel('Iteration') - - # set labels - if d_type == 'max': - plt.ylabel('Maximum Displacement') - plt.title("Maximum displacement evolution of Image {}".format(im)) - elif d_type == 'min': - plt.ylabel('Minimum Displacement') - plt.title("Minimum displacement evolution of Image {}".format(im)) - elif d_type == 'mean': - plt.ylabel('Mean Displacement') - plt.title("Mean displacement evolution of Image {}".format(im)) - elif d_type == 'median': - plt.ylabel('Median Displacement') - plt.title("Median displacement evolution of Image {}".format(im)) - else: - plt.ylabel("Displacement of Point {}".format(d_type)) - plt.title("Point {} displacement evolution of Image {}".format( - d_type, im)) - - # set figure size - x_scale = figure_options_wid.x_scale - y_scale = figure_options_wid.y_scale - plt.gcf().set_size_inches([x_scale, y_scale] * np.asarray(figure_size)) + from menpo.visualize import GraphPlotter + ylabel = "Displacement of Point {}".format(d_type) + title = "Point {} displacement per " \ + "iteration of Image {}".format(d_type, im) + renderer = GraphPlotter( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, x_axis=range(len(d_curve)), y_axis=[d_curve], + title=title, x_label='Iteration', y_label=ylabel, + x_axis_limits=(0, len(d_curve)-1), y_axis_limits=None).render( + render_lines=True, line_colour='b', line_style='-', + line_width=2, render_markers=True, marker_style='o', + marker_size=4, marker_face_colour='b', + marker_edge_colour='k', marker_edge_width=1., + render_legend=False, render_axes=True, + axes_font_name='sans-serif', axes_font_size=10, + axes_font_style='normal', axes_font_weight='normal', + render_grid=True, grid_line_style='--', grid_line_width=0.5, + figure_size=new_figure_size) # show figure plt.show() # save the current figure id - save_figure_wid.figure_id = figure_id + save_figure_wid.renderer[0] = renderer # define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] # selected mode: final or iterations final_enabled = False @@ -1712,144 +1801,220 @@ def plot_function(name, value): # update info text widget update_info('', error_type_wid.value) - # get the current figure id - figure_id = save_figure_wid.figure_id - - # call helper _plot_figure + # get selected options if final_enabled: - new_figure_id = _plot_figure( - image=fitting_results[im].fitted_image, figure_id=figure_id, - image_enabled=final_result_wid.show_image, - landmarks_enabled=True, image_is_masked=False, - masked_enabled=False, channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=final_result_wid.groups, - with_labels=[None] * len(final_result_wid.groups), - groups_colours=colour_final_dict, - subplots_enabled=final_result_wid.subplots_enabled, - subplots_titles=groups_final_dict, image_axes_mode=True, - legend_enabled=final_result_wid.legend_enabled, - numbering_enabled=final_result_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # image object + image = fitting_results[im].fitted_image + render_image = final_result_wid.selected_values['render_image'] + # groups + groups = final_result_wid.selected_values['selected_groups'] + # subplots + subplots_enabled = final_result_wid.selected_values[ + 'subplots_enabled'] + subplots_titles = groups_final_dict + # lines and markers options + render_lines = [] + line_colour = [] + line_style = [] + line_width = [] + render_markers = [] + marker_style = [] + marker_size = [] + marker_face_colour = [] + marker_edge_colour = [] + marker_edge_width = [] + for g in groups: + group_idx = all_groups.index(g) + tmp1 = viewer_options_wid.selected_values[group_idx]['lines'] + tmp2 = viewer_options_wid.selected_values[group_idx]['markers'] + render_lines.append(tmp1['render_lines']) + line_colour.append(tmp1['line_colour']) + line_style.append(tmp1['line_style']) + line_width.append(tmp1['line_width']) + render_markers.append(tmp2['render_markers']) + marker_style.append(tmp2['marker_style']) + marker_size.append(tmp2['marker_size']) + marker_face_colour.append(tmp2['marker_face_colour']) + marker_edge_colour.append(tmp2['marker_edge_colour']) + marker_edge_width.append(tmp2['marker_edge_width']) else: - # create subplot titles dict and colours dict - groups_dict = dict() - colour_dict = dict() - cols = np.random.random([3, len(iterations_wid.groups)]) - for i, group in enumerate(iterations_wid.groups): - iter_num = group[len(iter_str)::] - groups_dict[iter_str + iter_num] = "Iteration " + iter_num - colour_dict[iter_str + iter_num] = cols[:, i] + # image object + image = fitting_results[im].iter_image + render_image = iterations_wid.selected_values['render_image'] + # groups + groups = iterations_wid.selected_values['selected_groups'] + # subplots + subplots_enabled = iterations_wid.selected_values[ + 'subplots_enabled'] + subplots_titles = dict() + iter_str = iterations_wid.selected_values['iter_str'] + for i, g in enumerate(groups): + iter_num = g[len(iter_str)::] + subplots_titles[iter_str + iter_num] = "Iteration " + iter_num + # lines and markers options + group_idx = all_groups.index('iterations') + tmp1 = viewer_options_wid.selected_values[group_idx]['lines'] + tmp2 = viewer_options_wid.selected_values[group_idx]['markers'] + render_lines = [tmp1['render_lines']] * len(groups) + line_style = [tmp1['line_style']] * len(groups) + line_width = [tmp1['line_width']] * len(groups) + render_markers = [tmp2['render_markers']] * len(groups) + marker_style = [tmp2['marker_style']] * len(groups) + marker_size = [tmp2['marker_size']] * len(groups) + marker_edge_colour = [tmp2['marker_edge_colour']] * len(groups) + marker_edge_width = [tmp2['marker_edge_width']] * len(groups) + if (subplots_enabled or + iterations_wid.children[1].children[0].children[0].value == + 'animation'): + line_colour = [tmp1['line_colour']] * len(groups) + marker_face_colour = [tmp2['marker_face_colour']] * len(groups) + else: + cols = sample_colours_from_colourmap(len(groups), 'jet') + line_colour = cols + marker_face_colour = cols - # plot - new_figure_id = _plot_figure( - image=fitting_results[im].iter_image, figure_id=figure_id, - image_enabled=iterations_wid.show_image, landmarks_enabled=True, - image_is_masked=False, masked_enabled=False, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=iterations_wid.groups, - with_labels=[None] * len(iterations_wid.groups), - groups_colours=colour_dict, - subplots_enabled=iterations_wid.subplots_enabled, - subplots_titles=groups_dict, image_axes_mode=True, - legend_enabled=iterations_wid.legend_enabled, - numbering_enabled=iterations_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + tmp1 = viewer_options_wid.selected_values[0]['numbering'] + tmp2 = viewer_options_wid.selected_values[0]['legend'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + + # call helper _visualize + renderer = _visualize( + image=image, renderer=save_figure_wid.renderer[0], + render_image=render_image, render_landmarks=True, + image_is_masked=False, masked_enabled=False, + channels=channel_options_wid.selected_values['channels'], + glyph_enabled=channel_options_wid.selected_values['glyph_enabled'], + glyph_block_size=channel_options_wid.selected_values['glyph_block_size'], + glyph_use_negative=channel_options_wid.selected_values['glyph_use_negative'], + sum_enabled=channel_options_wid.selected_values['sum_enabled'], + groups=groups, with_labels=[None] * len(groups), + subplots_enabled=subplots_enabled, subplots_titles=subplots_titles, + image_axes_mode=True, render_lines=render_lines, + line_style=line_style, line_width=line_width, + line_colour=line_colour, render_markers=render_markers, + marker_style=marker_style, marker_size=marker_size, + marker_edge_width=marker_edge_width, + marker_edge_colour=marker_edge_colour, + marker_face_colour=marker_face_colour, + render_numbering=tmp1['render_numbering'], + numbers_horizontal_align=tmp1['numbers_horizontal_align'], + numbers_vertical_align=tmp1['numbers_vertical_align'], + numbers_font_name=tmp1['numbers_font_name'], + numbers_font_size=tmp1['numbers_font_size'], + numbers_font_style=tmp1['numbers_font_style'], + numbers_font_weight=tmp1['numbers_font_weight'], + numbers_font_colour=tmp1['numbers_font_colour'], + render_legend=tmp2['render_legend'], + legend_title=tmp2['legend_title'], + legend_font_name=tmp2['legend_font_name'], + legend_font_style=tmp2['legend_font_style'], + legend_font_size=tmp2['legend_font_size'], + legend_font_weight=tmp2['legend_font_weight'], + legend_marker_scale=tmp2['legend_marker_scale'], + legend_location=tmp2['legend_location'], + legend_bbox_to_anchor=tmp2['legend_bbox_to_anchor'], + legend_border_axes_pad=tmp2['legend_border_axes_pad'], + legend_n_columns=tmp2['legend_n_columns'], + legend_horizontal_spacing=tmp2['legend_horizontal_spacing'], + legend_vertical_spacing=tmp2['legend_vertical_spacing'], + legend_border=tmp2['legend_border'], + legend_border_padding=tmp2['legend_border_padding'], + legend_shadow=tmp2['legend_shadow'], + legend_rounded_corners=tmp2['legend_rounded_corners'], + render_axes=tmp3['render_axes'], + axes_font_name=tmp3['axes_font_name'], + axes_font_size=tmp3['axes_font_size'], + axes_font_style=tmp3['axes_font_style'], + axes_font_weight=tmp3['axes_font_weight'], + axes_x_limits=tmp3['axes_x_limits'], + axes_y_limits=tmp3['axes_y_limits'], + interpolation=tmp4['interpolation'], + alpha=tmp4['alpha'], figure_size=new_figure_size) # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(name, value): # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] # create output str if fitting_results[im].gt_shape is not None: - info_txt = r""" - Initial error: {:.4f} - Final error: {:.4f} - {} iterations - """.format(fitting_results[im].initial_error(error_type=value), - fitting_results[im].final_error(error_type=value), - fitting_results[im].n_iters) + info_wid.children[1].children[0].value = \ + "> Initial error: {:.4f}".format( + fitting_results[im].initial_error(error_type=value)) + info_wid.children[1].children[0].visible = True + info_wid.children[1].children[1].value = \ + "> Final error: {:.4f}".format( + fitting_results[im].final_error(error_type=value)) + info_wid.children[1].children[1].visible = True + info_wid.children[1].children[2].value = \ + "> {} iterations".format(fitting_results[im].n_iters) else: - info_txt = r""" - {} iterations - """.format(fitting_results[im].n_iters) + info_wid.children[1].children[0].value = '' + info_wid.children[1].children[0].visible = False + info_wid.children[1].children[1].value = '' + info_wid.children[1].children[1].visible = False + info_wid.children[1].children[2].value = "> {} iterations".format( + fitting_results[im].n_iters) if hasattr(fitting_results[im], 'n_levels'): # Multilevel result - info_txt += r""" - {} levels with downscale of {:.1f} - """.format(fitting_results[im].n_levels, - fitting_results[im].downscale) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + info_wid.children[1].children[3].value = \ + "> {} levels with downscale of {:.1f}".format( + fitting_results[im].n_levels, fitting_results[im].downscale) + info_wid.children[1].children[1].visible = True + else: + info_wid.children[1].children[1].value = '' + info_wid.children[1].children[1].visible = True # Create options widgets channel_options_wid = channel_options( - fitting_results[0].fitted_image.n_channels, - isinstance(fitting_results[0].fitted_image, MaskedImage), plot_function, - masked_default=False, toggle_show_default=True, - toggle_show_visible=False) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + channels_options_default, plot_function=plot_function, + toggle_show_default=True, toggle_show_visible=False) + + # viewer options widget + viewer_options_wid = viewer_options( + viewer_options_default, + ['markers', 'lines', 'figure_one', 'legend', 'numbering', 'image'], + objects_names=all_groups, plot_function=plot_function, + toggle_show_visible=False, toggle_show_default=True) + info_wid = info_print(n_bullets=4, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) - # Create landmark groups checkboxes - all_groups_keys, all_labels_keys = _extract_groups_labels( - fitting_results[0].fitted_image) - final_result_wid = final_result_options(all_groups_keys, plot_function, - title='Final', - show_image_default=True, - subplots_enabled_default=True, - legend_default=True, - numbering_default=False, - toggle_show_default=True, - toggle_show_visible=False) + # final result and iterations options + final_result_wid = final_result_options( + final_result_options_default, plot_function=plot_function, + title='Final', toggle_show_default=True, toggle_show_visible=False) iterations_wid = iterations_result_options( - fitting_results[0].n_iters, not fitting_results[0].gt_shape is None, - fitting_results[0].fitted_image.landmarks['final'].lms.n_points, - plot_function, plot_errors_function, plot_displacements_function, - iter_str=iter_str, title='Iterations', show_image_default=True, - subplots_enabled_default=False, legend_default=True, - numbering_default=False, toggle_show_default=True, - toggle_show_visible=False) - iterations_wid.children[2].children[4].on_click(plot_errors_function) - iterations_wid.children[2].children[5].children[0].on_click( - plot_displacements_function) + iterations_result_options_default, plot_function=plot_function, + plot_errors_function=plot_errors_function, + plot_displacements_function=plot_displacements_function, + title='Iterations', toggle_show_default=True, toggle_show_visible=False) # Create error type radio buttons error_type_values = OrderedDict() error_type_values['Point-to-point Normalized Mean Error'] = 'me_norm' error_type_values['Point-to-point Mean Error'] = 'me' error_type_values['RMS Error'] = 'rmse' - error_type_wid = RadioButtonsWidget(values=error_type_values, - value='me_norm', - description='Error type') + error_type_wid = ipywidgets.RadioButtonsWidget( + values=error_type_values, value='me_norm', description='Error type') error_type_wid.on_trait_change(update_info, 'value') - plot_ced_but = ButtonWidget(description='Plot CED', visible=show_ced) - error_wid = ContainerWidget(children=[error_type_wid, plot_ced_but]) + plot_ced_but = ipywidgets.ButtonWidget(description='Plot CED', + visible=show_ced) + error_wid = ipywidgets.ContainerWidget(children=[error_type_wid, + plot_ced_but]) # define function that updates options' widgets state def update_widgets(name, value): @@ -1859,34 +2024,40 @@ def update_widgets(name, value): # update channel options update_channel_options( channel_options_wid, - n_channels=fitting_results[value].fitted_image.n_channels, - image_is_masked=isinstance(fitting_results[value].fitted_image, - MaskedImage)) + fitting_results[value].fitted_image.n_channels, + isinstance(fitting_results[value].fitted_image, + MaskedImage)) # update final result's options update_final_result_options(final_result_wid, group_keys, plot_function) # update iterations result's options - update_iterations_result_options( - iterations_wid, fitting_results[value].n_iters, - not fitting_results[value].gt_shape is None, - fitting_results[value].fitted_image.landmarks['final'].lms.n_points, - iter_str=iter_str) + iterations_result_options_default = \ + {'n_iters': fitting_results[value].n_iters, + 'image_has_gt_shape': not fitting_results[value].gt_shape is None, + 'n_points': fitting_results[value].fitted_image.landmarks['final'].lms.n_points} + update_iterations_result_options(iterations_wid, + iterations_result_options_default) # Create final widget - options_wid = TabWidget(children=[channel_options_wid, figure_options_wid]) - result_wid = TabWidget(children=[final_result_wid, iterations_wid]) + options_wid = ipywidgets.TabWidget(children=[channel_options_wid, + viewer_options_wid]) + result_wid = ipywidgets.TabWidget(children=[final_result_wid, + iterations_wid]) result_wid.on_trait_change(plot_function, 'selected_index') if n_fitting_results > 1: # image selection slider image_number_wid = animation_options( - index_min_val=0, index_max_val=n_fitting_results-1, - plot_function=plot_function, update_function=update_widgets, - index_step=1, index_default=0, - index_description='Image Number', index_minus_description='<', - index_plus_description='>', index_style='buttons', - index_text_editable=True, loop_default=True, interval_default=0.3, + index_selection_default, plot_function=plot_function, + update_function=update_widgets, index_description='Image Number', + index_minus_description='<', index_plus_description='>', + index_style=browser_style, index_text_editable=True, + loop_default=True, interval_default=0.3, toggle_show_title='Image Options', toggle_show_default=True, toggle_show_visible=False) + # final widget + logo_wid = ipywidgets.ContainerWidget(children=[logo(), + image_number_wid]) + # define function that combines the results' tab widget with the # animation # If animation is activated and the user selects the iterations tab, @@ -1905,10 +2076,6 @@ def save_fig_tab_fun(name, value): image_number_wid.children[1].children[1].children[0].children[1].value = True # final widget - tab_wid = TabWidget(children=[info_wid, result_wid, options_wid, - error_wid, save_figure_wid]) - tab_wid.on_trait_change(save_fig_tab_fun, 'selected_index') - wid = ContainerWidget(children=[image_number_wid, tab_wid]) if show_ced: tab_titles = ['Info', 'Result', 'Options', 'CED', 'Save figure'] else: @@ -1920,14 +2087,23 @@ def save_fig_tab_fun(name, value): plot_ced_but.visible = False # final widget - wid = TabWidget(children=[info_wid, result_wid, options_wid, error_wid, - save_figure_wid]) + logo_wid = logo() tab_titles = ['Image info', 'Result', 'Options', 'Error type', 'Save figure'] button_title = 'Fitting Result Menu' + + # final widget + cont_wid = ipywidgets.TabWidget(children=[info_wid, result_wid, options_wid, + error_wid, save_figure_wid]) + if n_fitting_results > 1: + cont_wid.on_trait_change(save_fig_tab_fun, 'selected_index') + # create popup widget if asked if popup: - wid = PopupWidget(children=[wid], button_text=button_title) + wid = ipywidgets.PopupWidget(children=[logo_wid, cont_wid], + button_text=button_title) + else: + wid = ipywidgets.ContainerWidget(children=[logo_wid, cont_wid]) # invoke plot_ced widget def plot_ced_fun(name): @@ -1946,25 +2122,18 @@ def plot_ced_fun(name): errors = [fit_errors, initial_errors] # call plot_ced - plot_ced_widget = plot_ced(errors, figure_size=(9, 5), popup=True, - error_type=error_type, error_range=None, - legend_entries=['Final Fitting', - 'Initialization'], - return_widget=True) + plot_ced_widget = plot_ced( + errors, figure_size=(9, 5), popup=True, error_type=error_type, + error_range=None, legend_entries=['Final Fitting', + 'Initialization'], + return_widget=True) # If another tab is selected, then close the widget. def close_plot_ced_fun(name, value): if value != 3: plot_ced_widget.close() plot_ced_but.visible = True - if n_fitting_results > 1: - tab_wid.on_trait_change(close_plot_ced_fun, 'selected_index') - else: - if popup: - wid.children[0].on_trait_change(close_plot_ced_fun, - 'selected_index') - else: - wid.on_trait_change(close_plot_ced_fun, 'selected_index') + cont_wid.on_trait_change(close_plot_ced_fun, 'selected_index') # If another error type, then close the widget def close_plot_ced_fun_2(name, value): @@ -1974,31 +2143,22 @@ def close_plot_ced_fun_2(name, value): plot_ced_but.on_click(plot_ced_fun) # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - if popup: - if n_fitting_results > 1: - for (k, tl) in enumerate(tab_titles): - wid.children[0].children[1].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - if n_fitting_results > 1: - for (k, tl) in enumerate(tab_titles): - wid.children[1].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + for (k, tl) in enumerate(tab_titles): + wid.children[1].set_title(k, tl) + result_wid.set_title(0, 'Final Fitting') result_wid.set_title(1, 'Iterations') options_wid.set_title(0, 'Channels') - options_wid.set_title(1, 'Figure') + options_wid.set_title(1, 'Viewer') # format options' widgets if n_fitting_results > 1: - format_animation_options(image_number_wid, index_text_width='0.5cm', + wid.children[0].remove_class('vbox') + wid.children[0].add_class('hbox') + format_animation_options(image_number_wid, index_text_width='1.0cm', container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -2009,11 +2169,12 @@ def close_plot_ced_fun_2(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) + border_visible=False, + suboptions_border_visible=True) format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -2035,7 +2196,8 @@ def close_plot_ced_fun_2(name, value): tab_top_margin='0cm', border_visible=False) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False def plot_ced(errors, figure_size=(6, 4), popup=False, error_type='me_norm', @@ -2307,12 +2469,17 @@ def plot_function(name, value): errors_step.on_trait_change(plot_function, 'value') # create final widget - wid = ipywidgets.TabWidget(children=[error_range_wid, labels_wid, - viewer_options_wid, save_figure_wid]) + tab_wid = ipywidgets.TabWidget(children=[error_range_wid, labels_wid, + viewer_options_wid, + save_figure_wid]) # create popup widget if asked if popup: - wid = ipywidgets.PopupWidget(children=[wid], button_text='CED Menu') + wid = ipywidgets.PopupWidget(children=[logo(), tab_wid], + button_text='CED Menu') + else: + wid = ipywidgets.ContainerWidget(children=[logo(), tab_wid], + button_text='CED Menu') # display final widget ipydisplay.display(wid) @@ -2320,12 +2487,8 @@ def plot_function(name, value): # set final tab titles tab_titles = ['Error axis options', 'Labels options', 'Viewer options', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # format options' widgets labels_wid.add_class('align-end') @@ -2356,6 +2519,268 @@ def plot_function(name, value): return wid +def _visualize(image, renderer, render_image, render_landmarks, image_is_masked, + masked_enabled, channels, glyph_enabled, glyph_block_size, + glyph_use_negative, sum_enabled, groups, with_labels, + subplots_enabled, subplots_titles, image_axes_mode, + render_lines, line_style, line_width, line_colour, + render_markers, marker_style, marker_size, marker_edge_width, + marker_edge_colour, marker_face_colour, render_numbering, + numbers_horizontal_align, numbers_vertical_align, + numbers_font_name, numbers_font_size, numbers_font_style, + numbers_font_weight, numbers_font_colour, render_legend, + legend_title, legend_font_name, legend_font_style, + legend_font_size, legend_font_weight, legend_marker_scale, + legend_location, legend_bbox_to_anchor, legend_border_axes_pad, + legend_n_columns, legend_horizontal_spacing, + legend_vertical_spacing, legend_border, legend_border_padding, + legend_shadow, legend_rounded_corners, render_axes, + axes_font_name, axes_font_size, axes_font_style, + axes_font_weight, axes_x_limits, axes_y_limits, interpolation, + alpha, figure_size): + import matplotlib.pyplot as plt + + global glyph + if glyph is None: + from menpo.visualize.image import glyph + + # plot + if render_image: + # image will be displayed + if render_landmarks and len(groups) > 0: + # there are selected landmark groups and they will be displayed + if subplots_enabled: + # calculate subplots structure + subplots = MatplotlibSubplots()._subplot_layout(len(groups)) + # show image with landmarks + for k, group in enumerate(groups): + if subplots_enabled: + # create subplot + plt.subplot(subplots[0], subplots[1], k + 1) + if render_legend: + # set subplot's title + plt.title(subplots_titles[group], + fontname=legend_font_name, + fontstyle=legend_font_style, + fontweight=legend_font_weight, + fontsize=legend_font_size) + if glyph_enabled or sum_enabled: + # image, landmarks, masked, glyph + renderer = glyph(image, vectors_block_size=glyph_block_size, + use_negative=glyph_use_negative, + channels=channels).\ + view_landmarks( + masked=masked_enabled, group=group, + with_labels=with_labels[k], without_labels=None, + figure_id=renderer.figure_id, new_figure=False, + render_lines=render_lines[k], + line_style=line_style[k], + line_width=line_width[k], + line_colour=line_colour[k], + render_markers=render_markers[k], + marker_style=marker_style[k], + marker_size=marker_size[k], + marker_edge_width=marker_edge_width[k], + marker_edge_colour=marker_edge_colour[k], + marker_face_colour=marker_face_colour[k], + render_numbering=render_numbering, + numbers_horizontal_align=numbers_horizontal_align, + numbers_vertical_align=numbers_vertical_align, + numbers_font_name=numbers_font_name, + numbers_font_size=numbers_font_size, + numbers_font_style=numbers_font_style, + numbers_font_weight=numbers_font_weight, + numbers_font_colour=numbers_font_colour, + render_legend=render_legend and not subplots_enabled, + legend_title=legend_title, + legend_font_name=legend_font_name, + legend_font_style=legend_font_style, + legend_font_size=legend_font_size, + legend_font_weight=legend_font_weight, + legend_marker_scale=legend_marker_scale, + legend_location=legend_location, + legend_bbox_to_anchor=legend_bbox_to_anchor, + legend_border_axes_pad=legend_border_axes_pad, + legend_n_columns=legend_n_columns, + legend_horizontal_spacing=legend_horizontal_spacing, + legend_vertical_spacing=legend_vertical_spacing, + legend_border=legend_border, + legend_border_padding=legend_border_padding, + legend_shadow=legend_shadow, + legend_rounded_corners=legend_rounded_corners, + render_axes=render_axes, + axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, + interpolation=interpolation, alpha=alpha, + figure_size=figure_size) + else: + # image, landmarks, masked, not glyph + renderer = image.view_landmarks( + channels=channels, masked=masked_enabled, group=group, + with_labels=with_labels[k], without_labels=None, + figure_id=renderer.figure_id, new_figure=False, + render_lines=render_lines[k], line_style=line_style[k], + line_width=line_width[k], line_colour=line_colour[k], + render_markers=render_markers[k], + marker_style=marker_style[k], + marker_size=marker_size[k], + marker_edge_width=marker_edge_width[k], + marker_edge_colour=marker_edge_colour[k], + marker_face_colour=marker_face_colour[k], + render_numbering=render_numbering, + numbers_horizontal_align=numbers_horizontal_align, + numbers_vertical_align=numbers_vertical_align, + numbers_font_name=numbers_font_name, + numbers_font_size=numbers_font_size, + numbers_font_style=numbers_font_style, + numbers_font_weight=numbers_font_weight, + numbers_font_colour=numbers_font_colour, + render_legend=render_legend and not subplots_enabled, + legend_title=legend_title, + legend_font_name=legend_font_name, + legend_font_style=legend_font_style, + legend_font_size=legend_font_size, + legend_font_weight=legend_font_weight, + legend_marker_scale=legend_marker_scale, + legend_location=legend_location, + legend_bbox_to_anchor=legend_bbox_to_anchor, + legend_border_axes_pad=legend_border_axes_pad, + legend_n_columns=legend_n_columns, + legend_horizontal_spacing=legend_horizontal_spacing, + legend_vertical_spacing=legend_vertical_spacing, + legend_border=legend_border, + legend_border_padding=legend_border_padding, + legend_shadow=legend_shadow, + legend_rounded_corners=legend_rounded_corners, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, + interpolation=interpolation, alpha=alpha, + figure_size=figure_size) + else: + # either there are not any landmark groups selected or they won't + # be displayed + if image_is_masked: + if glyph_enabled or sum_enabled: + # image, not landmarks, masked, glyph + renderer = glyph(image, vectors_block_size=glyph_block_size, + use_negative=glyph_use_negative, + channels=channels).view( + masked=masked_enabled, render_axes=render_axes, + axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, + figure_size=figure_size, interpolation=interpolation, + alpha=alpha) + else: + # image, not landmarks, masked, not glyph + renderer = image.view( + masked=masked_enabled, channels=channels, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, figure_size=figure_size, + interpolation=interpolation, alpha=alpha) + else: + if glyph_enabled or sum_enabled: + # image, not landmarks, not masked, glyph + renderer = glyph(image, vectors_block_size=glyph_block_size, + use_negative=glyph_use_negative, + channels=channels).view( + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, figure_size=figure_size, + interpolation=interpolation, alpha=alpha) + else: + # image, not landmarks, not masked, not glyph + renderer = image.view( + channels=channels, render_axes=render_axes, + axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, figure_size=figure_size, + interpolation=interpolation, alpha=alpha) + else: + # image won't be displayed + if render_landmarks and len(groups) > 0: + # there are selected landmark groups and they will be displayed + if subplots_enabled: + # calculate subplots structure + subplots = MatplotlibSubplots()._subplot_layout(len(groups)) + # not image, landmarks + for k, group in enumerate(groups): + if subplots_enabled: + # create subplot + plt.subplot(subplots[0], subplots[1], k + 1) + if render_legend: + # set subplot's title + plt.title(subplots_titles[group], + fontname=legend_font_name, + fontstyle=legend_font_style, + fontweight=legend_font_weight, + fontsize=legend_font_size) + image.landmarks[group].lms.view( + image_view=image_axes_mode, render_lines=render_lines[k], + line_style=line_style[k], line_width=line_width[k], + line_colour=line_colour[k], + render_markers=render_markers[k], + marker_style=marker_style[k], marker_size=marker_size[k], + marker_edge_width=marker_edge_width[k], + marker_edge_colour=marker_edge_colour[k], + marker_face_colour=marker_face_colour[k], + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, axes_y_limits=axes_y_limits, + figure_size=figure_size) + if not subplots_enabled: + if len(groups) % 2 == 0: + plt.gca().invert_yaxis() + if render_legend: + # Options related to legend's font + prop = {'family': legend_font_name, + 'size': legend_font_size, + 'style': legend_font_style, + 'weight': legend_font_weight} + + # display legend on side + plt.gca().legend(groups, title=legend_title, prop=prop, + loc=legend_location, + bbox_to_anchor=legend_bbox_to_anchor, + borderaxespad=legend_border_axes_pad, + ncol=legend_n_columns, + columnspacing=legend_horizontal_spacing, + labelspacing=legend_vertical_spacing, + frameon=legend_border, + borderpad=legend_border_padding, + shadow=legend_shadow, + fancybox=legend_rounded_corners, + markerscale=legend_marker_scale) + + # show plot + plt.show() + + return renderer + + def _check_n_parameters(n_params, n_levels, max_n_params): r""" Checks the maximum number of components per level either of the shape From 4a3e0400c3897e2ecf242b62f2e34db57ddaa0fd Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Sat, 3 Jan 2015 01:24:41 +0000 Subject: [PATCH 13/15] adds view_{}_widget() to aam, atm and clm --- menpofit/aam/base.py | 109 +++++++++++++++++++++++++-------- menpofit/atm/base.py | 62 ++++++++++++++----- menpofit/clm/base.py | 39 ++++++------ menpofit/fittingresult.py | 14 +++-- menpofit/visualize/__init__.py | 1 - 5 files changed, 159 insertions(+), 66 deletions(-) diff --git a/menpofit/aam/base.py b/menpofit/aam/base.py index e97cdbd..294f89a 100644 --- a/menpofit/aam/base.py +++ b/menpofit/aam/base.py @@ -181,46 +181,107 @@ def _str_title(self): """ return 'Active Appearance Model' - def view_widget(self, n_shape_parameters=5, n_appearance_parameters=5, - parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + def view_shape_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False): r""" - Visualizes the AAM object using the - menpo.visualize.widgets.visualize_aam widget. + Visualizes the shape models of the AAM object using the + `menpo.visualize.widgets.visualize_shape_model` widget. Parameters ----------- - n_shape_parameters : `int` or `list` of `int` or None, optional + n_parameters : `int` or `list` of `int` or ``None``, optional The number of shape principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + parameters_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + """ + from menpofit.visualize import visualize_shape_model + visualize_shape_model(self.shape_models, n_parameters=n_parameters, + parameters_bounds=parameters_bounds, + figure_size=(6, 4), mode=mode, popup=popup) + + def view_appearance_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), + mode='multiple', popup=False): + r""" + Visualizes the appearance models of the AAM object using the + `menpo.visualize.widgets.visualize_appearance_model` widget. - n_appearance_parameters : `int` or `list` of `int` or None, optional + Parameters + ----------- + n_parameters : `int` or `list` of `int` or ``None``, optional The number of appearance principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + """ + from menpofit.visualize import visualize_appearance_model + visualize_appearance_model(self.appearance_models, + n_parameters=n_parameters, + parameters_bounds=parameters_bounds, + figure_size=(6, 4), mode=mode, popup=popup) + + def view_aam_widget(self, n_shape_parameters=5, n_appearance_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False): + r""" + Visualizes both the shape and appearance models of the AAM object using + the `menpo.visualize.widgets.visualize_aam` widget. - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + Parameters + ----------- + n_shape_parameters : `int` or `list` of `int` or None, optional + The number of shape principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + n_appearance_parameters : `int` or `list` of `int` or None, optional + The number of appearance principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + parameters_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ from menpofit.visualize import visualize_aam visualize_aam(self, n_shape_parameters=n_shape_parameters, n_appearance_parameters=n_appearance_parameters, - parameters_bounds=parameters_bounds, figure_size=(7, 7), + parameters_bounds=parameters_bounds, figure_size=(6, 4), mode=mode, popup=popup) def __str__(self): diff --git a/menpofit/atm/base.py b/menpofit/atm/base.py index da9ed11..139098e 100644 --- a/menpofit/atm/base.py +++ b/menpofit/atm/base.py @@ -166,8 +166,40 @@ def _str_title(self): """ return 'Active Template Model' - def view_widget(self, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), - mode='multiple', popup=False): + def view_shape_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False): + r""" + Visualizes the shape models of the AAM object using the + `menpo.visualize.widgets.visualize_shape_model` widget. + + Parameters + ----------- + n_parameters : `int` or `list` of `int` or ``None``, optional + The number of shape principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + parameters_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + """ + from menpofit.visualize import visualize_shape_model + visualize_shape_model(self.shape_models, n_parameters=n_parameters, + parameters_bounds=parameters_bounds, + figure_size=(6, 4), mode=mode, popup=popup) + + def view_atm_widget(self, n_shape_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False): r""" Visualizes the ATM object using the menpo.visualize.widgets.visualize_atm widget. @@ -177,25 +209,23 @@ def view_widget(self, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), n_shape_parameters : `int` or `list` of `int` or None, optional The number of shape principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ from menpofit.visualize import visualize_atm visualize_atm(self, n_shape_parameters=n_shape_parameters, - parameters_bounds=parameters_bounds, figure_size=(7, 7), + parameters_bounds=parameters_bounds, figure_size=(6, 4), mode=mode, popup=popup) def __str__(self): diff --git a/menpofit/clm/base.py b/menpofit/clm/base.py index 28ddc76..117a515 100644 --- a/menpofit/clm/base.py +++ b/menpofit/clm/base.py @@ -201,37 +201,36 @@ def _str_title(self): """ return 'Constrained Local Model' - def view_widget(self, n_parameters=5, parameters_bounds=(-3.0, 3.0), - mode='multiple', popup=False): + def view_shape_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False): r""" Visualizes the shape models of the CLM object using the - menpo.visualize.widgets.visualize_shape_model widget. + `menpo.visualize.widgets.visualize_shape_model` widget. Parameters ----------- - n_parameters : `int` or `list` of `int` or None, optional - The number of principal components to be used for the parameters - sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + n_parameters : `int` or `list` of `int` or ``None``, optional + The number of shape principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ from menpofit.visualize import visualize_shape_model visualize_shape_model(self.shape_models, n_parameters=n_parameters, parameters_bounds=parameters_bounds, - figure_size=(7, 7), mode=mode, popup=popup) + figure_size=(6, 4), mode=mode, popup=popup) def __str__(self): from menpofit.base import name_of_callable diff --git a/menpofit/fittingresult.py b/menpofit/fittingresult.py index 134518b..f53e4a0 100644 --- a/menpofit/fittingresult.py +++ b/menpofit/fittingresult.py @@ -198,18 +198,22 @@ def initial_error(self, error_type='me_norm'): raise ValueError('Ground truth has not been set, final error ' 'cannot be computed') - def view_widget(self, popup=False): + def view_widget(self, popup=False, browser_style='buttons'): r""" Visualizes the multilevel fitting result object using the - menpo.visualize.widgets.visualize_fitting_results widget. + `menpo.visualize.widgets.visualize_fitting_results` widget. Parameters ----------- - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + browser_style : {``buttons``, ``slider``}, optional + It defines whether the selector of the fitting results will have the + form of plus/minus buttons or a slider. """ from menpofit.visualize import visualize_fitting_results - visualize_fitting_results(self, figure_size=(7, 7), popup=popup) + visualize_fitting_results(self, figure_size=(6, 4), popup=popup, + browser_style=browser_style) def plot_errors(self, error_type='me_norm', figure_id=None, new_figure=False, render_lines=True, line_colour='b', diff --git a/menpofit/visualize/__init__.py b/menpofit/visualize/__init__.py index 729256a..90a1bf3 100644 --- a/menpofit/visualize/__init__.py +++ b/menpofit/visualize/__init__.py @@ -1,4 +1,3 @@ from .widgets import (visualize_shape_model, visualize_appearance_model, visualize_aam, visualize_atm, visualize_fitting_results, plot_ced) - From 76f2083d02ddcdc42488e1ab0a0dc1a55fe7f955 Mon Sep 17 00:00:00 2001 From: Epameinondas Antonakos Date: Tue, 27 Jan 2015 15:49:59 +0000 Subject: [PATCH 14/15] minor changes (figure_size) --- menpofit/aam/base.py | 22 ++++-- menpofit/atm/base.py | 14 ++-- menpofit/clm/base.py | 6 +- menpofit/fittingresult.py | 8 +- menpofit/visualize/widgets/base.py | 119 +++++++++++------------------ 5 files changed, 77 insertions(+), 92 deletions(-) diff --git a/menpofit/aam/base.py b/menpofit/aam/base.py index 294f89a..fde1033 100644 --- a/menpofit/aam/base.py +++ b/menpofit/aam/base.py @@ -183,7 +183,7 @@ def _str_title(self): def view_shape_models_widget(self, n_parameters=5, parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + popup=False, figure_size=(10, 8)): r""" Visualizes the shape models of the AAM object using the `menpo.visualize.widgets.visualize_shape_model` widget. @@ -206,15 +206,18 @@ def view_shape_models_widget(self, n_parameters=5, If ``'multiple'``, a slider is constructed for each parameter. popup : `bool`, optional If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_shape_model visualize_shape_model(self.shape_models, n_parameters=n_parameters, parameters_bounds=parameters_bounds, - figure_size=(6, 4), mode=mode, popup=popup) + figure_size=figure_size, mode=mode, popup=popup) def view_appearance_models_widget(self, n_parameters=5, parameters_bounds=(-3.0, 3.0), - mode='multiple', popup=False): + mode='multiple', popup=False, + figure_size=(10, 8)): r""" Visualizes the appearance models of the AAM object using the `menpo.visualize.widgets.visualize_appearance_model` widget. @@ -237,16 +240,19 @@ def view_appearance_models_widget(self, n_parameters=5, If ``'multiple'``, a slider is constructed for each parameter. popup : `bool`, optional If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_appearance_model visualize_appearance_model(self.appearance_models, n_parameters=n_parameters, parameters_bounds=parameters_bounds, - figure_size=(6, 4), mode=mode, popup=popup) + figure_size=figure_size, mode=mode, + popup=popup) def view_aam_widget(self, n_shape_parameters=5, n_appearance_parameters=5, parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + popup=False, figure_size=(10, 8)): r""" Visualizes both the shape and appearance models of the AAM object using the `menpo.visualize.widgets.visualize_aam` widget. @@ -277,12 +283,14 @@ def view_aam_widget(self, n_shape_parameters=5, n_appearance_parameters=5, If ``'multiple'``, a slider is constructed for each parameter. popup : `bool`, optional If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_aam visualize_aam(self, n_shape_parameters=n_shape_parameters, n_appearance_parameters=n_appearance_parameters, - parameters_bounds=parameters_bounds, figure_size=(6, 4), - mode=mode, popup=popup) + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, popup=popup) def __str__(self): out = "{}\n - {} training images.\n".format(self._str_title, diff --git a/menpofit/atm/base.py b/menpofit/atm/base.py index 139098e..c56a7bd 100644 --- a/menpofit/atm/base.py +++ b/menpofit/atm/base.py @@ -168,7 +168,7 @@ def _str_title(self): def view_shape_models_widget(self, n_parameters=5, parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + popup=False, figure_size=(10, 8)): r""" Visualizes the shape models of the AAM object using the `menpo.visualize.widgets.visualize_shape_model` widget. @@ -191,15 +191,17 @@ def view_shape_models_widget(self, n_parameters=5, If ``'multiple'``, a slider is constructed for each parameter. popup : `bool`, optional If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_shape_model visualize_shape_model(self.shape_models, n_parameters=n_parameters, parameters_bounds=parameters_bounds, - figure_size=(6, 4), mode=mode, popup=popup) + figure_size=figure_size, mode=mode, popup=popup) def view_atm_widget(self, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + popup=False, figure_size=(10, 8)): r""" Visualizes the ATM object using the menpo.visualize.widgets.visualize_atm widget. @@ -222,11 +224,13 @@ def view_atm_widget(self, n_shape_parameters=5, If ``'multiple'``, a slider is constructed for each parameter. popup : `bool`, optional If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_atm visualize_atm(self, n_shape_parameters=n_shape_parameters, - parameters_bounds=parameters_bounds, figure_size=(6, 4), - mode=mode, popup=popup) + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, popup=popup) def __str__(self): out = "{}\n - {} training shapes.\n".format(self._str_title, diff --git a/menpofit/clm/base.py b/menpofit/clm/base.py index 117a515..658d721 100644 --- a/menpofit/clm/base.py +++ b/menpofit/clm/base.py @@ -203,7 +203,7 @@ def _str_title(self): def view_shape_models_widget(self, n_parameters=5, parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + popup=False, figure_size=(10, 8)): r""" Visualizes the shape models of the CLM object using the `menpo.visualize.widgets.visualize_shape_model` widget. @@ -226,11 +226,13 @@ def view_shape_models_widget(self, n_parameters=5, If ``'multiple'``, a slider is constructed for each parameter. popup : `bool`, optional If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_shape_model visualize_shape_model(self.shape_models, n_parameters=n_parameters, parameters_bounds=parameters_bounds, - figure_size=(6, 4), mode=mode, popup=popup) + figure_size=figure_size, mode=mode, popup=popup) def __str__(self): from menpofit.base import name_of_callable diff --git a/menpofit/fittingresult.py b/menpofit/fittingresult.py index f53e4a0..15365aa 100644 --- a/menpofit/fittingresult.py +++ b/menpofit/fittingresult.py @@ -212,7 +212,7 @@ def view_widget(self, popup=False, browser_style='buttons'): form of plus/minus buttons or a slider. """ from menpofit.visualize import visualize_fitting_results - visualize_fitting_results(self, figure_size=(6, 4), popup=popup, + visualize_fitting_results(self, figure_size=(10, 8), popup=popup, browser_style=browser_style) def plot_errors(self, error_type='me_norm', figure_id=None, @@ -222,7 +222,7 @@ def plot_errors(self, error_type='me_norm', figure_id=None, marker_edge_colour='k', marker_edge_width=1., render_axes=True, axes_font_name='sans-serif', axes_font_size=10, axes_font_style='normal', - axes_font_weight='normal', figure_size=(6, 4), + axes_font_weight='normal', figure_size=(10, 6), render_grid=True, grid_line_style='--', grid_line_width=0.5): r""" @@ -319,7 +319,7 @@ def plot_displacements(self, stat_type='mean', figure_id=None, marker_edge_width=1., render_axes=True, axes_font_name='sans-serif', axes_font_size=10, axes_font_style='normal', axes_font_weight='normal', - figure_size=(6, 4), render_grid=True, + figure_size=(10, 6), render_grid=True, grid_line_style='--', grid_line_width=0.5): r""" Plot of a statistical metric of the displacement between the shape of @@ -1025,7 +1025,7 @@ def plot_cumulative_error_distribution(errors, error_range=None, figure_id=None, axes_font_style='normal', axes_font_weight='normal', axes_x_limits=None, axes_y_limits=None, - figure_size=(6, 4), render_grid=True, + figure_size=(10, 8), render_grid=True, grid_line_style='--', grid_line_width=0.5): r""" diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index 9a4a238..5eba460 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -33,7 +33,7 @@ def visualize_shape_model(shape_models, n_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(6, 4), + parameters_bounds=(-3.0, 3.0), figure_size=(10, 8), mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel shape model. @@ -366,8 +366,7 @@ def update_widgets(name, value): ipydisplay.display(wid) # set final tab titles - tab_titles = ['Shape parameters', 'Viewer options', 'Model info', - 'Save figure'] + tab_titles = ['Shape parameters', 'Viewer options', 'Info', 'Save figure'] for (k, tl) in enumerate(tab_titles): tab_wid.set_title(k, tl) @@ -403,7 +402,7 @@ def update_widgets(name, value): def visualize_appearance_model(appearance_models, n_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(6, 4), + parameters_bounds=(-3.0, 3.0), figure_size=(10, 8), mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel appearance model. @@ -700,8 +699,7 @@ def update_widgets(name, value): # set final tab titles tab_titles = ['Appearance parameters', 'Channels options', - 'Landmarks options', 'Viewer options', 'Model info', - 'Save figure'] + 'Landmarks options', 'Viewer options', 'Info', 'Save figure'] for (k, tl) in enumerate(tab_titles): tab_wid.set_title(k, tl) @@ -748,7 +746,7 @@ def update_widgets(name, value): def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(6, 4), + parameters_bounds=(-3.0, 3.0), figure_size=(10, 8), mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel AAM. @@ -1130,8 +1128,7 @@ def update_widgets(name, value): # set final tab titles tab_titles = ['AAM parameters', 'Channels options', - 'Landmarks options', 'Viewer options', 'Model info', - 'Save figure'] + 'Landmarks options', 'Viewer options', 'Info', 'Save figure'] for (k, tl) in enumerate(tab_titles): tab_wid.set_title(k, tl) tab_titles = ['Shape parameters', 'Appearance parameters'] @@ -1186,7 +1183,7 @@ def update_widgets(name, value): def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), - figure_size=(6, 4), mode='multiple', popup=False): + figure_size=(10, 8), mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel ATM. @@ -1514,8 +1511,7 @@ def update_widgets(name, value): # set final tab titles tab_titles = ['Shape parameters', 'Channels options', - 'Landmarks options', 'Viewer options', 'Model info', - 'Save figure'] + 'Landmarks options', 'Viewer options', 'Info', 'Save figure'] for (k, tl) in enumerate(tab_titles): tab_wid.set_title(k, tl) @@ -1561,7 +1557,7 @@ def update_widgets(name, value): False -def visualize_fitting_results(fitting_results, figure_size=(6, 4), +def visualize_fitting_results(fitting_results, figure_size=(10, 8), browser_style='buttons', popup=False): r""" Widget that allows browsing through a list of fitting results. @@ -2088,8 +2084,7 @@ def save_fig_tab_fun(name, value): # final widget logo_wid = logo() - tab_titles = ['Image info', 'Result', 'Options', 'Error type', - 'Save figure'] + tab_titles = ['Info', 'Result', 'Options', 'Error type', 'Save figure'] button_title = 'Fitting Result Menu' # final widget @@ -2200,7 +2195,7 @@ def close_plot_ced_fun_2(name, value): False -def plot_ced(errors, figure_size=(6, 4), popup=False, error_type='me_norm', +def plot_ced(errors, figure_size=(10, 8), popup=False, error_type='me_norm', error_range=None, legend_entries=None, return_widget=False): r""" Widget for visualizing the cumulative error curves of the provided errors. @@ -2544,6 +2539,11 @@ def _visualize(image, renderer, render_image, render_landmarks, image_is_masked, if glyph is None: from menpo.visualize.image import glyph + # This makes the code shorter for dealing with masked images vs non-masked + # images + mask_arguments = ({'masked': masked_enabled} + if image_is_masked else {}) + # plot if render_image: # image will be displayed @@ -2570,12 +2570,10 @@ def _visualize(image, renderer, render_image, render_landmarks, image_is_masked, use_negative=glyph_use_negative, channels=channels).\ view_landmarks( - masked=masked_enabled, group=group, - with_labels=with_labels[k], without_labels=None, - figure_id=renderer.figure_id, new_figure=False, - render_lines=render_lines[k], - line_style=line_style[k], - line_width=line_width[k], + group=group, with_labels=with_labels[k], + without_labels=None, figure_id=renderer.figure_id, + new_figure=False, render_lines=render_lines[k], + line_style=line_style[k], line_width=line_width[k], line_colour=line_colour[k], render_markers=render_markers[k], marker_style=marker_style[k], @@ -2616,11 +2614,11 @@ def _visualize(image, renderer, render_image, render_landmarks, image_is_masked, axes_x_limits=axes_x_limits, axes_y_limits=axes_y_limits, interpolation=interpolation, alpha=alpha, - figure_size=figure_size) + figure_size=figure_size, **mask_arguments) else: # image, landmarks, masked, not glyph renderer = image.view_landmarks( - channels=channels, masked=masked_enabled, group=group, + channels=channels, group=group, with_labels=with_labels[k], without_labels=None, figure_id=renderer.figure_id, new_figure=False, render_lines=render_lines[k], line_style=line_style[k], @@ -2663,60 +2661,33 @@ def _visualize(image, renderer, render_image, render_landmarks, image_is_masked, axes_x_limits=axes_x_limits, axes_y_limits=axes_y_limits, interpolation=interpolation, alpha=alpha, - figure_size=figure_size) + figure_size=figure_size, **mask_arguments) else: # either there are not any landmark groups selected or they won't # be displayed - if image_is_masked: - if glyph_enabled or sum_enabled: - # image, not landmarks, masked, glyph - renderer = glyph(image, vectors_block_size=glyph_block_size, - use_negative=glyph_use_negative, - channels=channels).view( - masked=masked_enabled, render_axes=render_axes, - axes_font_name=axes_font_name, - axes_font_size=axes_font_size, - axes_font_style=axes_font_style, - axes_font_weight=axes_font_weight, - axes_x_limits=axes_x_limits, - axes_y_limits=axes_y_limits, - figure_size=figure_size, interpolation=interpolation, - alpha=alpha) - else: - # image, not landmarks, masked, not glyph - renderer = image.view( - masked=masked_enabled, channels=channels, - render_axes=render_axes, axes_font_name=axes_font_name, - axes_font_size=axes_font_size, - axes_font_style=axes_font_style, - axes_font_weight=axes_font_weight, - axes_x_limits=axes_x_limits, - axes_y_limits=axes_y_limits, figure_size=figure_size, - interpolation=interpolation, alpha=alpha) + if glyph_enabled or sum_enabled: + # image, not landmarks, masked, glyph + renderer = glyph(image, vectors_block_size=glyph_block_size, + use_negative=glyph_use_negative, + channels=channels).view( + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, axes_y_limits=axes_y_limits, + figure_size=figure_size, interpolation=interpolation, + alpha=alpha, **mask_arguments) else: - if glyph_enabled or sum_enabled: - # image, not landmarks, not masked, glyph - renderer = glyph(image, vectors_block_size=glyph_block_size, - use_negative=glyph_use_negative, - channels=channels).view( - render_axes=render_axes, axes_font_name=axes_font_name, - axes_font_size=axes_font_size, - axes_font_style=axes_font_style, - axes_font_weight=axes_font_weight, - axes_x_limits=axes_x_limits, - axes_y_limits=axes_y_limits, figure_size=figure_size, - interpolation=interpolation, alpha=alpha) - else: - # image, not landmarks, not masked, not glyph - renderer = image.view( - channels=channels, render_axes=render_axes, - axes_font_name=axes_font_name, - axes_font_size=axes_font_size, - axes_font_style=axes_font_style, - axes_font_weight=axes_font_weight, - axes_x_limits=axes_x_limits, - axes_y_limits=axes_y_limits, figure_size=figure_size, - interpolation=interpolation, alpha=alpha) + # image, not landmarks, masked, not glyph + renderer = image.view( + channels=channels, render_axes=render_axes, + axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, figure_size=figure_size, + interpolation=interpolation, alpha=alpha, **mask_arguments) else: # image won't be displayed if render_landmarks and len(groups) > 0: From 7c4cae5170bf2d9569f7aa2da96424b8c1301aca Mon Sep 17 00:00:00 2001 From: Patrick Snape Date: Wed, 28 Jan 2015 13:10:37 +0000 Subject: [PATCH 15/15] Update tests for new LJSON change lenna is now an ljson and therefore all the tests that assumed every image was landmarked with PTS started failing. Also, unpin menpofit from menpo 0.4.0a3 and let if find the master branch so that the tests have a chance of passing. --- .travis.yml | 2 +- appveyor.yml | 2 +- conda/meta.yaml | 2 +- menpofit/test/aam_builder_test.py | 40 ++++++++++++++----------------- menpofit/test/aam_fitter_test.py | 19 +++++++-------- menpofit/test/atm_builder_test.py | 25 +++++++++---------- menpofit/test/atm_fitter_test.py | 18 +++++++------- menpofit/test/clm_builder_test.py | 38 +++++++++++++---------------- menpofit/test/clm_fitter_test.py | 6 ++--- menpofit/test/sdm_test.py | 30 +++++++++++------------ 10 files changed, 82 insertions(+), 100 deletions(-) diff --git a/.travis.yml b/.travis.yml index 84c8814..54dde0e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ install: - wget https://raw.githubusercontent.com/jabooth/condaci/v0.2.0/condaci.py -O condaci.py - python condaci.py setup $PYTHON_VERSION --channel $BINSTAR_USER - export PATH=$HOME/miniconda/bin:$PATH -#- conda config --add channels $BINSTAR_USER/channel/master +- conda config --add channels $BINSTAR_USER/channel/master script: - python condaci.py auto ./conda --binstaruser $BINSTAR_USER --binstarkey $BINSTAR_KEY diff --git a/appveyor.yml b/appveyor.yml index 83d7ea8..80c49b0 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -25,7 +25,7 @@ platform: init: - ps: Start-FileDownload 'https://raw.githubusercontent.com/jabooth/condaci/v0.2.0/condaci.py' C:\\condaci.py; echo "Done" - cmd: python C:\\condaci.py setup %PYTHON_VERSION% --channel %BINSTAR_USER% -#- cmd: C:\\Miniconda\\Scripts\\conda config --add channels %BINSTAR_USER%/channel/master +- cmd: C:\\Miniconda\\Scripts\\conda config --add channels %BINSTAR_USER%/channel/master install: - cmd: C:\\Miniconda\\python C:\\condaci.py auto ./conda --binstaruser %BINSTAR_USER% --binstarkey %BINSTAR_KEY% diff --git a/conda/meta.yaml b/conda/meta.yaml index a427d0a..6fbd043 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -9,7 +9,7 @@ requirements: run: - python - - menpo 0.4.0a3 + - menpo - numpy 1.9.0 - scipy 0.14.0 - scikit-learn 0.15.2 diff --git a/menpofit/test/aam_builder_test.py b/menpofit/test/aam_builder_test.py index 31a1119..a8ce97d 100644 --- a/menpofit/test/aam_builder_test.py +++ b/menpofit/test/aam_builder_test.py @@ -9,7 +9,7 @@ from menpo.feature import sparse_hog, igo, lbp, no_op import menpo.io as mio -from menpo.landmark import labeller, ibug_face_68_trimesh +from menpo.landmark import ibug_face_68_trimesh from menpofit.aam import AAMBuilder, PatchBasedAAMBuilder @@ -19,23 +19,23 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training.append(im) # build aams +template_trilist_image = training[0].landmarks[None] +trilist = ibug_face_68_trimesh(template_trilist_image)[1].lms.trilist aam1 = AAMBuilder(features=[igo, sparse_hog, no_op], transform=PiecewiseAffine, - trilist=training[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=3, downscale=2, scaled_shape_models=False, max_shape_components=[1, 2, 3], max_appearance_components=[3, 3, 3], - boundary=3).build(training, group='PTS') + boundary=3).build(training) aam2 = AAMBuilder(features=[no_op, no_op], transform=ThinPlateSplines, @@ -46,7 +46,7 @@ scaled_shape_models=True, max_shape_components=None, max_appearance_components=1, - boundary=0).build(training, group='PTS') + boundary=0).build(training) aam3 = AAMBuilder(features=igo, transform=ThinPlateSplines, @@ -57,7 +57,7 @@ scaled_shape_models=True, max_shape_components=[2], max_appearance_components=10, - boundary=2).build(training, group='PTS') + boundary=2).build(training) aam4 = PatchBasedAAMBuilder(features=lbp, patch_shape=(10, 13), @@ -67,55 +67,51 @@ scaled_shape_models=True, max_shape_components=1, max_appearance_components=None, - boundary=2).build(training, group='PTS') + boundary=2).build(training) @raises(ValueError) def test_features_exception(): - AAMBuilder(features=[igo, sparse_hog]).build(training, group='PTS') + AAMBuilder(features=[igo, sparse_hog]).build(training) @raises(ValueError) def test_n_levels_exception(): - AAMBuilder(n_levels=0).build(training, group='PTS') + AAMBuilder(n_levels=0).build(training) @raises(ValueError) def test_downscale_exception(): - aam = AAMBuilder(downscale=1).build(training, - group='PTS') + aam = AAMBuilder(downscale=1).build(training) assert (aam.downscale == 1) - AAMBuilder(downscale=0).build(training, group='PTS') + AAMBuilder(downscale=0).build(training) @raises(ValueError) def test_normalization_diagonal_exception(): - aam = AAMBuilder(normalization_diagonal=100).build(training, - group='PTS') + aam = AAMBuilder(normalization_diagonal=100).build(training) assert (aam.appearance_models[0].n_features == 382) - AAMBuilder(normalization_diagonal=10).build(training, group='PTS') + AAMBuilder(normalization_diagonal=10).build(training) @raises(ValueError) def test_max_shape_components_exception(): - AAMBuilder(max_shape_components=[1, 0.2, 'a']).build(training, - group='PTS') + AAMBuilder(max_shape_components=[1, 0.2, 'a']).build(training) @raises(ValueError) def test_max_appearance_components_exception(): - AAMBuilder(max_appearance_components=[1, 2]).build(training, - group='PTS') + AAMBuilder(max_appearance_components=[1, 2]).build(training) @raises(ValueError) def test_boundary_exception(): - AAMBuilder(boundary=-1).build(training, group='PTS') + AAMBuilder(boundary=-1).build(training) @patch('sys.stdout', new_callable=StringIO) def test_verbose_mock(mock_stdout): - AAMBuilder().build(training, group='PTS', verbose=True) + AAMBuilder().build(training, verbose=True) @patch('sys.stdout', new_callable=StringIO) diff --git a/menpofit/test/aam_fitter_test.py b/menpofit/test/aam_fitter_test.py index 438b9c5..1402e14 100644 --- a/menpofit/test/aam_fitter_test.py +++ b/menpofit/test/aam_fitter_test.py @@ -11,7 +11,7 @@ import menpo.io as mio from menpo.shape.pointcloud import PointCloud -from menpo.landmark import labeller, ibug_face_68_trimesh +from menpo.landmark import ibug_face_68_trimesh from menpofit.aam import AAMBuilder, LucasKanadeAAMFitter from menpofit.lucaskanade.appearance import ( AlternatingForwardAdditive, AlternatingForwardCompositional, @@ -307,35 +307,34 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) # build aam +template_trilist_image = training_images[0].landmarks[None] +trilist = ibug_face_68_trimesh(template_trilist_image)[1].lms.trilist aam = AAMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, - trilist=training_images[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=3, downscale=2, scaled_shape_models=True, max_shape_components=[1, 2, 3], max_appearance_components=[3, 2, 1], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) aam2 = AAMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, - trilist=training_images[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=1, downscale=2, scaled_shape_models=True, max_shape_components=[1], max_appearance_components=[1], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) def test_aam(): @@ -368,7 +367,7 @@ def test_n_appearance_exception(): def test_pertrurb_shape(): fitter = LucasKanadeAAMFitter(aam) - s = fitter.perturb_shape(training_images[0].landmarks['PTS'].lms, + s = fitter.perturb_shape(training_images[0].landmarks[None].lms, noise_std=0.08, rotation=False) assert (s.n_dims == 2) assert (s.n_landmark_groups == 0) @@ -410,7 +409,7 @@ def aam_helper(aam, algorithm, im_number, max_iters, initial_error, fitter = LucasKanadeAAMFitter(aam, algorithm=algorithm) fitting_result = fitter.fit( training_images[im_number], initial_shape[im_number], - gt_shape=training_images[im_number].landmarks['PTS'].lms, + gt_shape=training_images[im_number].landmarks[None].lms, max_iters=max_iters) assert_allclose( np.around(fitting_result.initial_error(error_type=error_type), 5), diff --git a/menpofit/test/atm_builder_test.py b/menpofit/test/atm_builder_test.py index 1f0244f..437a999 100644 --- a/menpofit/test/atm_builder_test.py +++ b/menpofit/test/atm_builder_test.py @@ -20,7 +20,7 @@ im = mio.import_builtin_asset(filenames[i]) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') - training.append(im.landmarks['PTS']['all']) + training.append(im.landmarks[None].lms) templates.append(im) # build atms @@ -31,7 +31,7 @@ downscale=2, scaled_shape_models=False, max_shape_components=[1, 2, 3], - boundary=3).build(training, templates[0], group='PTS') + boundary=3).build(training, templates[0]) atm2 = ATMBuilder(features=[no_op, no_op], transform=ThinPlateSplines, @@ -41,7 +41,7 @@ downscale=1.2, scaled_shape_models=True, max_shape_components=None, - boundary=0).build(training, templates[1], group='PTS') + boundary=0).build(training, templates[1]) atm3 = ATMBuilder(features=igo, transform=ThinPlateSplines, @@ -51,7 +51,7 @@ downscale=3, scaled_shape_models=True, max_shape_components=[2], - boundary=2).build(training, templates[2], group='PTS') + boundary=2).build(training, templates[2]) atm4 = PatchBasedATMBuilder(features=lbp, patch_shape=(10, 13), @@ -60,8 +60,7 @@ downscale=1.2, scaled_shape_models=True, max_shape_components=1, - boundary=2).build(training, templates[3], - group='PTS') + boundary=2).build(training, templates[3]) @raises(ValueError) @@ -76,23 +75,21 @@ def test_n_levels_exception(): @raises(ValueError) def test_downscale_exception(): - atm = ATMBuilder(downscale=1).build(training, templates[2], group='PTS') + atm = ATMBuilder(downscale=1).build(training, templates[2]) assert (atm.downscale == 1) - ATMBuilder(downscale=0).build(training, templates[2], group='PTS') + ATMBuilder(downscale=0).build(training, templates[2]) @raises(ValueError) def test_normalization_diagonal_exception(): - atm = ATMBuilder(normalization_diagonal=100).build(training, templates[3], - group='PTS') + atm = ATMBuilder(normalization_diagonal=100).build(training, templates[3]) assert (atm.warped_templates[0].n_true_pixels() == 1246) ATMBuilder(normalization_diagonal=10).build(training, templates[3]) @raises(ValueError) def test_max_shape_components_exception(): - ATMBuilder(max_shape_components=[1, 0.2, 'a']).build(training, templates[0], - group='PTS') + ATMBuilder(max_shape_components=[1, 0.2, 'a']).build(training, templates[0]) @raises(ValueError) @@ -102,12 +99,12 @@ def test_max_shape_components_exception_2(): @raises(ValueError) def test_boundary_exception(): - ATMBuilder(boundary=-1).build(training, templates[1], group='PTS') + ATMBuilder(boundary=-1).build(training, templates[1]) @patch('sys.stdout', new_callable=StringIO) def test_verbose_mock(mock_stdout): - ATMBuilder().build(training, templates[2], group='PTS', verbose=True) + ATMBuilder().build(training, templates[2], verbose=True) @patch('sys.stdout', new_callable=StringIO) diff --git a/menpofit/test/atm_fitter_test.py b/menpofit/test/atm_fitter_test.py index 4cf91d1..0a1459e 100644 --- a/menpofit/test/atm_fitter_test.py +++ b/menpofit/test/atm_fitter_test.py @@ -12,8 +12,8 @@ from menpo.shape.pointcloud import PointCloud from menpofit.atm import ATMBuilder, LucasKanadeATMFitter from menpofit.lucaskanade.image import (ImageInverseCompositional, - ImageForwardAdditive, - ImageForwardCompositional) + ImageForwardAdditive, + ImageForwardCompositional) initial_shape = [] @@ -302,7 +302,7 @@ im.crop_to_landmarks_proportion_inplace(0.1) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') - training_shapes.append(im.landmarks['PTS']['all']) + training_shapes.append(im.landmarks[None].lms) templates.append(im) @@ -314,7 +314,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1, 2, 3], - boundary=3).build(training_shapes, templates[0], group='PTS') + boundary=3).build(training_shapes, templates[0]) atm2 = ATMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, @@ -323,7 +323,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1], - boundary=3).build(training_shapes, templates[1], group='PTS') + boundary=3).build(training_shapes, templates[1]) atm3 = ATMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, @@ -332,7 +332,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1, 2, 3], - boundary=3).build(training_shapes, templates[2], group='PTS') + boundary=3).build(training_shapes, templates[2]) atm4 = ATMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, @@ -341,7 +341,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1], - boundary=3).build(training_shapes, templates[3], group='PTS') + boundary=3).build(training_shapes, templates[3]) def test_atm1(): @@ -371,7 +371,7 @@ def test_n_shape_exception_2(): def test_pertrurb_shape(): fitter = LucasKanadeATMFitter(atm1) - s = fitter.perturb_shape(templates[0].landmarks['PTS'].lms, + s = fitter.perturb_shape(templates[0].landmarks[None].lms, noise_std=0.08, rotation=False) assert (s.n_dims == 2) assert (s.n_landmark_groups == 0) @@ -412,7 +412,7 @@ def atm_helper(atm, algorithm, im_number, max_iters, initial_error, fitter = LucasKanadeATMFitter(atm, algorithm=algorithm) fitting_result = fitter.fit( templates[im_number], initial_shape[im_number], - gt_shape=templates[im_number].landmarks['PTS'].lms, + gt_shape=templates[im_number].landmarks[None].lms, max_iters=max_iters) assert_allclose( np.around(fitting_result.initial_error(error_type=error_type), 5), diff --git a/menpofit/test/clm_builder_test.py b/menpofit/test/clm_builder_test.py index c9fbf52..7219c3f 100644 --- a/menpofit/test/clm_builder_test.py +++ b/menpofit/test/clm_builder_test.py @@ -8,7 +8,6 @@ from menpo.feature import sparse_hog, igo, no_op import menpo.io as mio -from menpo.landmark import labeller, ibug_face_68_trimesh from menpofit.clm import CLMBuilder from menpofit.clm.classifier import linear_svm_lr from menpofit.base import name_of_callable @@ -29,7 +28,6 @@ def random_forest_predict(x): for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) @@ -43,7 +41,7 @@ def random_forest_predict(x): downscale=2, scaled_shape_models=False, max_shape_components=[1, 2, 3], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) clm2 = CLMBuilder(classifier_trainers=[random_forest, linear_svm_lr], patch_shape=(3, 10), @@ -53,7 +51,7 @@ def random_forest_predict(x): downscale=1.2, scaled_shape_models=True, max_shape_components=None, - boundary=0).build(training_images, group='PTS') + boundary=0).build(training_images) clm3 = CLMBuilder(classifier_trainers=[linear_svm_lr], patch_shape=(2, 3), @@ -63,69 +61,65 @@ def random_forest_predict(x): downscale=3, scaled_shape_models=True, max_shape_components=[1], - boundary=2).build(training_images, group='PTS') + boundary=2).build(training_images) @raises(ValueError) def test_classifier_type_1_exception(): CLMBuilder(classifier_trainers=[linear_svm_lr, linear_svm_lr]).build( - training_images, group='PTS') + training_images) @raises(ValueError) def test_classifier_type_2_exception(): - CLMBuilder(classifier_trainers=['linear_svm_lr']).build(training_images, - group='PTS') + CLMBuilder(classifier_trainers=['linear_svm_lr']).build(training_images) @raises(ValueError) def test_patch_shape_1_exception(): - CLMBuilder(patch_shape=(5, 1)).build(training_images, group='PTS') + CLMBuilder(patch_shape=(5, 1)).build(training_images) @raises(ValueError) def test_patch_shape_2_exception(): - CLMBuilder(patch_shape=(5, 6, 7)).build(training_images, group='PTS') + CLMBuilder(patch_shape=(5, 6, 7)).build(training_images) @raises(ValueError) def test_features_exception(): - CLMBuilder(features=[igo, sparse_hog]).build(training_images, group='PTS') + CLMBuilder(features=[igo, sparse_hog]).build(training_images) @raises(ValueError) def test_n_levels_exception(): - clm = CLMBuilder(n_levels=0).build(training_images, group='PTS') + clm = CLMBuilder(n_levels=0).build(training_images) @raises(ValueError) def test_downscale_exception(): - clm = CLMBuilder(downscale=1).build(training_images, group='PTS') + clm = CLMBuilder(downscale=1).build(training_images) assert (clm.downscale == 1) - CLMBuilder(downscale=0).build(training_images, group='PTS') + CLMBuilder(downscale=0).build(training_images) @raises(ValueError) def test_normalization_diagonal_exception(): - CLMBuilder(normalization_diagonal=10).build(training_images, - group='PTS') + CLMBuilder(normalization_diagonal=10).build(training_images) @raises(ValueError) def test_max_shape_components_1_exception(): - CLMBuilder(max_shape_components=[1, 0.2, 'a']).build(training_images, - group='PTS') + CLMBuilder(max_shape_components=[1, 0.2, 'a']).build(training_images) @raises(ValueError) def test_max_shape_components_2_exception(): - CLMBuilder(max_shape_components=[1, 2]).build(training_images, - group='PTS') + CLMBuilder(max_shape_components=[1, 2]).build(training_images) @raises(ValueError) def test_boundary_exception(): - CLMBuilder(boundary=-1).build(training_images, group='PTS') + CLMBuilder(boundary=-1).build(training_images) @patch('sys.stdout', new_callable=StringIO) def test_verbose_mock(mock_stdout): - CLMBuilder().build(training_images, group='PTS', verbose=True) + CLMBuilder().build(training_images, verbose=True) @patch('sys.stdout', new_callable=StringIO) diff --git a/menpofit/test/clm_fitter_test.py b/menpofit/test/clm_fitter_test.py index d442bc0..fd11932 100644 --- a/menpofit/test/clm_fitter_test.py +++ b/menpofit/test/clm_fitter_test.py @@ -8,7 +8,6 @@ import menpo.io as mio from menpo.shape.pointcloud import PointCloud -from menpo.landmark import labeller, ibug_face_68_trimesh from menpofit.clm import CLMBuilder from menpofit.clm import GradientDescentCLMFitter from menpofit.gradientdescent import RegularizedLandmarkMeanShift @@ -299,7 +298,6 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) @@ -313,7 +311,7 @@ downscale=1.1, scaled_shape_models=True, max_shape_components=[1, 2, 3], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) def test_clm(): @@ -353,7 +351,7 @@ def test_n_shape_2_exception(): def test_perturb_shape(): fitter = GradientDescentCLMFitter(clm) - s = fitter.perturb_shape(training_images[0].landmarks['PTS'].lms, + s = fitter.perturb_shape(training_images[0].landmarks[None].lms, noise_std=0.08, rotation=False) assert (s.n_dims == 2) assert (s.n_landmark_groups == 0) diff --git a/menpofit/test/sdm_test.py b/menpofit/test/sdm_test.py index b1189f2..6a08ed5 100644 --- a/menpofit/test/sdm_test.py +++ b/menpofit/test/sdm_test.py @@ -21,7 +21,6 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) @@ -29,17 +28,18 @@ # Seed the random number generator np.random.seed(seed=1000) +template_trilist_image = training_images[0].landmarks[None] +trilist = ibug_face_68_trimesh(template_trilist_image)[1].lms.trilist aam = AAMBuilder(features=sparse_hog, transform=PiecewiseAffine, - trilist=training_images[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=3, downscale=1.2, scaled_shape_models=False, max_shape_components=None, max_appearance_components=3, - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) clm = CLMBuilder(classifier_trainers=linear_svm_lr, features=sparse_hog, @@ -49,53 +49,51 @@ downscale=1.1, scaled_shape_models=True, max_shape_components=25, - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) @raises(ValueError) def test_features_exception(): sdm = SDMTrainer(features=[igo, sparse_hog], - n_levels=3).train(training_images, group='PTS') + n_levels=3).train(training_images) @raises(ValueError) def test_regression_features_sdmtrainer_exception_1(): sdm = SDMTrainer(n_levels=2, regression_features=[no_op, no_op, no_op]).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_regression_features_sdmtrainer_exception_2(): sdm = SDMTrainer(n_levels=3, regression_features=[no_op, sparse_hog, 1]).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_regression_features_sdaamtrainer_exception_1(): sdm = SDAAMTrainer(aam, regression_features=[no_op, sparse_hog]).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_regression_features_sdaamtrainer_exception_2(): sdm = SDAAMTrainer(aam, regression_features=7).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_n_levels_exception(): - sdm = SDMTrainer(n_levels=0).train(training_images, group='PTS') + sdm = SDMTrainer(n_levels=0).train(training_images) @raises(ValueError) def test_downscale_exception(): - sdm = SDMTrainer(downscale=0).train(training_images, - group='PTS') + sdm = SDMTrainer(downscale=0).train(training_images) @raises(ValueError) def test_n_perturbations_exception(): - sdm = SDAAMTrainer(aam, n_perturbations=-10).train(training_images, - group='PTS') + sdm = SDAAMTrainer(aam, n_perturbations=-10).train(training_images) @patch('sys.stdout', new_callable=StringIO) @@ -109,5 +107,5 @@ def test_verbose_mock(mock_stdout): downscale=1.3, noise_std=0.04, rotation=False, - n_perturbations=2).train(training_images, group='PTS', + n_perturbations=2).train(training_images, verbose=True)