diff --git a/.travis.yml b/.travis.yml index 84c8814..54dde0e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ install: - wget https://raw.githubusercontent.com/jabooth/condaci/v0.2.0/condaci.py -O condaci.py - python condaci.py setup $PYTHON_VERSION --channel $BINSTAR_USER - export PATH=$HOME/miniconda/bin:$PATH -#- conda config --add channels $BINSTAR_USER/channel/master +- conda config --add channels $BINSTAR_USER/channel/master script: - python condaci.py auto ./conda --binstaruser $BINSTAR_USER --binstarkey $BINSTAR_KEY diff --git a/appveyor.yml b/appveyor.yml index 83d7ea8..80c49b0 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -25,7 +25,7 @@ platform: init: - ps: Start-FileDownload 'https://raw.githubusercontent.com/jabooth/condaci/v0.2.0/condaci.py' C:\\condaci.py; echo "Done" - cmd: python C:\\condaci.py setup %PYTHON_VERSION% --channel %BINSTAR_USER% -#- cmd: C:\\Miniconda\\Scripts\\conda config --add channels %BINSTAR_USER%/channel/master +- cmd: C:\\Miniconda\\Scripts\\conda config --add channels %BINSTAR_USER%/channel/master install: - cmd: C:\\Miniconda\\python C:\\condaci.py auto ./conda --binstaruser %BINSTAR_USER% --binstarkey %BINSTAR_KEY% diff --git a/conda/meta.yaml b/conda/meta.yaml index a427d0a..6fbd043 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -9,7 +9,7 @@ requirements: run: - python - - menpo 0.4.0a3 + - menpo - numpy 1.9.0 - scipy 0.14.0 - scikit-learn 0.15.2 diff --git a/menpofit/aam/base.py b/menpofit/aam/base.py index e97cdbd..fde1033 100644 --- a/menpofit/aam/base.py +++ b/menpofit/aam/base.py @@ -181,47 +181,116 @@ def _str_title(self): """ return 'Active Appearance Model' - def view_widget(self, n_shape_parameters=5, n_appearance_parameters=5, - parameters_bounds=(-3.0, 3.0), mode='multiple', - popup=False): + def view_shape_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False, figure_size=(10, 8)): r""" - Visualizes the AAM object using the - menpo.visualize.widgets.visualize_aam widget. + Visualizes the shape models of the AAM object using the + `menpo.visualize.widgets.visualize_shape_model` widget. Parameters ----------- - n_shape_parameters : `int` or `list` of `int` or None, optional + n_parameters : `int` or `list` of `int` or ``None``, optional The number of shape principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + parameters_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. + """ + from menpofit.visualize import visualize_shape_model + visualize_shape_model(self.shape_models, n_parameters=n_parameters, + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, popup=popup) + + def view_appearance_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), + mode='multiple', popup=False, + figure_size=(10, 8)): + r""" + Visualizes the appearance models of the AAM object using the + `menpo.visualize.widgets.visualize_appearance_model` widget. - n_appearance_parameters : `int` or `list` of `int` or None, optional + Parameters + ----------- + n_parameters : `int` or `list` of `int` or ``None``, optional The number of appearance principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. + """ + from menpofit.visualize import visualize_appearance_model + visualize_appearance_model(self.appearance_models, + n_parameters=n_parameters, + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, + popup=popup) + + def view_aam_widget(self, n_shape_parameters=5, n_appearance_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False, figure_size=(10, 8)): + r""" + Visualizes both the shape and appearance models of the AAM object using + the `menpo.visualize.widgets.visualize_aam` widget. - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + Parameters + ----------- + n_shape_parameters : `int` or `list` of `int` or None, optional + The number of shape principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + n_appearance_parameters : `int` or `list` of `int` or None, optional + The number of appearance principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + parameters_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_aam visualize_aam(self, n_shape_parameters=n_shape_parameters, n_appearance_parameters=n_appearance_parameters, - parameters_bounds=parameters_bounds, figure_size=(7, 7), - mode=mode, popup=popup) + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, popup=popup) def __str__(self): out = "{}\n - {} training images.\n".format(self._str_title, diff --git a/menpofit/atm/base.py b/menpofit/atm/base.py index da9ed11..c56a7bd 100644 --- a/menpofit/atm/base.py +++ b/menpofit/atm/base.py @@ -166,8 +166,42 @@ def _str_title(self): """ return 'Active Template Model' - def view_widget(self, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), - mode='multiple', popup=False): + def view_shape_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False, figure_size=(10, 8)): + r""" + Visualizes the shape models of the AAM object using the + `menpo.visualize.widgets.visualize_shape_model` widget. + + Parameters + ----------- + n_parameters : `int` or `list` of `int` or ``None``, optional + The number of shape principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. + parameters_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. + """ + from menpofit.visualize import visualize_shape_model + visualize_shape_model(self.shape_models, n_parameters=n_parameters, + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, popup=popup) + + def view_atm_widget(self, n_shape_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False, figure_size=(10, 8)): r""" Visualizes the ATM object using the menpo.visualize.widgets.visualize_atm widget. @@ -177,26 +211,26 @@ def view_widget(self, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), n_shape_parameters : `int` or `list` of `int` or None, optional The number of shape principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_atm visualize_atm(self, n_shape_parameters=n_shape_parameters, - parameters_bounds=parameters_bounds, figure_size=(7, 7), - mode=mode, popup=popup) + parameters_bounds=parameters_bounds, + figure_size=figure_size, mode=mode, popup=popup) def __str__(self): out = "{}\n - {} training shapes.\n".format(self._str_title, diff --git a/menpofit/clm/base.py b/menpofit/clm/base.py index 28ddc76..658d721 100644 --- a/menpofit/clm/base.py +++ b/menpofit/clm/base.py @@ -201,37 +201,38 @@ def _str_title(self): """ return 'Constrained Local Model' - def view_widget(self, n_parameters=5, parameters_bounds=(-3.0, 3.0), - mode='multiple', popup=False): + def view_shape_models_widget(self, n_parameters=5, + parameters_bounds=(-3.0, 3.0), mode='multiple', + popup=False, figure_size=(10, 8)): r""" Visualizes the shape models of the CLM object using the - menpo.visualize.widgets.visualize_shape_model widget. + `menpo.visualize.widgets.visualize_shape_model` widget. Parameters ----------- - n_parameters : `int` or `list` of `int` or None, optional - The number of principal components to be used for the parameters - sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + n_parameters : `int` or `list` of `int` or ``None``, optional + The number of shape principal components to be used for the + parameters sliders. + If `int`, then the number of sliders per level is the minimum + between `n_parameters` and the number of active components per + level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + mode : {``single``, ``multiple``}, optional + If ``'single'``, only a single slider is constructed along with a + drop down menu. + If ``'multiple'``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + figure_size : (`int`, `int`), optional + The size of the plotted figures. """ from menpofit.visualize import visualize_shape_model visualize_shape_model(self.shape_models, n_parameters=n_parameters, parameters_bounds=parameters_bounds, - figure_size=(7, 7), mode=mode, popup=popup) + figure_size=figure_size, mode=mode, popup=popup) def __str__(self): from menpofit.base import name_of_callable diff --git a/menpofit/fittingresult.py b/menpofit/fittingresult.py index 1b52f9e..15365aa 100644 --- a/menpofit/fittingresult.py +++ b/menpofit/fittingresult.py @@ -198,18 +198,232 @@ def initial_error(self, error_type='me_norm'): raise ValueError('Ground truth has not been set, final error ' 'cannot be computed') - def view_widget(self, popup=False): + def view_widget(self, popup=False, browser_style='buttons'): r""" Visualizes the multilevel fitting result object using the - menpo.visualize.widgets.visualize_fitting_results widget. + `menpo.visualize.widgets.visualize_fitting_results` widget. Parameters ----------- - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + browser_style : {``buttons``, ``slider``}, optional + It defines whether the selector of the fitting results will have the + form of plus/minus buttons or a slider. """ from menpofit.visualize import visualize_fitting_results - visualize_fitting_results(self, figure_size=(7, 7), popup=popup) + visualize_fitting_results(self, figure_size=(10, 8), popup=popup, + browser_style=browser_style) + + def plot_errors(self, error_type='me_norm', figure_id=None, + new_figure=False, render_lines=True, line_colour='b', + line_style='-', line_width=2, render_markers=True, + marker_style='o', marker_size=4, marker_face_colour='b', + marker_edge_colour='k', marker_edge_width=1., + render_axes=True, axes_font_name='sans-serif', + axes_font_size=10, axes_font_style='normal', + axes_font_weight='normal', figure_size=(10, 6), + render_grid=True, grid_line_style='--', + grid_line_width=0.5): + r""" + Plot of the error evolution at each fitting iteration. + + Parameters + ---------- + error_type : {``me_norm``, ``me``, ``rmse``}, optional + Specifies the way in which the error between the fitted and + ground truth shapes is to be computed. + figure_id : `object`, optional + The id of the figure to be used. + new_figure : `bool`, optional + If ``True``, a new figure is created. + render_lines : `bool`, optional + If ``True``, the line will be rendered. + line_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} or + ``(3, )`` `ndarray`, optional + The colour of the lines. + line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the lines. + line_width : `float`, optional + The width of the lines. + render_markers : `bool`, optional + If ``True``, the markers will be rendered. + marker_style : {``.``, ``,``, ``o``, ``v``, ``^``, ``<``, ``>``, ``+``, + ``x``, ``D``, ``d``, ``s``, ``p``, ``*``, ``h``, ``H``, + ``1``, ``2``, ``3``, ``4``, ``8``}, optional + The style of the markers. + marker_size : `int`, optional + The size of the markers in points^2. + marker_face_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The face (filling) colour of the markers. + marker_edge_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The edge colour of the markers. + marker_edge_width : `float`, optional + The width of the markers' edge. + render_axes : `bool`, optional + If ``True``, the axes will be rendered. + axes_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the axes. + axes_font_size : `int`, optional + The font size of the axes. + axes_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the axes. + axes_font_weight : {``ultralight``, ``light``, ``normal``, ``regular``, + ``book``, ``medium``, ``roman``, ``semibold``, + ``demibold``, ``demi``, ``bold``, ``heavy``, + ``extra bold``, ``black``}, optional + The font weight of the axes. + figure_size : (`float`, `float`) or `None`, optional + The size of the figure in inches. + render_grid : `bool`, optional + If ``True``, the grid will be rendered. + grid_line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the grid lines. + grid_line_width : `float`, optional + The width of the grid lines. + + Returns + ------- + viewer : :map:`GraphPlotter` + The viewer object. + """ + from menpo.visualize import GraphPlotter + errors_list = self.errors(error_type=error_type) + return GraphPlotter(figure_id=figure_id, new_figure=new_figure, + x_axis=range(len(errors_list)), + y_axis=[errors_list], + title='Fitting Errors per Iteration', + x_label='Iteration', y_label='Fitting Error', + x_axis_limits=(0, len(errors_list)-1), + y_axis_limits=None).render( + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, render_legend=False, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, render_grid=render_grid, + grid_line_style=grid_line_style, grid_line_width=grid_line_width, + figure_size=figure_size) + + def plot_displacements(self, stat_type='mean', figure_id=None, + new_figure=False, render_lines=True, line_colour='b', + line_style='-', line_width=2, render_markers=True, + marker_style='o', marker_size=4, + marker_face_colour='b', marker_edge_colour='k', + marker_edge_width=1., render_axes=True, + axes_font_name='sans-serif', axes_font_size=10, + axes_font_style='normal', axes_font_weight='normal', + figure_size=(10, 6), render_grid=True, + grid_line_style='--', grid_line_width=0.5): + r""" + Plot of a statistical metric of the displacement between the shape of + each iteration and the shape of the previous one. + + Parameters + ---------- + stat_type : {``mean``, ``median``, ``min``, ``max``}, optional + Specifies a statistic metric to be extracted from the displacements + (see also `displacements_stats()` method). + figure_id : `object`, optional + The id of the figure to be used. + new_figure : `bool`, optional + If ``True``, a new figure is created. + render_lines : `bool`, optional + If ``True``, the line will be rendered. + line_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} or + ``(3, )`` `ndarray`, optional + The colour of the lines. + line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the lines. + line_width : `float`, optional + The width of the lines. + render_markers : `bool`, optional + If ``True``, the markers will be rendered. + marker_style : {``.``, ``,``, ``o``, ``v``, ``^``, ``<``, ``>``, ``+``, + ``x``, ``D``, ``d``, ``s``, ``p``, ``*``, ``h``, ``H``, + ``1``, ``2``, ``3``, ``4``, ``8``}, optional + The style of the markers. + marker_size : `int`, optional + The size of the markers in points^2. + marker_face_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The face (filling) colour of the markers. + marker_edge_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray`, optional + The edge colour of the markers. + marker_edge_width : `float`, optional + The width of the markers' edge. + render_axes : `bool`, optional + If ``True``, the axes will be rendered. + axes_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the axes. + axes_font_size : `int`, optional + The font size of the axes. + axes_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the axes. + axes_font_weight : {``ultralight``, ``light``, ``normal``, ``regular``, + ``book``, ``medium``, ``roman``, ``semibold``, + ``demibold``, ``demi``, ``bold``, ``heavy``, + ``extra bold``, ``black``}, optional + The font weight of the axes. + figure_size : (`float`, `float`) or `None`, optional + The size of the figure in inches. + render_grid : `bool`, optional + If ``True``, the grid will be rendered. + grid_line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the grid lines. + grid_line_width : `float`, optional + The width of the grid lines. + + Returns + ------- + viewer : :map:`GraphPlotter` + The viewer object. + """ + from menpo.visualize import GraphPlotter + # set labels + if stat_type == 'max': + ylabel = 'Maximum Displacement' + title = 'Maximum displacement per Iteration' + elif stat_type == 'min': + ylabel = 'Minimum Displacement' + title = 'Minimum displacement per Iteration' + elif stat_type == 'mean': + ylabel = 'Mean Displacement' + title = 'Mean displacement per Iteration' + elif stat_type == 'median': + ylabel = 'Median Displacement' + title = 'Median displacement per Iteration' + else: + raise ValueError('stat_type must be one of {max, min, mean, ' + 'median}.') + # plot + displacements_list = self.displacements_stats(stat_type=stat_type) + return GraphPlotter(figure_id=figure_id, new_figure=new_figure, + x_axis=range(len(displacements_list)), + y_axis=[displacements_list], + title=title, + x_label='Iteration', y_label=ylabel, + x_axis_limits=(0, len(displacements_list)-1), + y_axis_limits=None).render( + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, render_legend=False, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, render_grid=render_grid, + grid_line_style=grid_line_style, grid_line_width=grid_line_width, + figure_size=figure_size) def as_serializable(self): r"""" @@ -774,6 +988,267 @@ def compute_cumulative_error(errors, x_axis): r""" """ n_errors = len(errors) - cumulative_error = [np.count_nonzero([errors <= x]) - for x in x_axis] - return np.array(cumulative_error) / n_errors + return [np.count_nonzero([errors <= x]) / n_errors for x in x_axis] + + +def plot_cumulative_error_distribution(errors, error_range=None, figure_id=None, + new_figure=False, + title='Cumulative Error Distribution', + x_label='Normalized Point-to-Point Error', + y_label='Images Proportion', + legend_entries=None, render_lines=True, + line_colour=None, line_style='-', + line_width=2, render_markers=True, + marker_style='s', marker_size=10, + marker_face_colour='w', + marker_edge_colour=None, + marker_edge_width=2, render_legend=True, + legend_title=None, + legend_font_name='sans-serif', + legend_font_style='normal', + legend_font_size=10, + legend_font_weight='normal', + legend_marker_scale=1., + legend_location=2, + legend_bbox_to_anchor=(1.05, 1.), + legend_border_axes_pad=1., + legend_n_columns=1, + legend_horizontal_spacing=1., + legend_vertical_spacing=1., + legend_border=True, + legend_border_padding=0.5, + legend_shadow=False, + legend_rounded_corners=False, + render_axes=True, + axes_font_name='sans-serif', + axes_font_size=10, + axes_font_style='normal', + axes_font_weight='normal', + axes_x_limits=None, axes_y_limits=None, + figure_size=(10, 8), render_grid=True, + grid_line_style='--', + grid_line_width=0.5): + r""" + Plot the cumulative error distribution (CED) of the provided fitting errors. + + Parameters + ---------- + errors : `list` of `lists` + A `list` with `lists` of fitting errors. A separate CED curve will be + rendered for each errors `list`. + error_range : `list` of `float` with length 3, optional + Specifies the horizontal axis range, i.e. + + :: + + error_range[0] = min_error + error_range[1] = max_error + error_range[2] = error_step + + If ``None``, then ``'error_range = [0., 0.101, 0.005]'``. + figure_id : `object`, optional + The id of the figure to be used. + new_figure : `bool`, optional + If ``True``, a new figure is created. + title : `str`, optional + The figure's title. + x_label : `str`, optional + The label of the horizontal axis. + y_label : `str`, optional + The label of the vertical axis. + legend_entries : `list of `str` or ``None``, optional + If `list` of `str`, it must have the same length as `errors` `list` and + each `str` will be used to name each curve. If ``None``, the CED curves + will be named as `'Curve %d'`. + render_lines : `bool` or `list` of `bool`, optional + If ``True``, the line will be rendered. If `bool`, this value will be + used for all curves. If `list`, a value must be specified for each + fitting errors curve, thus it must have the same length as `errors`. + line_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} or + ``(3, )`` `ndarray` or `list` of those or ``None``, optional + The colour of the lines. If not a `list`, this value will be + used for all curves. If `list`, a value must be specified for each + fitting errors curve, thus it must have the same length as `errors`. If + ``None``, the colours will be linearly sampled from jet colormap. + line_style : {``-``, ``--``, ``-.``, ``:``} or `list` of those, optional + The style of the lines. If not a `list`, this value will be used for all + curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + line_width : `float` or `list` of `float`, optional + The width of the lines. If `float`, this value will be used for all + curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + render_markers : `bool` or `list` of `bool`, optional + If ``True``, the markers will be rendered. If `bool`, this value will be + used for all curves. If `list`, a value must be specified for each + fitting errors curve, thus it must have the same length as `errors`. + marker_style : {``.``, ``,``, ``o``, ``v``, ``^``, ``<``, ``>``, ``+``, + ``x``, ``D``, ``d``, ``s``, ``p``, ``*``, ``h``, ``H``, + ``1``, ``2``, ``3``, ``4``, ``8``} or `list` of those, optional + The style of the markers. If not a `list`, this value will be used for + all curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + marker_size : `int` or `list` of `int`, optional + The size of the markers in points^2. If `int`, this value will be used + for all curves. If `list`, a value must be specified for each fitting + errors curve, thus it must have the same length as `errors`. + marker_face_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray` or `list` of those or ``None``, optional + The face (filling) colour of the markers. If not a `list`, this value + will be used for all curves. If `list`, a value must be specified for + each fitting errors curve, thus it must have the same length as + `errors`. If ``None``, the colours will be linearly sampled from jet + colormap. + marker_edge_colour : {``r``, ``g``, ``b``, ``c``, ``m``, ``k``, ``w``} + or ``(3, )`` `ndarray` or `list` of those or ``None``, optional + The edge colour of the markers. If not a `list`, this value will be used + for all curves. If `list`, a value must be specified for each fitting + errors curve, thus it must have the same length as `errors`. If + ``None``, the colours will be linearly sampled from jet colormap. + marker_edge_width : `float` or `list` of `float`, optional + The width of the markers' edge. If `float`, this value will be used for + all curves. If `list`, a value must be specified for each fitting errors + curve, thus it must have the same length as `errors`. + render_legend : `bool`, optional + If ``True``, the legend will be rendered. + legend_title : `str`, optional + The title of the legend. + legend_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the legend. + legend_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the legend. + legend_font_size : `int`, optional + The font size of the legend. + legend_font_weight : {``ultralight``, ``light``, ``normal``, + ``regular``, ``book``, ``medium``, ``roman``, + ``semibold``, ``demibold``, ``demi``, ``bold``, + ``heavy``, ``extra bold``, ``black``}, optional + The font weight of the legend. + legend_marker_scale : `float`, optional + The relative size of the legend markers with respect to the original + legend_location : `int`, optional + The location of the legend. The predefined values are: + + =============== === + 'best' 0 + 'upper right' 1 + 'upper left' 2 + 'lower left' 3 + 'lower right' 4 + 'right' 5 + 'center left' 6 + 'center right' 7 + 'lower center' 8 + 'upper center' 9 + 'center' 10 + =============== === + + legend_bbox_to_anchor : (`float`, `float`), optional + The bbox that the legend will be anchored. + legend_border_axes_pad : `float`, optional + The pad between the axes and legend border. + legend_n_columns : `int`, optional + The number of the legend's columns. + legend_horizontal_spacing : `float`, optional + The spacing between the columns. + legend_vertical_spacing : `float`, optional + The vertical space between the legend entries. + legend_border : `bool`, optional + If ``True``, a frame will be drawn around the legend. + legend_border_padding : `float`, optional + The fractional whitespace inside the legend border. + legend_shadow : `bool`, optional + If ``True``, a shadow will be drawn behind legend. + legend_rounded_corners : `bool`, optional + If ``True``, the frame's corners will be rounded (fancybox). + render_axes : `bool`, optional + If ``True``, the axes will be rendered. + axes_font_name : {``serif``, ``sans-serif``, ``cursive``, ``fantasy``, + ``monospace``}, optional + The font of the axes. + axes_font_size : `int`, optional + The font size of the axes. + axes_font_style : {``normal``, ``italic``, ``oblique``}, optional + The font style of the axes. + axes_font_weight : {``ultralight``, ``light``, ``normal``, ``regular``, + ``book``, ``medium``, ``roman``, ``semibold``, + ``demibold``, ``demi``, ``bold``, ``heavy``, + ``extra bold``, ``black``}, optional + The font weight of the axes. + axes_x_limits : (`float`, `float`) or ``None``, optional + The limits of the x axis. If ``None``, it is set to + ``(0., 'errors_max')``. + axes_y_limits : (`float`, `float`) or ``None``, optional + The limits of the y axis. If ``None``, it is set to ``(0., 1.)``. + figure_size : (`float`, `float`) or ``None``, optional + The size of the figure in inches. + render_grid : `bool`, optional + If ``True``, the grid will be rendered. + grid_line_style : {``-``, ``--``, ``-.``, ``:``}, optional + The style of the grid lines. + grid_line_width : `float`, optional + The width of the grid lines. + + Raises + ------ + ValueError + legend_entries list has different length than errors list + + Returns + ------- + viewer : :map:`GraphPlotter` + The viewer object. + """ + from menpo.visualize import GraphPlotter + + # make sure that errors is a list even with one list member + if not isinstance(errors[0], list): + errors = [errors] + + # create x and y axes lists + x_axis = list(np.arange(error_range[0], error_range[1], error_range[2])) + ceds = [compute_cumulative_error(e, x_axis) for e in errors] + + # parse legend_entries, axes_x_limits and axes_y_limits + if legend_entries is None: + legend_entries = ["Curve {}".format(k) for k in range(len(ceds))] + if len(legend_entries) != len(ceds): + raise ValueError('legend_entries list has different length than errors ' + 'list') + if axes_x_limits is None: + axes_x_limits = (0., x_axis[-1]) + if axes_y_limits is None: + axes_y_limits = (0., 1.) + + # render + return GraphPlotter(figure_id=figure_id, new_figure=new_figure, + x_axis=x_axis, y_axis=ceds, title=title, + legend_entries=legend_entries, x_label=x_label, + y_label=y_label, x_axis_limits=axes_x_limits, + y_axis_limits=axes_y_limits).render( + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, render_legend=render_legend, + legend_title=legend_title, legend_font_name=legend_font_name, + legend_font_style=legend_font_style, legend_font_size=legend_font_size, + legend_font_weight=legend_font_weight, + legend_marker_scale=legend_marker_scale, + legend_location=legend_location, + legend_bbox_to_anchor=legend_bbox_to_anchor, + legend_border_axes_pad=legend_border_axes_pad, + legend_n_columns=legend_n_columns, + legend_horizontal_spacing=legend_horizontal_spacing, + legend_vertical_spacing=legend_vertical_spacing, + legend_border=legend_border, + legend_border_padding=legend_border_padding, + legend_shadow=legend_shadow, + legend_rounded_corners=legend_rounded_corners, render_axes=render_axes, + axes_font_name=axes_font_name, axes_font_size=axes_font_size, + axes_font_style=axes_font_style, axes_font_weight=axes_font_weight, + figure_size=figure_size, render_grid=render_grid, + grid_line_style=grid_line_style, grid_line_width=grid_line_width) + diff --git a/menpofit/test/aam_builder_test.py b/menpofit/test/aam_builder_test.py index 31a1119..a8ce97d 100644 --- a/menpofit/test/aam_builder_test.py +++ b/menpofit/test/aam_builder_test.py @@ -9,7 +9,7 @@ from menpo.feature import sparse_hog, igo, lbp, no_op import menpo.io as mio -from menpo.landmark import labeller, ibug_face_68_trimesh +from menpo.landmark import ibug_face_68_trimesh from menpofit.aam import AAMBuilder, PatchBasedAAMBuilder @@ -19,23 +19,23 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training.append(im) # build aams +template_trilist_image = training[0].landmarks[None] +trilist = ibug_face_68_trimesh(template_trilist_image)[1].lms.trilist aam1 = AAMBuilder(features=[igo, sparse_hog, no_op], transform=PiecewiseAffine, - trilist=training[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=3, downscale=2, scaled_shape_models=False, max_shape_components=[1, 2, 3], max_appearance_components=[3, 3, 3], - boundary=3).build(training, group='PTS') + boundary=3).build(training) aam2 = AAMBuilder(features=[no_op, no_op], transform=ThinPlateSplines, @@ -46,7 +46,7 @@ scaled_shape_models=True, max_shape_components=None, max_appearance_components=1, - boundary=0).build(training, group='PTS') + boundary=0).build(training) aam3 = AAMBuilder(features=igo, transform=ThinPlateSplines, @@ -57,7 +57,7 @@ scaled_shape_models=True, max_shape_components=[2], max_appearance_components=10, - boundary=2).build(training, group='PTS') + boundary=2).build(training) aam4 = PatchBasedAAMBuilder(features=lbp, patch_shape=(10, 13), @@ -67,55 +67,51 @@ scaled_shape_models=True, max_shape_components=1, max_appearance_components=None, - boundary=2).build(training, group='PTS') + boundary=2).build(training) @raises(ValueError) def test_features_exception(): - AAMBuilder(features=[igo, sparse_hog]).build(training, group='PTS') + AAMBuilder(features=[igo, sparse_hog]).build(training) @raises(ValueError) def test_n_levels_exception(): - AAMBuilder(n_levels=0).build(training, group='PTS') + AAMBuilder(n_levels=0).build(training) @raises(ValueError) def test_downscale_exception(): - aam = AAMBuilder(downscale=1).build(training, - group='PTS') + aam = AAMBuilder(downscale=1).build(training) assert (aam.downscale == 1) - AAMBuilder(downscale=0).build(training, group='PTS') + AAMBuilder(downscale=0).build(training) @raises(ValueError) def test_normalization_diagonal_exception(): - aam = AAMBuilder(normalization_diagonal=100).build(training, - group='PTS') + aam = AAMBuilder(normalization_diagonal=100).build(training) assert (aam.appearance_models[0].n_features == 382) - AAMBuilder(normalization_diagonal=10).build(training, group='PTS') + AAMBuilder(normalization_diagonal=10).build(training) @raises(ValueError) def test_max_shape_components_exception(): - AAMBuilder(max_shape_components=[1, 0.2, 'a']).build(training, - group='PTS') + AAMBuilder(max_shape_components=[1, 0.2, 'a']).build(training) @raises(ValueError) def test_max_appearance_components_exception(): - AAMBuilder(max_appearance_components=[1, 2]).build(training, - group='PTS') + AAMBuilder(max_appearance_components=[1, 2]).build(training) @raises(ValueError) def test_boundary_exception(): - AAMBuilder(boundary=-1).build(training, group='PTS') + AAMBuilder(boundary=-1).build(training) @patch('sys.stdout', new_callable=StringIO) def test_verbose_mock(mock_stdout): - AAMBuilder().build(training, group='PTS', verbose=True) + AAMBuilder().build(training, verbose=True) @patch('sys.stdout', new_callable=StringIO) diff --git a/menpofit/test/aam_fitter_test.py b/menpofit/test/aam_fitter_test.py index 438b9c5..1402e14 100644 --- a/menpofit/test/aam_fitter_test.py +++ b/menpofit/test/aam_fitter_test.py @@ -11,7 +11,7 @@ import menpo.io as mio from menpo.shape.pointcloud import PointCloud -from menpo.landmark import labeller, ibug_face_68_trimesh +from menpo.landmark import ibug_face_68_trimesh from menpofit.aam import AAMBuilder, LucasKanadeAAMFitter from menpofit.lucaskanade.appearance import ( AlternatingForwardAdditive, AlternatingForwardCompositional, @@ -307,35 +307,34 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) # build aam +template_trilist_image = training_images[0].landmarks[None] +trilist = ibug_face_68_trimesh(template_trilist_image)[1].lms.trilist aam = AAMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, - trilist=training_images[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=3, downscale=2, scaled_shape_models=True, max_shape_components=[1, 2, 3], max_appearance_components=[3, 2, 1], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) aam2 = AAMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, - trilist=training_images[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=1, downscale=2, scaled_shape_models=True, max_shape_components=[1], max_appearance_components=[1], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) def test_aam(): @@ -368,7 +367,7 @@ def test_n_appearance_exception(): def test_pertrurb_shape(): fitter = LucasKanadeAAMFitter(aam) - s = fitter.perturb_shape(training_images[0].landmarks['PTS'].lms, + s = fitter.perturb_shape(training_images[0].landmarks[None].lms, noise_std=0.08, rotation=False) assert (s.n_dims == 2) assert (s.n_landmark_groups == 0) @@ -410,7 +409,7 @@ def aam_helper(aam, algorithm, im_number, max_iters, initial_error, fitter = LucasKanadeAAMFitter(aam, algorithm=algorithm) fitting_result = fitter.fit( training_images[im_number], initial_shape[im_number], - gt_shape=training_images[im_number].landmarks['PTS'].lms, + gt_shape=training_images[im_number].landmarks[None].lms, max_iters=max_iters) assert_allclose( np.around(fitting_result.initial_error(error_type=error_type), 5), diff --git a/menpofit/test/atm_builder_test.py b/menpofit/test/atm_builder_test.py index 1f0244f..437a999 100644 --- a/menpofit/test/atm_builder_test.py +++ b/menpofit/test/atm_builder_test.py @@ -20,7 +20,7 @@ im = mio.import_builtin_asset(filenames[i]) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') - training.append(im.landmarks['PTS']['all']) + training.append(im.landmarks[None].lms) templates.append(im) # build atms @@ -31,7 +31,7 @@ downscale=2, scaled_shape_models=False, max_shape_components=[1, 2, 3], - boundary=3).build(training, templates[0], group='PTS') + boundary=3).build(training, templates[0]) atm2 = ATMBuilder(features=[no_op, no_op], transform=ThinPlateSplines, @@ -41,7 +41,7 @@ downscale=1.2, scaled_shape_models=True, max_shape_components=None, - boundary=0).build(training, templates[1], group='PTS') + boundary=0).build(training, templates[1]) atm3 = ATMBuilder(features=igo, transform=ThinPlateSplines, @@ -51,7 +51,7 @@ downscale=3, scaled_shape_models=True, max_shape_components=[2], - boundary=2).build(training, templates[2], group='PTS') + boundary=2).build(training, templates[2]) atm4 = PatchBasedATMBuilder(features=lbp, patch_shape=(10, 13), @@ -60,8 +60,7 @@ downscale=1.2, scaled_shape_models=True, max_shape_components=1, - boundary=2).build(training, templates[3], - group='PTS') + boundary=2).build(training, templates[3]) @raises(ValueError) @@ -76,23 +75,21 @@ def test_n_levels_exception(): @raises(ValueError) def test_downscale_exception(): - atm = ATMBuilder(downscale=1).build(training, templates[2], group='PTS') + atm = ATMBuilder(downscale=1).build(training, templates[2]) assert (atm.downscale == 1) - ATMBuilder(downscale=0).build(training, templates[2], group='PTS') + ATMBuilder(downscale=0).build(training, templates[2]) @raises(ValueError) def test_normalization_diagonal_exception(): - atm = ATMBuilder(normalization_diagonal=100).build(training, templates[3], - group='PTS') + atm = ATMBuilder(normalization_diagonal=100).build(training, templates[3]) assert (atm.warped_templates[0].n_true_pixels() == 1246) ATMBuilder(normalization_diagonal=10).build(training, templates[3]) @raises(ValueError) def test_max_shape_components_exception(): - ATMBuilder(max_shape_components=[1, 0.2, 'a']).build(training, templates[0], - group='PTS') + ATMBuilder(max_shape_components=[1, 0.2, 'a']).build(training, templates[0]) @raises(ValueError) @@ -102,12 +99,12 @@ def test_max_shape_components_exception_2(): @raises(ValueError) def test_boundary_exception(): - ATMBuilder(boundary=-1).build(training, templates[1], group='PTS') + ATMBuilder(boundary=-1).build(training, templates[1]) @patch('sys.stdout', new_callable=StringIO) def test_verbose_mock(mock_stdout): - ATMBuilder().build(training, templates[2], group='PTS', verbose=True) + ATMBuilder().build(training, templates[2], verbose=True) @patch('sys.stdout', new_callable=StringIO) diff --git a/menpofit/test/atm_fitter_test.py b/menpofit/test/atm_fitter_test.py index 4cf91d1..0a1459e 100644 --- a/menpofit/test/atm_fitter_test.py +++ b/menpofit/test/atm_fitter_test.py @@ -12,8 +12,8 @@ from menpo.shape.pointcloud import PointCloud from menpofit.atm import ATMBuilder, LucasKanadeATMFitter from menpofit.lucaskanade.image import (ImageInverseCompositional, - ImageForwardAdditive, - ImageForwardCompositional) + ImageForwardAdditive, + ImageForwardCompositional) initial_shape = [] @@ -302,7 +302,7 @@ im.crop_to_landmarks_proportion_inplace(0.1) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') - training_shapes.append(im.landmarks['PTS']['all']) + training_shapes.append(im.landmarks[None].lms) templates.append(im) @@ -314,7 +314,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1, 2, 3], - boundary=3).build(training_shapes, templates[0], group='PTS') + boundary=3).build(training_shapes, templates[0]) atm2 = ATMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, @@ -323,7 +323,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1], - boundary=3).build(training_shapes, templates[1], group='PTS') + boundary=3).build(training_shapes, templates[1]) atm3 = ATMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, @@ -332,7 +332,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1, 2, 3], - boundary=3).build(training_shapes, templates[2], group='PTS') + boundary=3).build(training_shapes, templates[2]) atm4 = ATMBuilder(features=igo, transform=DifferentiablePiecewiseAffine, @@ -341,7 +341,7 @@ downscale=2, scaled_shape_models=True, max_shape_components=[1], - boundary=3).build(training_shapes, templates[3], group='PTS') + boundary=3).build(training_shapes, templates[3]) def test_atm1(): @@ -371,7 +371,7 @@ def test_n_shape_exception_2(): def test_pertrurb_shape(): fitter = LucasKanadeATMFitter(atm1) - s = fitter.perturb_shape(templates[0].landmarks['PTS'].lms, + s = fitter.perturb_shape(templates[0].landmarks[None].lms, noise_std=0.08, rotation=False) assert (s.n_dims == 2) assert (s.n_landmark_groups == 0) @@ -412,7 +412,7 @@ def atm_helper(atm, algorithm, im_number, max_iters, initial_error, fitter = LucasKanadeATMFitter(atm, algorithm=algorithm) fitting_result = fitter.fit( templates[im_number], initial_shape[im_number], - gt_shape=templates[im_number].landmarks['PTS'].lms, + gt_shape=templates[im_number].landmarks[None].lms, max_iters=max_iters) assert_allclose( np.around(fitting_result.initial_error(error_type=error_type), 5), diff --git a/menpofit/test/clm_builder_test.py b/menpofit/test/clm_builder_test.py index c9fbf52..7219c3f 100644 --- a/menpofit/test/clm_builder_test.py +++ b/menpofit/test/clm_builder_test.py @@ -8,7 +8,6 @@ from menpo.feature import sparse_hog, igo, no_op import menpo.io as mio -from menpo.landmark import labeller, ibug_face_68_trimesh from menpofit.clm import CLMBuilder from menpofit.clm.classifier import linear_svm_lr from menpofit.base import name_of_callable @@ -29,7 +28,6 @@ def random_forest_predict(x): for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) @@ -43,7 +41,7 @@ def random_forest_predict(x): downscale=2, scaled_shape_models=False, max_shape_components=[1, 2, 3], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) clm2 = CLMBuilder(classifier_trainers=[random_forest, linear_svm_lr], patch_shape=(3, 10), @@ -53,7 +51,7 @@ def random_forest_predict(x): downscale=1.2, scaled_shape_models=True, max_shape_components=None, - boundary=0).build(training_images, group='PTS') + boundary=0).build(training_images) clm3 = CLMBuilder(classifier_trainers=[linear_svm_lr], patch_shape=(2, 3), @@ -63,69 +61,65 @@ def random_forest_predict(x): downscale=3, scaled_shape_models=True, max_shape_components=[1], - boundary=2).build(training_images, group='PTS') + boundary=2).build(training_images) @raises(ValueError) def test_classifier_type_1_exception(): CLMBuilder(classifier_trainers=[linear_svm_lr, linear_svm_lr]).build( - training_images, group='PTS') + training_images) @raises(ValueError) def test_classifier_type_2_exception(): - CLMBuilder(classifier_trainers=['linear_svm_lr']).build(training_images, - group='PTS') + CLMBuilder(classifier_trainers=['linear_svm_lr']).build(training_images) @raises(ValueError) def test_patch_shape_1_exception(): - CLMBuilder(patch_shape=(5, 1)).build(training_images, group='PTS') + CLMBuilder(patch_shape=(5, 1)).build(training_images) @raises(ValueError) def test_patch_shape_2_exception(): - CLMBuilder(patch_shape=(5, 6, 7)).build(training_images, group='PTS') + CLMBuilder(patch_shape=(5, 6, 7)).build(training_images) @raises(ValueError) def test_features_exception(): - CLMBuilder(features=[igo, sparse_hog]).build(training_images, group='PTS') + CLMBuilder(features=[igo, sparse_hog]).build(training_images) @raises(ValueError) def test_n_levels_exception(): - clm = CLMBuilder(n_levels=0).build(training_images, group='PTS') + clm = CLMBuilder(n_levels=0).build(training_images) @raises(ValueError) def test_downscale_exception(): - clm = CLMBuilder(downscale=1).build(training_images, group='PTS') + clm = CLMBuilder(downscale=1).build(training_images) assert (clm.downscale == 1) - CLMBuilder(downscale=0).build(training_images, group='PTS') + CLMBuilder(downscale=0).build(training_images) @raises(ValueError) def test_normalization_diagonal_exception(): - CLMBuilder(normalization_diagonal=10).build(training_images, - group='PTS') + CLMBuilder(normalization_diagonal=10).build(training_images) @raises(ValueError) def test_max_shape_components_1_exception(): - CLMBuilder(max_shape_components=[1, 0.2, 'a']).build(training_images, - group='PTS') + CLMBuilder(max_shape_components=[1, 0.2, 'a']).build(training_images) @raises(ValueError) def test_max_shape_components_2_exception(): - CLMBuilder(max_shape_components=[1, 2]).build(training_images, - group='PTS') + CLMBuilder(max_shape_components=[1, 2]).build(training_images) @raises(ValueError) def test_boundary_exception(): - CLMBuilder(boundary=-1).build(training_images, group='PTS') + CLMBuilder(boundary=-1).build(training_images) @patch('sys.stdout', new_callable=StringIO) def test_verbose_mock(mock_stdout): - CLMBuilder().build(training_images, group='PTS', verbose=True) + CLMBuilder().build(training_images, verbose=True) @patch('sys.stdout', new_callable=StringIO) diff --git a/menpofit/test/clm_fitter_test.py b/menpofit/test/clm_fitter_test.py index d442bc0..fd11932 100644 --- a/menpofit/test/clm_fitter_test.py +++ b/menpofit/test/clm_fitter_test.py @@ -8,7 +8,6 @@ import menpo.io as mio from menpo.shape.pointcloud import PointCloud -from menpo.landmark import labeller, ibug_face_68_trimesh from menpofit.clm import CLMBuilder from menpofit.clm import GradientDescentCLMFitter from menpofit.gradientdescent import RegularizedLandmarkMeanShift @@ -299,7 +298,6 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) @@ -313,7 +311,7 @@ downscale=1.1, scaled_shape_models=True, max_shape_components=[1, 2, 3], - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) def test_clm(): @@ -353,7 +351,7 @@ def test_n_shape_2_exception(): def test_perturb_shape(): fitter = GradientDescentCLMFitter(clm) - s = fitter.perturb_shape(training_images[0].landmarks['PTS'].lms, + s = fitter.perturb_shape(training_images[0].landmarks[None].lms, noise_std=0.08, rotation=False) assert (s.n_dims == 2) assert (s.n_landmark_groups == 0) diff --git a/menpofit/test/sdm_test.py b/menpofit/test/sdm_test.py index b1189f2..6a08ed5 100644 --- a/menpofit/test/sdm_test.py +++ b/menpofit/test/sdm_test.py @@ -21,7 +21,6 @@ for i in range(4): im = mio.import_builtin_asset(filenames[i]) im.crop_to_landmarks_proportion_inplace(0.1) - labeller(im, 'PTS', ibug_face_68_trimesh) if im.n_channels == 3: im = im.as_greyscale(mode='luminosity') training_images.append(im) @@ -29,17 +28,18 @@ # Seed the random number generator np.random.seed(seed=1000) +template_trilist_image = training_images[0].landmarks[None] +trilist = ibug_face_68_trimesh(template_trilist_image)[1].lms.trilist aam = AAMBuilder(features=sparse_hog, transform=PiecewiseAffine, - trilist=training_images[0].landmarks['ibug_face_68_trimesh']. - lms.trilist, + trilist=trilist, normalization_diagonal=150, n_levels=3, downscale=1.2, scaled_shape_models=False, max_shape_components=None, max_appearance_components=3, - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) clm = CLMBuilder(classifier_trainers=linear_svm_lr, features=sparse_hog, @@ -49,53 +49,51 @@ downscale=1.1, scaled_shape_models=True, max_shape_components=25, - boundary=3).build(training_images, group='PTS') + boundary=3).build(training_images) @raises(ValueError) def test_features_exception(): sdm = SDMTrainer(features=[igo, sparse_hog], - n_levels=3).train(training_images, group='PTS') + n_levels=3).train(training_images) @raises(ValueError) def test_regression_features_sdmtrainer_exception_1(): sdm = SDMTrainer(n_levels=2, regression_features=[no_op, no_op, no_op]).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_regression_features_sdmtrainer_exception_2(): sdm = SDMTrainer(n_levels=3, regression_features=[no_op, sparse_hog, 1]).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_regression_features_sdaamtrainer_exception_1(): sdm = SDAAMTrainer(aam, regression_features=[no_op, sparse_hog]).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_regression_features_sdaamtrainer_exception_2(): sdm = SDAAMTrainer(aam, regression_features=7).\ - train(training_images, group='PTS') + train(training_images) @raises(ValueError) def test_n_levels_exception(): - sdm = SDMTrainer(n_levels=0).train(training_images, group='PTS') + sdm = SDMTrainer(n_levels=0).train(training_images) @raises(ValueError) def test_downscale_exception(): - sdm = SDMTrainer(downscale=0).train(training_images, - group='PTS') + sdm = SDMTrainer(downscale=0).train(training_images) @raises(ValueError) def test_n_perturbations_exception(): - sdm = SDAAMTrainer(aam, n_perturbations=-10).train(training_images, - group='PTS') + sdm = SDAAMTrainer(aam, n_perturbations=-10).train(training_images) @patch('sys.stdout', new_callable=StringIO) @@ -109,5 +107,5 @@ def test_verbose_mock(mock_stdout): downscale=1.3, noise_std=0.04, rotation=False, - n_perturbations=2).train(training_images, group='PTS', + n_perturbations=2).train(training_images, verbose=True) diff --git a/menpofit/visualize/__init__.py b/menpofit/visualize/__init__.py index 729256a..90a1bf3 100644 --- a/menpofit/visualize/__init__.py +++ b/menpofit/visualize/__init__.py @@ -1,4 +1,3 @@ from .widgets import (visualize_shape_model, visualize_appearance_model, visualize_aam, visualize_atm, visualize_fitting_results, plot_ced) - diff --git a/menpofit/visualize/widgets/base.py b/menpofit/visualize/widgets/base.py index cb052a3..5eba460 100644 --- a/menpofit/visualize/widgets/base.py +++ b/menpofit/visualize/widgets/base.py @@ -1,41 +1,31 @@ -from menpo.visualize.widgets.helpers import (figure_options, - format_figure_options, - figure_options_two_scales, - format_figure_options_two_scales, +import numpy as np +from collections import OrderedDict + +from menpo.visualize.widgets.options import (viewer_options, + format_viewer_options, channel_options, format_channel_options, update_channel_options, landmark_options, format_landmark_options, info_print, format_info_print, - model_parameters, - format_model_parameters, - update_model_parameters, - final_result_options, - format_final_result_options, - update_final_result_options, - iterations_result_options, - format_iterations_result_options, - update_iterations_result_options, animation_options, format_animation_options, - plot_options, - format_plot_options, save_figure_options, format_save_figure_options) -from menpo.visualize.widgets.base import (_plot_figure, _plot_graph, - _plot_eigenvalues, - _check_n_parameters, - _raw_info_string_to_latex, - _extract_groups_labels) -from IPython.html.widgets import (FloatTextWidget, TextWidget, PopupWidget, - ContainerWidget, TabWidget, FloatSliderWidget, - RadioButtonsWidget, CheckboxWidget, - DropdownWidget, AccordionWidget, ButtonWidget) -from IPython.display import display, clear_output -import matplotlib.pylab as plt -import numpy as np -from collections import OrderedDict +from menpo.visualize.widgets.tools import logo +from menpo.visualize.widgets.base import _visualize as _visualize_menpo +from menpo.visualize.widgets.base import _extract_groups_labels +from menpo.visualize.viewmatplotlib import (MatplotlibImageViewer2d, + sample_colours_from_colourmap, + MatplotlibSubplots) + +from .options import (model_parameters, format_model_parameters, + update_model_parameters, final_result_options, + format_final_result_options, update_final_result_options, + iterations_result_options, + format_iterations_result_options, + update_iterations_result_options) # This glyph import is called frequently during visualisation, so we ensure # that we only import it once @@ -43,8 +33,8 @@ def visualize_shape_model(shape_models, n_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(7, 7), - mode='multiple', popup=False, **kwargs): + parameters_bounds=(-3.0, 3.0), figure_size=(10, 8), + mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel shape model. @@ -53,32 +43,28 @@ def visualize_shape_model(shape_models, n_parameters=5, shape_models : `list` of :map:`PCAModel` or subclass The multilevel shape model to be displayed. Note that each level can have different number of components. - n_parameters : `int` or `list` of `int` or None, optional The number of principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt + from matplotlib import collections as mc + # make sure that shape_models is a list even with one member if not isinstance(shape_models, list): shape_models = [shape_models] @@ -93,52 +79,104 @@ def visualize_shape_model(shape_models, n_parameters=5, # the returned n_parameters is a list of len n_levels n_parameters = _check_n_parameters(n_parameters, n_levels, max_n_params) + # initial options dictionaries + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': ['r'], + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': False, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options} + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get params + # get selected level level = 0 if n_levels > 1: level = level_wid.value - def_mode = mode_wid.value - axis_mode = axes_mode_wid.value - parameters_values = model_parameters_wid.parameters_values - x_scale = figure_options_wid.x_scale - y_scale = figure_options_wid.y_scale - axes_visible = figure_options_wid.axes_visible # compute weights + parameters_values = model_parameters_wid.parameters_values weights = (parameters_values * - shape_models[level].eigenvalues[:len(parameters_values)] ** 0.5) + shape_models[level].eigenvalues[:len(parameters_values)] ** + 0.5) # compute the mean mean = shape_models[level].mean() - # select figure - figure_id = plt.figure(save_figure_wid.figure_id.number) - - # invert axis if image mode is enabled - if axis_mode == 1: - plt.gca().invert_yaxis() + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) # compute and show instance - if def_mode == 1: + if mode_wid.value == 1: # Deformation mode # compute instance instance = shape_models[level].instance(weights) # plot if mean_wid.value: - mean.view(image_view=axis_mode == 1, colour_array='y', - **kwargs) - plt.hold(True) - instance.view(image_view=axis_mode == 1, **kwargs) + mean.view( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, image_view=axes_mode_wid.value == 1, + render_lines=tmp1['render_lines'], + line_colour='y', + line_style='solid', line_width=tmp1['line_width'], + render_markers=tmp2['render_markers'], + marker_style=tmp2['marker_style'], + marker_size=tmp2['marker_size'], marker_face_colour='y', + marker_edge_colour='y', + marker_edge_width=tmp2['marker_edge_width'], + render_axes=False, figure_size=None) + + renderer = instance.view( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, image_view=axes_mode_wid.value==1, + render_lines=tmp1['render_lines'], + line_colour=tmp1['line_colour'][0], + line_style=tmp1['line_style'], line_width=tmp1['line_width'], + render_markers=tmp2['render_markers'], + marker_style=tmp2['marker_style'], + marker_size=tmp2['marker_size'], + marker_face_colour=tmp2['marker_face_colour'], + marker_edge_colour=tmp2['marker_edge_colour'], + marker_edge_width=tmp2['marker_edge_width'], + render_axes=tmp3['render_axes'], + axes_font_name=tmp3['axes_font_name'], + axes_font_size=tmp3['axes_font_size'], + axes_font_style=tmp3['axes_font_style'], + axes_font_weight=tmp3['axes_font_weight'], + axes_x_limits=tmp3['axes_x_limits'], + axes_y_limits=tmp3['axes_y_limits'], + figure_size=new_figure_size, + label=None) + + if mean_wid.value and axes_mode_wid.value == 1: + plt.gca().invert_yaxis() # instance range - tmp_range = instance.range() + instance_range = instance.range() else: # Vectors mode # compute instance @@ -146,8 +184,28 @@ def plot_function(name, value): instance_upper = shape_models[level].instance(weights) # plot - mean.view(image_view=axis_mode == 1, **kwargs) - plt.hold(True) + renderer = mean.view( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, image_view=axes_mode_wid.value == 1, + render_lines=tmp1['render_lines'], + line_colour=tmp1['line_colour'][0], + line_style=tmp1['line_style'], line_width=tmp1['line_width'], + render_markers=tmp2['render_markers'], + marker_style=tmp2['marker_style'], + marker_size=tmp2['marker_size'], + marker_face_colour=tmp2['marker_face_colour'], + marker_edge_colour=tmp2['marker_edge_colour'], + marker_edge_width=tmp2['marker_edge_width'], + render_axes=tmp3['render_axes'], + axes_font_name=tmp3['axes_font_name'], + axes_font_size=tmp3['axes_font_size'], + axes_font_style=tmp3['axes_font_style'], + axes_font_weight=tmp3['axes_font_weight'], + axes_x_limits=tmp3['axes_x_limits'], + axes_y_limits=tmp3['axes_y_limits'], + figure_size=new_figure_size, label=None) + + ax = plt.gca() for p in range(mean.n_points): xm = mean.points[p, 0] ym = mean.points[p, 1] @@ -155,77 +213,80 @@ def plot_function(name, value): yl = instance_lower.points[p, 1] xu = instance_upper.points[p, 0] yu = instance_upper.points[p, 1] - if axis_mode == 1: + if axes_mode_wid.value == 1: # image mode - plt.plot([ym, yl], [xm, xl], 'r-', lw=2) - plt.plot([ym, yu], [xm, xu], 'g-', lw=2) + lines = [[(ym, xm), (yl, xl)], [(ym, xm), (yu, xu)]] else: # point cloud mode - plt.plot([xm, xl], [ym, yl], 'r-', lw=2) - plt.plot([xm, xu], [ym, yu], 'g-', lw=2) + lines = [[(xm, ym), (xl, yl)], [(xm, ym), (xu, yu)]] + lc = mc.LineCollection(lines, colors=('g', 'b'), + linestyles='solid', linewidths=2) + ax.add_collection(lc) # instance range - tmp_range = mean.range() - - plt.hold(False) - plt.gca().axis('equal') - # set figure size - plt.gcf().set_size_inches([x_scale, y_scale] * np.asarray(figure_size)) - # turn axis on/off - if not axes_visible: - plt.axis('off') + instance_range = mean.range() + plt.show() # save the current figure id - save_figure_wid.figure_id = figure_id - - # info_wid string - info_txt = r""" - Level: {} out of {}. - {} components in total. - {} active components. - {:.1f} % variance kept. - Instance range: {:.1f} x {:.1f}. - {} landmark points, {} features. - """.format(level + 1, n_levels, shape_models[level].n_components, - shape_models[level].n_active_components, - shape_models[level].variance_ratio() * 100, tmp_range[0], - tmp_range[1], mean.n_points, - shape_models[level].n_features) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + save_figure_wid.renderer[0] = renderer + + # update info text widget + update_info(level, instance_range) + + # define function that updates info text + def update_info(level, instance_range): + lvl_sha_mod = shape_models[level] + info_wid.children[1].children[0].value = "> Level: {} out of {}.".\ + format(level + 1, n_levels) + info_wid.children[1].children[1].value = "> {} components in total.".\ + format(lvl_sha_mod.n_components) + info_wid.children[1].children[2].value = "> {} active components.".\ + format(lvl_sha_mod.n_active_components) + info_wid.children[1].children[3].value = "> {:.1f}% variance kept.".\ + format(lvl_sha_mod.variance_ratio() * 100) + info_wid.children[1].children[4].value = "> Instance range: {:.1f} " \ + "x {:.1f}.".\ + format(instance_range[0], instance_range[1]) + info_wid.children[1].children[5].value = "> {} landmark points, " \ + "{} features.".\ + format(lvl_sha_mod.mean().n_points, lvl_sha_mod.n_features) # Plot eigenvalues function def plot_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, shape_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + shape_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = shape_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # create options widgets mode_dict = OrderedDict() mode_dict['Deformation'] = 1 mode_dict['Vectors'] = 2 - mode_wid = RadioButtonsWidget(values=mode_dict, description='Mode:', - value=1) + mode_wid = ipywidgets.RadioButtonsWidget(values=mode_dict, + description='Mode:', value=1) mode_wid.on_trait_change(plot_function, 'value') - mean_wid = CheckboxWidget(value=False, description='Show mean shape') + mean_wid = ipywidgets.CheckboxWidget(value=False, + description='Show mean shape') mean_wid.on_trait_change(plot_function, 'value') # controls mean shape checkbox visibility @@ -243,19 +304,27 @@ def mean_visible(name, value): toggle_show_visible=False, plot_eig_visible=True, plot_eig_function=plot_eigenvalues) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - axes_mode_wid = RadioButtonsWidget(values={'Image': 1, 'Point cloud': 2}, - description='Axes mode:', value=1) + + # viewer options widget + axes_mode_wid = ipywidgets.RadioButtonsWidget( + values={'Image': 1, 'Point cloud': 2}, description='Axes mode:', + value=2) axes_mode_wid.on_trait_change(plot_function, 'value') - ch = list(figure_options_wid.children) - ch.insert(3, axes_mode_wid) - figure_options_wid.children = ch - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + viewer_options_all = ipywidgets.ContainerWidget(children=[axes_mode_wid, + viewer_options_wid]) + info_wid = info_print(n_bullets=6, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -274,32 +343,32 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') radio_children = [level_wid, mode_wid, mean_wid] else: radio_children = [mode_wid, mean_wid] - radio_wids = ContainerWidget(children=radio_children) - tmp_wid = ContainerWidget(children=[radio_wids, model_parameters_wid]) - wid = TabWidget(children=[tmp_wid, figure_options_wid, info_wid, - save_figure_wid]) + radio_wids = ipywidgets.ContainerWidget(children=radio_children) + tmp_wid = ipywidgets.ContainerWidget(children=[radio_wids, + model_parameters_wid]) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, viewer_options_all, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='Shape Model Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='Shape Model Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['Shape parameters', 'Figure options', 'Model info', - 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + tab_titles = ['Shape parameters', 'Viewer options', 'Info', 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # align widgets tmp_wid.remove_class('vbox') @@ -309,12 +378,13 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=True) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + border_visible=False, + suboptions_border_visible=True) + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -328,48 +398,40 @@ def update_widgets(name, value): update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + axes_mode_wid.value = 1 def visualize_appearance_model(appearance_models, n_parameters=5, - parameters_bounds=(-3.0, 3.0), - figure_size=(7, 7), mode='multiple', - popup=False, **kwargs): + parameters_bounds=(-3.0, 3.0), figure_size=(10, 8), + mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel appearance model. Parameters ----------- appearance_models : `list` of :map:`PCAModel` or subclass - The multilevel appearance model to be displayed. Note that each level - can have different attributes, e.g. number of parameters, feature type, - number of channels. - + The multilevel appearance model to be displayed. Note that each level can + have different number of components. n_parameters : `int` or `list` of `int` or None, optional The number of principal components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage # make sure that appearance_models is a list even with one member @@ -386,102 +448,175 @@ def visualize_appearance_model(appearance_models, n_parameters=5, # the returned n_parameters is a list of len n_levels n_parameters = _check_n_parameters(n_parameters, n_levels, max_n_params) - # define plot function + # find initial groups and labels that will be passed to the landmark options + # widget creation + mean_has_landmarks = appearance_models[0].mean().landmarks.n_groups != 0 + if mean_has_landmarks: + all_groups_keys, all_labels_keys = _extract_groups_labels( + appearance_models[0].mean()) + else: + all_groups_keys = [' '] + all_labels_keys = [[' ']] + + # get initial line colours for each available label + if len(all_labels_keys[0]) == 1: + line_colours = ['r'] + else: + line_colours = sample_colours_from_colourmap(len(all_labels_keys[0]), + 'jet') + + # initial options dictionaries + channels_default = 0 + if appearance_models[0].mean().n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': appearance_models[0].mean().n_channels, + 'image_is_masked': isinstance(appearance_models[0].mean(), + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': isinstance(appearance_models[0].mean(), MaskedImage)} + landmark_options_default = {'render_landmarks': mean_has_landmarks, + 'group_keys': all_groups_keys, + 'labels_keys': all_labels_keys, + 'group': None, + 'with_labels': None} + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': line_colours, + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + image_options = {'interpolation': 'none', + 'alpha': 1.} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options, + 'image': image_options} + + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected level level = 0 if n_levels > 1: level = level_wid.value - # get parameters values + # compute weights and instance parameters_values = model_parameters_wid.parameters_values - - # compute instance - weights = parameters_values * appearance_models[level].eigenvalues[:len(parameters_values)] ** 0.5 + weights = (parameters_values * + appearance_models[level].eigenvalues[:len(parameters_values)] ** + 0.5) instance = appearance_models[level].instance(weights) - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show image with selected options - new_figure_id = _plot_figure( - image=instance, figure_id=figure_id, image_enabled=True, - landmarks_enabled=landmark_options_wid.landmarks_enabled, - image_is_masked=channel_options_wid.image_is_masked, - masked_enabled=channel_options_wid.masked_enabled, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=[landmark_options_wid.group], - with_labels=[landmark_options_wid.with_labels], - groups_colours=dict(), subplots_enabled=False, - subplots_titles=dict(), image_axes_mode=True, - legend_enabled=landmark_options_wid.legend_enabled, - numbering_enabled=landmark_options_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # update info text widget + update_info(instance, level, + landmark_options_wid.selected_values['group']) + n_labels = len(landmark_options_wid.selected_values['with_labels']) - # save the current figure id - save_figure_wid.figure_id = new_figure_id + # compute the mean + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + renderer = _visualize_menpo( + instance, save_figure_wid.renderer[0], + landmark_options_wid.selected_values['render_landmarks'], + channel_options_wid.selected_values['image_is_masked'], + channel_options_wid.selected_values['masked_enabled'], + channel_options_wid.selected_values['channels'], + channel_options_wid.selected_values['glyph_enabled'], + channel_options_wid.selected_values['glyph_block_size'], + channel_options_wid.selected_values['glyph_use_negative'], + channel_options_wid.selected_values['sum_enabled'], + landmark_options_wid.selected_values['group'], + landmark_options_wid.selected_values['with_labels'], + tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], + tmp1['line_colour'][:n_labels], tmp2['render_markers'], + tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], + tmp2['marker_edge_colour'], tmp2['marker_face_colour'], + False, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + False, None, None, new_figure_size, tmp3['render_axes'], + tmp3['axes_font_name'], tmp3['axes_font_size'], + tmp3['axes_font_style'], tmp3['axes_x_limits'], + tmp3['axes_y_limits'], tmp3['axes_font_weight'], + tmp4['interpolation'], tmp4['alpha']) - # update info text widget - update_info(instance, level, landmark_options_wid.group) + # save the current figure id + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(image, level, group): lvl_app_mod = appearance_models[level] - - info_txt = r""" - Level: {} out of {}. - {} components in total. - {} active components. - {:.1f}% variance kept. - Reference shape of size {} with {} channel{}. - {} features. - {} landmark points. - Instance: min={:.3f}, max={:.3f} - """.format(level + 1, n_levels, lvl_app_mod.n_components, - lvl_app_mod.n_active_components, - lvl_app_mod.variance_ratio() * 100, image._str_shape, - image.n_channels, 's' * (image.n_channels > 1), - lvl_app_mod.n_features, image.landmarks[group].lms.n_points, - image.pixels.min(), image.pixels.max()) - - # update info widget text - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + info_wid.children[1].children[0].value = "> Level: {} out of {}.".\ + format(level + 1, n_levels) + info_wid.children[1].children[1].value = "> {} components in total.".\ + format(lvl_app_mod.n_components) + info_wid.children[1].children[2].value = "> {} active components.".\ + format(lvl_app_mod.n_active_components) + info_wid.children[1].children[3].value = "> {:.1f}% variance kept.".\ + format(lvl_app_mod.variance_ratio() * 100) + info_wid.children[1].children[4].value = "> Reference shape of size " \ + "{} with {} channel{}.".\ + format(image._str_shape, + image.n_channels, 's' * (image.n_channels > 1)) + info_wid.children[1].children[5].value = "> {} features.".\ + format(lvl_app_mod.n_features) + info_wid.children[1].children[6].value = "> {} landmark points.".\ + format(image.landmarks[group].lms.n_points) + info_wid.children[1].children[7].value = "> Instance: min={:.3f}, " \ + "max={:.3f}".\ + format(image.pixels.min(), image.pixels.max()) # Plot eigenvalues function def plot_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, appearance_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + appearance_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = appearance_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer - # create options widgets + # create parameters, channels nad landmarks options widgets model_parameters_wid = model_parameters(n_parameters[0], plot_function, params_str='param ', mode=mode, params_bounds=parameters_bounds, @@ -489,37 +624,33 @@ def plot_eigenvalues(name): toggle_show_visible=False, plot_eig_visible=True, plot_eig_function=plot_eigenvalues) - channel_options_wid = channel_options( - appearance_models[0].mean().n_channels, - isinstance(appearance_models[0].mean(), MaskedImage), plot_function, - masked_default=True, toggle_show_default=True, - toggle_show_visible=False) - - # find initial groups and labels that will be passed to the landmark options - # widget creation - mean_has_landmarks = appearance_models[0].mean().landmarks.n_groups != 0 - if mean_has_landmarks: - all_groups_keys, all_labels_keys = _extract_groups_labels( - appearance_models[0].mean()) - else: - all_groups_keys = [' '] - all_labels_keys = [[' ']] - landmark_options_wid = landmark_options( - all_groups_keys, all_labels_keys, plot_function, - toggle_show_default=True, landmarks_default=mean_has_landmarks, - legend_default=False, numbering_default=False, - toggle_show_visible=False) + channel_options_wid = channel_options(channels_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) + landmark_options_wid = landmark_options(landmark_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) # if the mean doesn't have landmarks, then landmarks checkbox should be # disabled - landmark_options_wid.children[1].children[0].disabled = \ - not mean_has_landmarks - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + landmark_options_wid.children[1].disabled = not mean_has_landmarks + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one', + 'image'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + info_wid = info_print(n_bullets=8, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -528,6 +659,7 @@ def update_widgets(name, value): # update model parameters update_model_parameters(model_parameters_wid, n_parameters[value], plot_function, params_str='param ') + # update channel options update_channel_options(channel_options_wid, appearance_models[value].mean().n_channels, @@ -545,31 +677,31 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) - tmp_wid = ContainerWidget(children=tmp_children) - wid = TabWidget(children=[tmp_wid, channel_options_wid, - landmark_options_wid, figure_options_wid, - info_wid, save_figure_wid]) + tmp_wid = ipywidgets.ContainerWidget(children=tmp_children) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, + landmark_options_wid, + viewer_options_wid, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='Appearance Model Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='Appearance Model Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles tab_titles = ['Appearance parameters', 'Channels options', - 'Landmarks options', 'Figure options', 'Model info', - 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + 'Landmarks options', 'Viewer options', 'Info', 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # align widgets tmp_wid.remove_class('vbox') @@ -589,12 +721,13 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + border_visible=False, + suboptions_border_visible=True) + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -604,16 +737,17 @@ def update_widgets(name, value): toggle_button_font_weight='bold', tab_top_margin='0cm', border_visible=False) - # update widgets' state for level 0 + # update widgets' state for image number 0 update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, - parameters_bounds=(-3.0, 3.0), figure_size=(7, 7), - mode='multiple', popup=False, **kwargs): + parameters_bounds=(-3.0, 3.0), figure_size=(10, 8), + mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel AAM. @@ -623,40 +757,33 @@ def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, The multilevel AAM to be displayed. Note that each level can have different attributes, e.g. number of active components, feature type, number of channels. - n_shape_parameters : `int` or `list` of `int` or None, optional - The number of shape principal components to be used for the parameters + The number of shape components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. n_appearance_parameters : `int` or `list` of `int` or None, optional - The number of appearance principal components to be used for the - parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + The number of appearance components to be used for the parameters + sliders. + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage # find number of levels @@ -673,53 +800,124 @@ def visualize_aam(aam, n_shape_parameters=5, n_appearance_parameters=5, n_appearance_parameters = _check_n_parameters(n_appearance_parameters, n_levels, max_n_appearance) - # define plot function + # find initial groups and labels that will be passed to the landmark options + # widget creation + mean_has_landmarks = aam.appearance_models[0].mean().landmarks.n_groups != 0 + if mean_has_landmarks: + all_groups_keys, all_labels_keys = _extract_groups_labels( + aam.appearance_models[0].mean()) + else: + all_groups_keys = [' '] + all_labels_keys = [[' ']] + + # get initial line colours for each available label + if len(all_labels_keys[0]) == 1: + line_colours = ['r'] + else: + line_colours = sample_colours_from_colourmap(len(all_labels_keys[0]), + 'jet') + + # initial options dictionaries + channels_default = 0 + if aam.appearance_models[0].mean().n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': aam.appearance_models[0].mean().n_channels, + 'image_is_masked': isinstance(aam.appearance_models[0].mean(), + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': isinstance(aam.appearance_models[0].mean(), + MaskedImage)} + landmark_options_default = {'render_landmarks': mean_has_landmarks, + 'group_keys': all_groups_keys, + 'labels_keys': all_labels_keys, + 'group': None, + 'with_labels': None} + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': line_colours, + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + image_options = {'interpolation': 'none', + 'alpha': 1.0} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options, + 'image': image_options} + + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected level level = 0 if n_levels > 1: level = level_wid.value - # get weights and compute instance + # compute weights and instance shape_weights = shape_model_parameters_wid.parameters_values appearance_weights = appearance_model_parameters_wid.parameters_values instance = aam.instance(level=level, shape_weights=shape_weights, appearance_weights=appearance_weights) - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show image with selected options - new_figure_id = _plot_figure( - image=instance, figure_id=figure_id, image_enabled=True, - landmarks_enabled=landmark_options_wid.landmarks_enabled, - image_is_masked=channel_options_wid.image_is_masked, - masked_enabled=channel_options_wid.masked_enabled, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=[landmark_options_wid.group], - with_labels=[landmark_options_wid.with_labels], - groups_colours=dict(), subplots_enabled=False, - subplots_titles=dict(), image_axes_mode=True, - legend_enabled=landmark_options_wid.legend_enabled, - numbering_enabled=landmark_options_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # update info text widget + update_info(aam, instance, level, + landmark_options_wid.selected_values['group']) + n_labels = len(landmark_options_wid.selected_values['with_labels']) + + # plot + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + renderer = _visualize_menpo( + instance, save_figure_wid.renderer[0], + landmark_options_wid.selected_values['render_landmarks'], + channel_options_wid.selected_values['image_is_masked'], + channel_options_wid.selected_values['masked_enabled'], + channel_options_wid.selected_values['channels'], + channel_options_wid.selected_values['glyph_enabled'], + channel_options_wid.selected_values['glyph_block_size'], + channel_options_wid.selected_values['glyph_use_negative'], + channel_options_wid.selected_values['sum_enabled'], + landmark_options_wid.selected_values['group'], + landmark_options_wid.selected_values['with_labels'], + tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], + tmp1['line_colour'][:n_labels], tmp2['render_markers'], + tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], + tmp2['marker_edge_colour'], tmp2['marker_face_colour'], + False, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + False, None, None, new_figure_size, tmp3['render_axes'], + tmp3['axes_font_name'], tmp3['axes_font_size'], + tmp3['axes_font_style'], tmp3['axes_x_limits'], + tmp3['axes_y_limits'], tmp3['axes_font_weight'], + tmp4['interpolation'], tmp4['alpha']) # save the current figure id - save_figure_wid.figure_id = new_figure_id - - # update info text widget - update_info(aam, instance, level, landmark_options_wid.group) + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(aam, instance, level, group): @@ -756,113 +954,125 @@ def update_info(aam, instance, level, group): else: tmp_pyramid = "Features were extracted at each pyramid level." - # Formatting is a bit ugly but this is MUCH easier to read. - info_txt = r""" - {} training images. - Warp using {} transform. - Level {}/{} (downscale={:.1f}). - {} - {} - {} - Reference frame of length {} ({} x {}C, {} x {}C). - {} shape components ({:.2f}% of variance) - {} appearance components ({:.2f}% of variance) - {} landmark points. - Instance: min={:.3f} , max={:.3f} - """.format(aam.n_training_images, aam.transform.__name__, - level + 1, - aam.n_levels, aam.downscale, tmp_shape_models, - tmp_pyramid, tmp_feat, lvl_app_mod.n_features, - tmplt_inst.n_true_pixels(), n_channels, - tmplt_inst._str_shape, n_channels, - lvl_shape_mod.n_components, - lvl_shape_mod.variance_ratio() * 100, - lvl_app_mod.n_components, - lvl_app_mod.variance_ratio() * 100, - instance.landmarks[group].lms.n_points, - instance.pixels.min(), instance.pixels.max()) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + # update info widgets + info_wid.children[1].children[0].value = "> {} training images.".\ + format(aam.n_training_images) + info_wid.children[1].children[1].value = "> Warp using {} transform.".\ + format(aam.transform.__name__) + info_wid.children[1].children[2].value = "> Level {}/{} " \ + "(downscale={:.1f}).".\ + format(level + 1, aam.n_levels, aam.downscale) + info_wid.children[1].children[3].value = "> {}".format(tmp_shape_models) + info_wid.children[1].children[4].value = "> {}".format(tmp_pyramid) + info_wid.children[1].children[5].value = "> {}".format(tmp_feat) + info_wid.children[1].children[6].value = "> Reference frame of " \ + "length {} ({} x {}C, {} x " \ + "{}C).".\ + format(lvl_app_mod.n_features, tmplt_inst.n_true_pixels(), + n_channels, tmplt_inst._str_shape, n_channels) + info_wid.children[1].children[7].value = "> {} shape components " \ + "({:.2f}% of variance).".\ + format(lvl_shape_mod.n_components, + lvl_shape_mod.variance_ratio() * 100) + info_wid.children[1].children[8].value = "> {} appearance components " \ + "({:.2f}% of variance).".\ + format(lvl_app_mod.n_components, lvl_app_mod.variance_ratio() * 100) + info_wid.children[1].children[9].value = "> {} landmark points.".\ + format(instance.landmarks[group].lms.n_points) + info_wid.children[1].children[10].value = "> Instance: min={:.3f} , " \ + "max={:.3f}.".\ + format(instance.pixels.min(), instance.pixels.max()) # Plot shape eigenvalues function def plot_shape_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, aam.shape_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + aam.shape_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = aam.shape_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # Plot appearance eigenvalues function def plot_appearance_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, - aam.appearance_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + aam.appearance_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = aam.appearance_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer - # create options widgets + # create parameters, channels nad landmarks options widgets shape_model_parameters_wid = model_parameters( n_shape_parameters[0], plot_function, params_str='param ', mode=mode, - params_bounds=parameters_bounds, toggle_show_default=False, - toggle_show_visible=True, toggle_show_name='Shape Parameters', + params_bounds=parameters_bounds, toggle_show_default=True, + toggle_show_visible=False, toggle_show_name='Shape Parameters', plot_eig_visible=True, plot_eig_function=plot_shape_eigenvalues) appearance_model_parameters_wid = model_parameters( n_appearance_parameters[0], plot_function, params_str='param ', - mode=mode, params_bounds=parameters_bounds, toggle_show_default=False, - toggle_show_visible=True, toggle_show_name='Appearance Parameters', + mode=mode, params_bounds=parameters_bounds, toggle_show_default=True, + toggle_show_visible=False, toggle_show_name='Appearance Parameters', plot_eig_visible=True, plot_eig_function=plot_appearance_eigenvalues) - channel_options_wid = channel_options( - aam.appearance_models[0].mean().n_channels, - isinstance(aam.appearance_models[0].mean(), MaskedImage), plot_function, - masked_default=True, toggle_show_default=True, - toggle_show_visible=False) - all_groups_keys, all_labels_keys = \ - _extract_groups_labels(aam.appearance_models[0].mean()) - landmark_options_wid = landmark_options(all_groups_keys, all_labels_keys, - plot_function, + channel_options_wid = channel_options(channels_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) + landmark_options_wid = landmark_options(landmark_options_default, + plot_function=plot_function, toggle_show_default=True, - landmarks_default=True, - legend_default=False, - numbering_default=False, toggle_show_visible=False) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + # if the mean doesn't have landmarks, then landmarks checkbox should be + # disabled + landmark_options_wid.children[1].disabled = not mean_has_landmarks + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one', + 'image'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + info_wid = info_print(n_bullets=11, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -876,6 +1086,7 @@ def update_widgets(name, value): update_model_parameters(appearance_model_parameters_wid, n_appearance_parameters[value], plot_function, params_str='param ') + # update channel options update_channel_options(channel_options_wid, aam.appearance_models[value].mean().n_channels, @@ -883,7 +1094,7 @@ def update_widgets(name, value): MaskedImage)) # create final widget - model_parameters_wid = ContainerWidget( + model_parameters_wid = ipywidgets.AccordionWidget( children=[shape_model_parameters_wid, appearance_model_parameters_wid]) tmp_children = [model_parameters_wid] if n_levels > 1: @@ -895,45 +1106,48 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) - tmp_wid = ContainerWidget(children=tmp_children) - wid = TabWidget(children=[tmp_wid, channel_options_wid, - landmark_options_wid, figure_options_wid, - info_wid, save_figure_wid]) + tmp_wid = ipywidgets.ContainerWidget(children=tmp_children) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, + landmark_options_wid, + viewer_options_wid, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='AAM Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='AAM Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['AAM parameters', 'Channels options', 'Landmarks options', - 'Figure options', 'Model info', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + tab_titles = ['AAM parameters', 'Channels options', + 'Landmarks options', 'Viewer options', 'Info', 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) + tab_titles = ['Shape parameters', 'Appearance parameters'] + for (k, tl) in enumerate(tab_titles): + model_parameters_wid.set_title(k, tl) # align widgets - if n_levels > 1: - tmp_wid.remove_class('vbox') - tmp_wid.add_class('hbox') + tmp_wid.remove_class('vbox') + tmp_wid.add_class('hbox') format_model_parameters(shape_model_parameters_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=True) + border_visible=False) format_model_parameters(appearance_model_parameters_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=True) + border_visible=False) format_channel_options(channel_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -944,12 +1158,13 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + border_visible=False, + suboptions_border_visible=True) + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -959,15 +1174,16 @@ def update_widgets(name, value): toggle_button_font_weight='bold', tab_top_margin='0cm', border_visible=False) - # update widgets' state for level 0 + # update widgets' state for image number 0 update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), - figure_size=(7, 7), mode='multiple', popup=False, **kwargs): + figure_size=(10, 8), mode='multiple', popup=False): r""" Allows the dynamic visualization of a multilevel ATM. @@ -977,32 +1193,26 @@ def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), The multilevel ATM to be displayed. Note that each level can have different attributes, e.g. number of active components, feature type, number of channels. - n_shape_parameters : `int` or `list` of `int` or None, optional - The number of shape principal components to be used for the parameters + The number of shape components to be used for the parameters sliders. - If int, then the number of sliders per level is the minimum between - n_parameters and the number of active components per level. - If list of int, then a number of sliders is defined per level. - If None, all the active components per level will have a slider. - + If `int`, then the number of sliders per level is the minimum between + `n_parameters` and the number of active components per level. + If `list` of `int`, then a number of sliders is defined per level. + If ``None``, all the active components per level will have a slider. parameters_bounds : (`float`, `float`), optional The minimum and maximum bounds, in std units, for the sliders. - figure_size : (`int`, `int`), optional The size of the plotted figures. - - mode : 'single' or 'multiple', optional - If single, only a single slider is constructed along with a drop down - menu. - If multiple, a slider is constructed for each parameter. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + mode : {``single``, ``multiple``}, optional + If ``single``, only a single slider is constructed along with a drop + down menu. If ``multiple``, a slider is constructed for each parameter. + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage # find number of levels @@ -1016,51 +1226,122 @@ def visualize_atm(atm, n_shape_parameters=5, parameters_bounds=(-3.0, 3.0), n_shape_parameters = _check_n_parameters(n_shape_parameters, n_levels, max_n_shape) - # define plot function + # find initial groups and labels that will be passed to the landmark options + # widget creation + template_has_landmarks = atm.warped_templates[0].landmarks.n_groups != 0 + if template_has_landmarks: + all_groups_keys, all_labels_keys = _extract_groups_labels( + atm.warped_templates[0]) + else: + all_groups_keys = [' '] + all_labels_keys = [[' ']] + + # get initial line colours for each available label + if len(all_labels_keys[0]) == 1: + line_colours = ['r'] + else: + line_colours = sample_colours_from_colourmap(len(all_labels_keys[0]), + 'jet') + + # initial options dictionaries + channels_default = 0 + if atm.warped_templates[0].n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': atm.warped_templates[0].n_channels, + 'image_is_masked': isinstance(atm.warped_templates[0], + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': isinstance(atm.warped_templates[0], + MaskedImage)} + landmark_options_default = {'render_landmarks': template_has_landmarks, + 'group_keys': all_groups_keys, + 'labels_keys': all_labels_keys, + 'group': None, + 'with_labels': None} + lines_options = {'render_lines': True, + 'line_width': 1, + 'line_colour': line_colours, + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + image_options = {'interpolation': 'none', + 'alpha': 1.0} + viewer_options_default = {'lines': lines_options, + 'markers': markers_options, + 'figure': figure_options, + 'image': image_options} + + # Define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected level level = 0 if n_levels > 1: level = level_wid.value - # get weights and compute instance + # compute weights and instance shape_weights = shape_model_parameters_wid.parameters_values instance = atm.instance(level=level, shape_weights=shape_weights) - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show image with selected options - new_figure_id = _plot_figure( - image=instance, figure_id=figure_id, image_enabled=True, - landmarks_enabled=landmark_options_wid.landmarks_enabled, - image_is_masked=channel_options_wid.image_is_masked, - masked_enabled=channel_options_wid.masked_enabled, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=[landmark_options_wid.group], - with_labels=[landmark_options_wid.with_labels], - groups_colours=dict(), subplots_enabled=False, - subplots_titles=dict(), image_axes_mode=True, - legend_enabled=landmark_options_wid.legend_enabled, - numbering_enabled=landmark_options_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # update info text widget + update_info(atm, instance, level, + landmark_options_wid.selected_values['group']) + n_labels = len(landmark_options_wid.selected_values['with_labels']) + + # plot + tmp1 = viewer_options_wid.selected_values[0]['lines'] + tmp2 = viewer_options_wid.selected_values[0]['markers'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + renderer = _visualize_menpo( + instance, save_figure_wid.renderer[0], + landmark_options_wid.selected_values['render_landmarks'], + channel_options_wid.selected_values['image_is_masked'], + channel_options_wid.selected_values['masked_enabled'], + channel_options_wid.selected_values['channels'], + channel_options_wid.selected_values['glyph_enabled'], + channel_options_wid.selected_values['glyph_block_size'], + channel_options_wid.selected_values['glyph_use_negative'], + channel_options_wid.selected_values['sum_enabled'], + landmark_options_wid.selected_values['group'], + landmark_options_wid.selected_values['with_labels'], + tmp1['render_lines'], tmp1['line_style'], tmp1['line_width'], + tmp1['line_colour'][:n_labels], tmp2['render_markers'], + tmp2['marker_style'], tmp2['marker_size'], tmp2['marker_edge_width'], + tmp2['marker_edge_colour'], tmp2['marker_face_colour'], + False, None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None, + False, None, None, new_figure_size, tmp3['render_axes'], + tmp3['axes_font_name'], tmp3['axes_font_size'], + tmp3['axes_font_style'], tmp3['axes_x_limits'], + tmp3['axes_y_limits'], tmp3['axes_font_weight'], + tmp4['interpolation'], tmp4['alpha']) # save the current figure id - save_figure_wid.figure_id = new_figure_id - - # update info text widget - update_info(atm, instance, level, landmark_options_wid.group) + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(atm, instance, level, group): @@ -1095,82 +1376,92 @@ def update_info(atm, instance, level, group): else: tmp_pyramid = "Features were extracted at each pyramid level." - # Formatting is a bit ugly but this is MUCH easier to read. - info_txt = r""" - {} training shapes. - Warp using {} transform. - Level {}/{} (downscale={:.1f}). - {} - {} - {} - Reference frame of length {} ({} x {}C, {} x {}C). - {} shape components ({:.2f}% of variance) - {} landmark points. - Instance: min={:.3f} , max={:.3f} - """.format(atm.n_training_shapes, atm.transform.__name__, - level + 1, - atm.n_levels, atm.downscale, tmp_shape_models, - tmp_pyramid, tmp_feat, - tmplt_inst.n_true_pixels() * n_channels, - tmplt_inst.n_true_pixels(), n_channels, - tmplt_inst._str_shape, n_channels, - lvl_shape_mod.n_components, - lvl_shape_mod.variance_ratio() * 100, - instance.landmarks[group].lms.n_points, - instance.pixels.min(), instance.pixels.max()) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + # update info widgets + info_wid.children[1].children[0].value = "> {} training shapes.".\ + format(atm.n_training_shapes) + info_wid.children[1].children[1].value = "> Warp using {} transform.".\ + format(atm.transform.__name__) + info_wid.children[1].children[2].value = "> Level {}/{} " \ + "(downscale={:.1f}).".\ + format(level + 1, atm.n_levels, atm.downscale) + info_wid.children[1].children[3].value = "> {}".format(tmp_shape_models) + info_wid.children[1].children[4].value = "> {}".format(tmp_pyramid) + info_wid.children[1].children[5].value = "> {}".format(tmp_feat) + info_wid.children[1].children[6].value = "> Reference frame of " \ + "length {} ({} x {}C, {} x " \ + "{}C).".\ + format(tmplt_inst.n_true_pixels() * n_channels, + tmplt_inst.n_true_pixels(), + n_channels, tmplt_inst._str_shape, n_channels) + info_wid.children[1].children[7].value = "> {} shape components " \ + "({:.2f}% of variance).".\ + format(lvl_shape_mod.n_components, + lvl_shape_mod.variance_ratio() * 100) + info_wid.children[1].children[8].value = "> {} landmark points.".\ + format(instance.landmarks[group].lms.n_points) + info_wid.children[1].children[9].value = "> Instance: min={:.3f} , " \ + "max={:.3f}.".\ + format(instance.pixels.min(), instance.pixels.max()) # Plot shape eigenvalues function def plot_shape_eigenvalues(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) - # get parameters + # get level level = 0 if n_levels > 1: level = level_wid.value - # get the current figure id - figure_id = save_figure_wid.figure_id - - # show eigenvalues plots - new_figure_id = _plot_eigenvalues(figure_id, atm.shape_models[level], - figure_size, - figure_options_wid.x_scale, - figure_options_wid.y_scale) + # get the current figure id and plot the eigenvalues + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * 10, + viewer_options_wid.selected_values[0]['figure']['y_scale'] * 3) + plt.subplot(121) + atm.shape_models[level].plot_eigenvalues_ratio( + figure_id=save_figure_wid.renderer[0].figure_id) + plt.subplot(122) + renderer = atm.shape_models[level].plot_eigenvalues_cumulative_ratio( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer - # create options widgets + # create parameters, channels nad landmarks options widgets shape_model_parameters_wid = model_parameters( n_shape_parameters[0], plot_function, params_str='param ', mode=mode, params_bounds=parameters_bounds, toggle_show_default=True, toggle_show_visible=False, toggle_show_name='Shape Parameters', plot_eig_visible=True, plot_eig_function=plot_shape_eigenvalues) - channel_options_wid = channel_options( - atm.warped_templates[0].n_channels, - isinstance(atm.warped_templates[0], MaskedImage), - plot_function, masked_default=True, toggle_show_default=True, - toggle_show_visible=False) - all_groups_keys, all_labels_keys = \ - _extract_groups_labels(atm.warped_templates[0]) - landmark_options_wid = landmark_options(all_groups_keys, all_labels_keys, - plot_function, + channel_options_wid = channel_options(channels_options_default, + plot_function=plot_function, + toggle_show_default=True, + toggle_show_visible=False) + landmark_options_wid = landmark_options(landmark_options_default, + plot_function=plot_function, toggle_show_default=True, - landmarks_default=True, - legend_default=False, - numbering_default=False, toggle_show_visible=False) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + # if the mean doesn't have landmarks, then landmarks checkbox should be + # disabled + landmark_options_wid.children[1].disabled = not template_has_landmarks + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'figure_one', + 'image'], + objects_names=None, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True) + info_wid = info_print(n_bullets=10, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -1180,6 +1471,7 @@ def update_widgets(name, value): update_model_parameters(shape_model_parameters_wid, n_shape_parameters[value], plot_function, params_str='param ') + # update channel options update_channel_options(channel_options_wid, atm.warped_templates[value].n_channels, @@ -1197,35 +1489,35 @@ def update_widgets(name, value): radio_str["Level {} (high)".format(l)] = l else: radio_str["Level {}".format(l)] = l - level_wid = RadioButtonsWidget(values=radio_str, - description='Pyramid:', value=0) + level_wid = ipywidgets.RadioButtonsWidget(values=radio_str, + description='Pyramid:', + value=0) level_wid.on_trait_change(update_widgets, 'value') level_wid.on_trait_change(plot_function, 'value') tmp_children.insert(0, level_wid) - tmp_wid = ContainerWidget(children=tmp_children) - wid = TabWidget(children=[tmp_wid, channel_options_wid, - landmark_options_wid, figure_options_wid, - info_wid, save_figure_wid]) + tmp_wid = ipywidgets.ContainerWidget(children=tmp_children) + tab_wid = ipywidgets.TabWidget(children=[tmp_wid, channel_options_wid, + landmark_options_wid, + viewer_options_wid, + info_wid, save_figure_wid]) + logo_wid = logo() + wid = ipywidgets.ContainerWidget(children=[logo_wid, tab_wid]) if popup: - wid = PopupWidget(children=[wid], button_text='ATM Menu') + wid = ipywidgets.PopupWidget(children=[wid], + button_text='ATM Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['Shape parameters', 'Channels options', 'Landmarks options', - 'Figure options', 'Model info', 'Save figure'] - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + tab_titles = ['Shape parameters', 'Channels options', + 'Landmarks options', 'Viewer options', 'Info', 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # align widgets - if n_levels > 1: - tmp_wid.remove_class('vbox') - tmp_wid.add_class('hbox') + tmp_wid.remove_class('vbox') + tmp_wid.add_class('hbox') format_model_parameters(shape_model_parameters_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -1241,12 +1533,13 @@ def update_widgets(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) - format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', + border_visible=False, + suboptions_border_visible=True) + format_info_print(info_wid, font_size_in_pt='10pt', container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) @@ -1256,35 +1549,39 @@ def update_widgets(name, value): toggle_button_font_weight='bold', tab_top_margin='0cm', border_visible=False) - # update widgets' state for level 0 + # update widgets' state for image number 0 update_widgets('', 0) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False -def visualize_fitting_results(fitting_results, figure_size=(7, 7), popup=False, - **kwargs): +def visualize_fitting_results(fitting_results, figure_size=(10, 8), + browser_style='buttons', popup=False): r""" Widget that allows browsing through a list of fitting results. Parameters ----------- fitting_results : `list` of :map:`FittingResult` or subclass - The list of fitting results to be displayed. Note that the fitting + The `list` of fitting results to be displayed. Note that the fitting results can have different attributes between them, i.e. different number of iterations, number of channels etc. - figure_size : (`int`, `int`), optional The initial size of the plotted figures. - + browser_style : {``buttons``, ``slider``}, optional + It defines whether the selector of the fitting results will have the form of + plus/minus buttons or a slider. popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - kwargs : `dict`, optional - Passed through to the viewer. + If ``True``, the widget will appear as a popup window. """ + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + import matplotlib.pyplot as plt from menpo.image import MaskedImage + from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap + print 'Initializing...' # make sure that fitting_results is a list even with one fitting_result if not isinstance(fitting_results, list): @@ -1297,117 +1594,200 @@ def visualize_fitting_results(fitting_results, figure_size=(7, 7), popup=False, n_fitting_results = len(fitting_results) # create dictionaries - iter_str = 'iter_' + all_groups = ['final', 'initial', 'ground', 'iterations'] groups_final_dict = dict() colour_final_dict = dict() groups_final_dict['initial'] = 'Initial shape' - colour_final_dict['initial'] = 'r' + colour_final_dict['initial'] = 'b' groups_final_dict['final'] = 'Final shape' - colour_final_dict['final'] = 'b' + colour_final_dict['final'] = 'r' groups_final_dict['ground'] = 'Ground-truth shape' colour_final_dict['ground'] = 'y' + groups_final_dict['iterations'] = 'Iterations' + colour_final_dict['iterations'] = 'r' + + # initial options dictionaries + channels_default = 0 + if fitting_results[0].fitted_image.n_channels == 3: + channels_default = None + channels_options_default = \ + {'n_channels': fitting_results[0].fitted_image.n_channels, + 'image_is_masked': isinstance(fitting_results[0].fitted_image, + MaskedImage), + 'channels': channels_default, + 'glyph_enabled': False, + 'glyph_block_size': 3, + 'glyph_use_negative': False, + 'sum_enabled': False, + 'masked_enabled': False} + all_groups_keys, _ = _extract_groups_labels(fitting_results[0].fitted_image) + final_result_options_default = {'all_groups': all_groups_keys, + 'render_image': True, + 'selected_groups': ['final'], + 'subplots_enabled': True} + iterations_result_options_default = \ + {'n_iters': fitting_results[0].n_iters, + 'image_has_gt_shape': not fitting_results[0].gt_shape is None, + 'n_points': fitting_results[0].fitted_image.landmarks['final'].lms.n_points, + 'iter_str': 'iter_', + 'selected_groups': ['iter_0'], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} + markers_options = {'render_markers': True, + 'marker_size': 20, + 'marker_face_colour': ['r'], + 'marker_edge_colour': ['k'], + 'marker_style': 'o', + 'marker_edge_width': 1} + lines_options_default = {'render_lines': True, + 'line_width': 1, + 'line_colour': ['r'], + 'line_style': '-'} + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': None} + numbering_options = {'render_numbering': False, + 'numbers_font_name': 'serif', + 'numbers_font_size': 10, + 'numbers_font_style': 'normal', + 'numbers_font_weight': 'normal', + 'numbers_font_colour': ['k'], + 'numbers_horizontal_align': 'center', + 'numbers_vertical_align': 'bottom'} + legend_options = {'render_legend': True, + 'legend_title': '', + 'legend_font_name': 'sans-serif', + 'legend_font_style': 'normal', + 'legend_font_size': 11, + 'legend_font_weight': 'normal', + 'legend_marker_scale': 1., + 'legend_location': 2, + 'legend_bbox_to_anchor': (1.05, 1.), + 'legend_border_axes_pad': 1., + 'legend_n_columns': 1, + 'legend_horizontal_spacing': 1., + 'legend_vertical_spacing': 1., + 'legend_border': True, + 'legend_border_padding': 0.5, + 'legend_shadow': False, + 'legend_rounded_corners': True} + image_options = {'interpolation': 'bilinear', + 'alpha': 1.0} + viewer_options_default = [] + for group in all_groups: + tmp_lines = lines_options_default.copy() + tmp_lines['line_colour'] = [colour_final_dict[group]] + tmp_markers = markers_options.copy() + tmp_markers['marker_face_colour'] = [colour_final_dict[group]] + tmp = {'markers': tmp_markers, + 'lines': tmp_lines, + 'figure': figure_options, + 'legend': legend_options, + 'numbering': numbering_options, + 'image': image_options} + viewer_options_default.append(tmp) + index_selection_default = {'min': 0, + 'max': n_fitting_results - 1, + 'step': 1, + 'index': 0} # define function that plots errors curve def plot_errors_function(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] - # select figure - figure_id = plt.figure(save_figure_wid.figure_id.number) + # get figure size + new_figure_size = (viewer_options_wid.selected_values[0]['figure']['x_scale'] * figure_size[0], + viewer_options_wid.selected_values[0]['figure']['y_scale'] * figure_size[1]) # plot errors curve - plt.plot(range(len(fitting_results[im].errors())), - fitting_results[im].errors(), '-bo') - plt.gca().set_xlim(0, len(fitting_results[im].errors())-1) - plt.xlabel('Iteration') - plt.ylabel('Fitting Error') - plt.title("Fitting error evolution of Image {}".format(im)) - plt.grid("on") - - # set figure size - x_scale = figure_options_wid.x_scale - y_scale = figure_options_wid.y_scale - plt.gcf().set_size_inches([x_scale, y_scale] * np.asarray(figure_size)) + renderer = fitting_results[im].plot_errors( + error_type=error_type_wid.value, + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size) # show figure plt.show() # save the current figure id - save_figure_wid.figure_id = figure_id + save_figure_wid.renderer[0] = renderer # define function that plots displacements curve def plot_displacements_function(name): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] - # select figure - figure_id = plt.figure(save_figure_wid.figure_id.number) + # get figure size + new_figure_size = (viewer_options_wid.selected_values[0]['figure'][ + 'x_scale'] * figure_size[0], + viewer_options_wid.selected_values[0]['figure'][ + 'y_scale'] * figure_size[1]) - # plot displacements curve - d_type = iterations_wid.displacement_type + # plot errors curve + d_type = iterations_wid.selected_values['displacement_type'] if (d_type == 'max' or d_type == 'min' or d_type == 'mean' or d_type == 'median'): - d_curve = fitting_results[im].displacements_stats(stat_type=d_type) + renderer = fitting_results[im].plot_displacements( + figure_id=save_figure_wid.renderer[0].figure_id, + figure_size=new_figure_size, stat_type=d_type) else: all_displacements = fitting_results[im].displacements() d_curve = [iteration_displacements[d_type] for iteration_displacements in all_displacements] - plt.plot(range(len(d_curve)), d_curve, '-bo') - plt.gca().set_xlim(0, len(d_curve)-1) - plt.grid("on") - plt.xlabel('Iteration') - - # set labels - if d_type == 'max': - plt.ylabel('Maximum Displacement') - plt.title("Maximum displacement evolution of Image {}".format(im)) - elif d_type == 'min': - plt.ylabel('Minimum Displacement') - plt.title("Minimum displacement evolution of Image {}".format(im)) - elif d_type == 'mean': - plt.ylabel('Mean Displacement') - plt.title("Mean displacement evolution of Image {}".format(im)) - elif d_type == 'median': - plt.ylabel('Median Displacement') - plt.title("Median displacement evolution of Image {}".format(im)) - else: - plt.ylabel("Displacement of Point {}".format(d_type)) - plt.title("Point {} displacement evolution of Image {}".format( - d_type, im)) - - # set figure size - x_scale = figure_options_wid.x_scale - y_scale = figure_options_wid.y_scale - plt.gcf().set_size_inches([x_scale, y_scale] * np.asarray(figure_size)) + from menpo.visualize import GraphPlotter + ylabel = "Displacement of Point {}".format(d_type) + title = "Point {} displacement per " \ + "iteration of Image {}".format(d_type, im) + renderer = GraphPlotter( + figure_id=save_figure_wid.renderer[0].figure_id, + new_figure=False, x_axis=range(len(d_curve)), y_axis=[d_curve], + title=title, x_label='Iteration', y_label=ylabel, + x_axis_limits=(0, len(d_curve)-1), y_axis_limits=None).render( + render_lines=True, line_colour='b', line_style='-', + line_width=2, render_markers=True, marker_style='o', + marker_size=4, marker_face_colour='b', + marker_edge_colour='k', marker_edge_width=1., + render_legend=False, render_axes=True, + axes_font_name='sans-serif', axes_font_size=10, + axes_font_style='normal', axes_font_weight='normal', + render_grid=True, grid_line_style='--', grid_line_width=0.5, + figure_size=new_figure_size) # show figure plt.show() # save the current figure id - save_figure_wid.figure_id = figure_id + save_figure_wid.renderer[0] = renderer # define plot function def plot_function(name, value): # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) + ipydisplay.clear_output(wait=True) # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] # selected mode: final or iterations final_enabled = False @@ -1417,144 +1797,220 @@ def plot_function(name, value): # update info text widget update_info('', error_type_wid.value) - # get the current figure id - figure_id = save_figure_wid.figure_id - - # call helper _plot_figure + # get selected options if final_enabled: - new_figure_id = _plot_figure( - image=fitting_results[im].fitted_image, figure_id=figure_id, - image_enabled=final_result_wid.show_image, - landmarks_enabled=True, image_is_masked=False, - masked_enabled=False, channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=final_result_wid.groups, - with_labels=[None] * len(final_result_wid.groups), - groups_colours=colour_final_dict, - subplots_enabled=final_result_wid.subplots_enabled, - subplots_titles=groups_final_dict, image_axes_mode=True, - legend_enabled=final_result_wid.legend_enabled, - numbering_enabled=final_result_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # image object + image = fitting_results[im].fitted_image + render_image = final_result_wid.selected_values['render_image'] + # groups + groups = final_result_wid.selected_values['selected_groups'] + # subplots + subplots_enabled = final_result_wid.selected_values[ + 'subplots_enabled'] + subplots_titles = groups_final_dict + # lines and markers options + render_lines = [] + line_colour = [] + line_style = [] + line_width = [] + render_markers = [] + marker_style = [] + marker_size = [] + marker_face_colour = [] + marker_edge_colour = [] + marker_edge_width = [] + for g in groups: + group_idx = all_groups.index(g) + tmp1 = viewer_options_wid.selected_values[group_idx]['lines'] + tmp2 = viewer_options_wid.selected_values[group_idx]['markers'] + render_lines.append(tmp1['render_lines']) + line_colour.append(tmp1['line_colour']) + line_style.append(tmp1['line_style']) + line_width.append(tmp1['line_width']) + render_markers.append(tmp2['render_markers']) + marker_style.append(tmp2['marker_style']) + marker_size.append(tmp2['marker_size']) + marker_face_colour.append(tmp2['marker_face_colour']) + marker_edge_colour.append(tmp2['marker_edge_colour']) + marker_edge_width.append(tmp2['marker_edge_width']) else: - # create subplot titles dict and colours dict - groups_dict = dict() - colour_dict = dict() - cols = np.random.random([3, len(iterations_wid.groups)]) - for i, group in enumerate(iterations_wid.groups): - iter_num = group[len(iter_str)::] - groups_dict[iter_str + iter_num] = "Iteration " + iter_num - colour_dict[iter_str + iter_num] = cols[:, i] - - # plot - new_figure_id = _plot_figure( - image=fitting_results[im].iter_image, figure_id=figure_id, - image_enabled=iterations_wid.show_image, landmarks_enabled=True, - image_is_masked=False, masked_enabled=False, - channels=channel_options_wid.channels, - glyph_enabled=channel_options_wid.glyph_enabled, - glyph_block_size=channel_options_wid.glyph_block_size, - glyph_use_negative=channel_options_wid.glyph_use_negative, - sum_enabled=channel_options_wid.sum_enabled, - groups=iterations_wid.groups, - with_labels=[None] * len(iterations_wid.groups), - groups_colours=colour_dict, - subplots_enabled=iterations_wid.subplots_enabled, - subplots_titles=groups_dict, image_axes_mode=True, - legend_enabled=iterations_wid.legend_enabled, - numbering_enabled=iterations_wid.numbering_enabled, - x_scale=figure_options_wid.x_scale, - y_scale=figure_options_wid.y_scale, - axes_visible=figure_options_wid.axes_visible, - figure_size=figure_size, **kwargs) + # image object + image = fitting_results[im].iter_image + render_image = iterations_wid.selected_values['render_image'] + # groups + groups = iterations_wid.selected_values['selected_groups'] + # subplots + subplots_enabled = iterations_wid.selected_values[ + 'subplots_enabled'] + subplots_titles = dict() + iter_str = iterations_wid.selected_values['iter_str'] + for i, g in enumerate(groups): + iter_num = g[len(iter_str)::] + subplots_titles[iter_str + iter_num] = "Iteration " + iter_num + # lines and markers options + group_idx = all_groups.index('iterations') + tmp1 = viewer_options_wid.selected_values[group_idx]['lines'] + tmp2 = viewer_options_wid.selected_values[group_idx]['markers'] + render_lines = [tmp1['render_lines']] * len(groups) + line_style = [tmp1['line_style']] * len(groups) + line_width = [tmp1['line_width']] * len(groups) + render_markers = [tmp2['render_markers']] * len(groups) + marker_style = [tmp2['marker_style']] * len(groups) + marker_size = [tmp2['marker_size']] * len(groups) + marker_edge_colour = [tmp2['marker_edge_colour']] * len(groups) + marker_edge_width = [tmp2['marker_edge_width']] * len(groups) + if (subplots_enabled or + iterations_wid.children[1].children[0].children[0].value == + 'animation'): + line_colour = [tmp1['line_colour']] * len(groups) + marker_face_colour = [tmp2['marker_face_colour']] * len(groups) + else: + cols = sample_colours_from_colourmap(len(groups), 'jet') + line_colour = cols + marker_face_colour = cols + + tmp1 = viewer_options_wid.selected_values[0]['numbering'] + tmp2 = viewer_options_wid.selected_values[0]['legend'] + tmp3 = viewer_options_wid.selected_values[0]['figure'] + tmp4 = viewer_options_wid.selected_values[0]['image'] + new_figure_size = (tmp3['x_scale'] * figure_size[0], + tmp3['y_scale'] * figure_size[1]) + + # call helper _visualize + renderer = _visualize( + image=image, renderer=save_figure_wid.renderer[0], + render_image=render_image, render_landmarks=True, + image_is_masked=False, masked_enabled=False, + channels=channel_options_wid.selected_values['channels'], + glyph_enabled=channel_options_wid.selected_values['glyph_enabled'], + glyph_block_size=channel_options_wid.selected_values['glyph_block_size'], + glyph_use_negative=channel_options_wid.selected_values['glyph_use_negative'], + sum_enabled=channel_options_wid.selected_values['sum_enabled'], + groups=groups, with_labels=[None] * len(groups), + subplots_enabled=subplots_enabled, subplots_titles=subplots_titles, + image_axes_mode=True, render_lines=render_lines, + line_style=line_style, line_width=line_width, + line_colour=line_colour, render_markers=render_markers, + marker_style=marker_style, marker_size=marker_size, + marker_edge_width=marker_edge_width, + marker_edge_colour=marker_edge_colour, + marker_face_colour=marker_face_colour, + render_numbering=tmp1['render_numbering'], + numbers_horizontal_align=tmp1['numbers_horizontal_align'], + numbers_vertical_align=tmp1['numbers_vertical_align'], + numbers_font_name=tmp1['numbers_font_name'], + numbers_font_size=tmp1['numbers_font_size'], + numbers_font_style=tmp1['numbers_font_style'], + numbers_font_weight=tmp1['numbers_font_weight'], + numbers_font_colour=tmp1['numbers_font_colour'], + render_legend=tmp2['render_legend'], + legend_title=tmp2['legend_title'], + legend_font_name=tmp2['legend_font_name'], + legend_font_style=tmp2['legend_font_style'], + legend_font_size=tmp2['legend_font_size'], + legend_font_weight=tmp2['legend_font_weight'], + legend_marker_scale=tmp2['legend_marker_scale'], + legend_location=tmp2['legend_location'], + legend_bbox_to_anchor=tmp2['legend_bbox_to_anchor'], + legend_border_axes_pad=tmp2['legend_border_axes_pad'], + legend_n_columns=tmp2['legend_n_columns'], + legend_horizontal_spacing=tmp2['legend_horizontal_spacing'], + legend_vertical_spacing=tmp2['legend_vertical_spacing'], + legend_border=tmp2['legend_border'], + legend_border_padding=tmp2['legend_border_padding'], + legend_shadow=tmp2['legend_shadow'], + legend_rounded_corners=tmp2['legend_rounded_corners'], + render_axes=tmp3['render_axes'], + axes_font_name=tmp3['axes_font_name'], + axes_font_size=tmp3['axes_font_size'], + axes_font_style=tmp3['axes_font_style'], + axes_font_weight=tmp3['axes_font_weight'], + axes_x_limits=tmp3['axes_x_limits'], + axes_y_limits=tmp3['axes_y_limits'], + interpolation=tmp4['interpolation'], + alpha=tmp4['alpha'], figure_size=new_figure_size) # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # define function that updates info text def update_info(name, value): # get selected image im = 0 if n_fitting_results > 1: - im = image_number_wid.selected_index + im = image_number_wid.selected_values['index'] # create output str if fitting_results[im].gt_shape is not None: - info_txt = r""" - Initial error: {:.4f} - Final error: {:.4f} - {} iterations - """.format(fitting_results[im].initial_error(error_type=value), - fitting_results[im].final_error(error_type=value), - fitting_results[im].n_iters) + info_wid.children[1].children[0].value = \ + "> Initial error: {:.4f}".format( + fitting_results[im].initial_error(error_type=value)) + info_wid.children[1].children[0].visible = True + info_wid.children[1].children[1].value = \ + "> Final error: {:.4f}".format( + fitting_results[im].final_error(error_type=value)) + info_wid.children[1].children[1].visible = True + info_wid.children[1].children[2].value = \ + "> {} iterations".format(fitting_results[im].n_iters) else: - info_txt = r""" - {} iterations - """.format(fitting_results[im].n_iters) + info_wid.children[1].children[0].value = '' + info_wid.children[1].children[0].visible = False + info_wid.children[1].children[1].value = '' + info_wid.children[1].children[1].visible = False + info_wid.children[1].children[2].value = "> {} iterations".format( + fitting_results[im].n_iters) if hasattr(fitting_results[im], 'n_levels'): # Multilevel result - info_txt += r""" - {} levels with downscale of {:.1f} - """.format(fitting_results[im].n_levels, - fitting_results[im].downscale) - - info_wid.children[1].value = _raw_info_string_to_latex(info_txt) + info_wid.children[1].children[3].value = \ + "> {} levels with downscale of {:.1f}".format( + fitting_results[im].n_levels, fitting_results[im].downscale) + info_wid.children[1].children[1].visible = True + else: + info_wid.children[1].children[1].value = '' + info_wid.children[1].children[1].visible = True # Create options widgets channel_options_wid = channel_options( - fitting_results[0].fitted_image.n_channels, - isinstance(fitting_results[0].fitted_image, MaskedImage), plot_function, - masked_default=False, toggle_show_default=True, - toggle_show_visible=False) - figure_options_wid = figure_options(plot_function, scale_default=1., - show_axes_default=True, - toggle_show_default=True, - toggle_show_visible=False) - info_wid = info_print(toggle_show_default=True, toggle_show_visible=False) - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + channels_options_default, plot_function=plot_function, + toggle_show_default=True, toggle_show_visible=False) + + # viewer options widget + viewer_options_wid = viewer_options( + viewer_options_default, + ['markers', 'lines', 'figure_one', 'legend', 'numbering', 'image'], + objects_names=all_groups, plot_function=plot_function, + toggle_show_visible=False, toggle_show_default=True) + info_wid = info_print(n_bullets=4, toggle_show_default=True, + toggle_show_visible=False) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) - # Create landmark groups checkboxes - all_groups_keys, all_labels_keys = _extract_groups_labels( - fitting_results[0].fitted_image) - final_result_wid = final_result_options(all_groups_keys, plot_function, - title='Final', - show_image_default=True, - subplots_enabled_default=True, - legend_default=True, - numbering_default=False, - toggle_show_default=True, - toggle_show_visible=False) + # final result and iterations options + final_result_wid = final_result_options( + final_result_options_default, plot_function=plot_function, + title='Final', toggle_show_default=True, toggle_show_visible=False) iterations_wid = iterations_result_options( - fitting_results[0].n_iters, not fitting_results[0].gt_shape is None, - fitting_results[0].fitted_image.landmarks['final'].lms.n_points, - plot_function, plot_errors_function, plot_displacements_function, - iter_str=iter_str, title='Iterations', show_image_default=True, - subplots_enabled_default=False, legend_default=True, - numbering_default=False, toggle_show_default=True, - toggle_show_visible=False) - iterations_wid.children[2].children[4].on_click(plot_errors_function) - iterations_wid.children[2].children[5].children[0].on_click( - plot_displacements_function) + iterations_result_options_default, plot_function=plot_function, + plot_errors_function=plot_errors_function, + plot_displacements_function=plot_displacements_function, + title='Iterations', toggle_show_default=True, toggle_show_visible=False) # Create error type radio buttons error_type_values = OrderedDict() error_type_values['Point-to-point Normalized Mean Error'] = 'me_norm' error_type_values['Point-to-point Mean Error'] = 'me' error_type_values['RMS Error'] = 'rmse' - error_type_wid = RadioButtonsWidget(values=error_type_values, - value='me_norm', - description='Error type') + error_type_wid = ipywidgets.RadioButtonsWidget( + values=error_type_values, value='me_norm', description='Error type') error_type_wid.on_trait_change(update_info, 'value') - plot_ced_but = ButtonWidget(description='Plot CED', visible=show_ced) - error_wid = ContainerWidget(children=[error_type_wid, plot_ced_but]) + plot_ced_but = ipywidgets.ButtonWidget(description='Plot CED', + visible=show_ced) + error_wid = ipywidgets.ContainerWidget(children=[error_type_wid, + plot_ced_but]) # define function that updates options' widgets state def update_widgets(name, value): @@ -1564,34 +2020,40 @@ def update_widgets(name, value): # update channel options update_channel_options( channel_options_wid, - n_channels=fitting_results[value].fitted_image.n_channels, - image_is_masked=isinstance(fitting_results[value].fitted_image, - MaskedImage)) + fitting_results[value].fitted_image.n_channels, + isinstance(fitting_results[value].fitted_image, + MaskedImage)) # update final result's options update_final_result_options(final_result_wid, group_keys, plot_function) # update iterations result's options - update_iterations_result_options( - iterations_wid, fitting_results[value].n_iters, - not fitting_results[value].gt_shape is None, - fitting_results[value].fitted_image.landmarks['final'].lms.n_points, - iter_str=iter_str) + iterations_result_options_default = \ + {'n_iters': fitting_results[value].n_iters, + 'image_has_gt_shape': not fitting_results[value].gt_shape is None, + 'n_points': fitting_results[value].fitted_image.landmarks['final'].lms.n_points} + update_iterations_result_options(iterations_wid, + iterations_result_options_default) # Create final widget - options_wid = TabWidget(children=[channel_options_wid, figure_options_wid]) - result_wid = TabWidget(children=[final_result_wid, iterations_wid]) + options_wid = ipywidgets.TabWidget(children=[channel_options_wid, + viewer_options_wid]) + result_wid = ipywidgets.TabWidget(children=[final_result_wid, + iterations_wid]) result_wid.on_trait_change(plot_function, 'selected_index') if n_fitting_results > 1: # image selection slider image_number_wid = animation_options( - index_min_val=0, index_max_val=n_fitting_results-1, - plot_function=plot_function, update_function=update_widgets, - index_step=1, index_default=0, - index_description='Image Number', index_minus_description='<', - index_plus_description='>', index_style='buttons', - index_text_editable=True, loop_default=True, interval_default=0.3, + index_selection_default, plot_function=plot_function, + update_function=update_widgets, index_description='Image Number', + index_minus_description='<', index_plus_description='>', + index_style=browser_style, index_text_editable=True, + loop_default=True, interval_default=0.3, toggle_show_title='Image Options', toggle_show_default=True, toggle_show_visible=False) + # final widget + logo_wid = ipywidgets.ContainerWidget(children=[logo(), + image_number_wid]) + # define function that combines the results' tab widget with the # animation # If animation is activated and the user selects the iterations tab, @@ -1610,10 +2072,6 @@ def save_fig_tab_fun(name, value): image_number_wid.children[1].children[1].children[0].children[1].value = True # final widget - tab_wid = TabWidget(children=[info_wid, result_wid, options_wid, - error_wid, save_figure_wid]) - tab_wid.on_trait_change(save_fig_tab_fun, 'selected_index') - wid = ContainerWidget(children=[image_number_wid, tab_wid]) if show_ced: tab_titles = ['Info', 'Result', 'Options', 'CED', 'Save figure'] else: @@ -1625,14 +2083,22 @@ def save_fig_tab_fun(name, value): plot_ced_but.visible = False # final widget - wid = TabWidget(children=[info_wid, result_wid, options_wid, error_wid, - save_figure_wid]) - tab_titles = ['Image info', 'Result', 'Options', 'Error type', - 'Save figure'] + logo_wid = logo() + tab_titles = ['Info', 'Result', 'Options', 'Error type', 'Save figure'] button_title = 'Fitting Result Menu' + + # final widget + cont_wid = ipywidgets.TabWidget(children=[info_wid, result_wid, options_wid, + error_wid, save_figure_wid]) + if n_fitting_results > 1: + cont_wid.on_trait_change(save_fig_tab_fun, 'selected_index') + # create popup widget if asked if popup: - wid = PopupWidget(children=[wid], button_text=button_title) + wid = ipywidgets.PopupWidget(children=[logo_wid, cont_wid], + button_text=button_title) + else: + wid = ipywidgets.ContainerWidget(children=[logo_wid, cont_wid]) # invoke plot_ced widget def plot_ced_fun(name): @@ -1651,25 +2117,18 @@ def plot_ced_fun(name): errors = [fit_errors, initial_errors] # call plot_ced - plot_ced_widget = plot_ced(errors, figure_size=(9, 5), popup=True, - error_type=error_type, error_range=None, - legend_entries=['Final Fitting', - 'Initialization'], - return_widget=True) + plot_ced_widget = plot_ced( + errors, figure_size=(9, 5), popup=True, error_type=error_type, + error_range=None, legend_entries=['Final Fitting', + 'Initialization'], + return_widget=True) # If another tab is selected, then close the widget. def close_plot_ced_fun(name, value): if value != 3: plot_ced_widget.close() plot_ced_but.visible = True - if n_fitting_results > 1: - tab_wid.on_trait_change(close_plot_ced_fun, 'selected_index') - else: - if popup: - wid.children[0].on_trait_change(close_plot_ced_fun, - 'selected_index') - else: - wid.on_trait_change(close_plot_ced_fun, 'selected_index') + cont_wid.on_trait_change(close_plot_ced_fun, 'selected_index') # If another error type, then close the widget def close_plot_ced_fun_2(name, value): @@ -1679,31 +2138,22 @@ def close_plot_ced_fun_2(name, value): plot_ced_but.on_click(plot_ced_fun) # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - if popup: - if n_fitting_results > 1: - for (k, tl) in enumerate(tab_titles): - wid.children[0].children[1].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - if n_fitting_results > 1: - for (k, tl) in enumerate(tab_titles): - wid.children[1].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + for (k, tl) in enumerate(tab_titles): + wid.children[1].set_title(k, tl) + result_wid.set_title(0, 'Final Fitting') result_wid.set_title(1, 'Iterations') options_wid.set_title(0, 'Channels') - options_wid.set_title(1, 'Figure') + options_wid.set_title(1, 'Viewer') # format options' widgets if n_fitting_results > 1: - format_animation_options(image_number_wid, index_text_width='0.5cm', + wid.children[0].remove_class('vbox') + wid.children[0].add_class('hbox') + format_animation_options(image_number_wid, index_text_width='1.0cm', container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -1714,11 +2164,12 @@ def close_plot_ced_fun_2(name, value): container_border='1px solid black', toggle_button_font_weight='bold', border_visible=False) - format_figure_options(figure_options_wid, container_padding='6px', + format_viewer_options(viewer_options_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - border_visible=False) + border_visible=False, + suboptions_border_visible=True) format_info_print(info_wid, font_size_in_pt='9pt', container_padding='6px', container_margin='6px', container_border='1px solid black', @@ -1740,10 +2191,11 @@ def close_plot_ced_fun_2(name, value): tab_top_margin='0cm', border_visible=False) # Reset value to enable initial visualization - figure_options_wid.children[2].value = False + viewer_options_wid.children[1].children[1].children[2].children[2].value = \ + False -def plot_ced(errors, figure_size=(9, 5), popup=False, error_type='me_norm', +def plot_ced(errors, figure_size=(10, 8), popup=False, error_type='me_norm', error_range=None, legend_entries=None, return_widget=False): r""" Widget for visualizing the cumulative error curves of the provided errors. @@ -1753,38 +2205,43 @@ def plot_ced(errors, figure_size=(9, 5), popup=False, error_type='me_norm', ----------- errors : `list` of `list` of `float` The list of errors to be used. - figure_size : (`int`, `int`), optional The initial size of the plotted figures. - - popup : `boolean`, optional - If enabled, the widget will appear as a popup window. - - error_type : `str` ``{'me_norm', 'me', 'rmse'}``, optional + popup : `bool`, optional + If ``True``, the widget will appear as a popup window. + error_type : {``me_norm``, ``me``, ``rmse``}, optional Specifies the type of the provided errors. - error_range : `list` of `float` with length 3, optional Specifies the horizontal axis range, i.e. + + :: + error_range[0] = min_error error_range[1] = max_error error_range[2] = error_step - If None, then + + If ``None``, then + + :: + error_range = [0., 0.101, 0.005] for error_type = 'me_norm' error_range = [0., 20., 1.] for error_type = 'me' error_range = [0., 20., 1.] for error_type = 'rmse' legend_entries : `list` of `str` The entries of the legend. The list must have the same length as errors. - If None, the entries will have the form 'Curve %d'. - - return_widget : `boolean`, optional - If True, the widget object will be returned so that it can be used as a - part of a bigger widget. If False, the widget object is not returned, it - is just visualized. + If ``None``, the entries will have the form ``'Curve %d'``. + return_widget : `bool`, optional + If ``True``, the widget object will be returned so that it can be used + as part of a bigger widget. If ``False``, the widget object is not + returned, it is just visualized. """ - from menpofit.fittingresult import compute_cumulative_error + import IPython.html.widgets as ipywidgets + import IPython.display as ipydisplay + from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap + from menpofit.fittingresult import plot_cumulative_error_distribution - # make sure that images is a list even with one image member + # make sure that errors is a list even with one list member if not isinstance(errors[0], list): errors = [errors] @@ -1798,137 +2255,203 @@ def plot_ced(errors, figure_size=(9, 5), popup=False, error_type='me_norm', # get horizontal axis errors x_label_initial_value = 'Error' x_axis_limit_initial_value = 0 + x_axis_step_initial_value = 0 if error_range is None: if error_type == 'me_norm': error_range = [0., 0.101, 0.005] x_axis_limit_initial_value = 0.05 + x_axis_step_initial_value = 0.005 x_label_initial_value = 'Normalized Point-to-Point Error' elif error_type == 'me' or error_type == 'rmse': error_range = [0., 20., 0.5] x_axis_limit_initial_value = 5. + x_axis_step_initial_value = 0.5 x_label_initial_value = 'Point-to-Point Error' else: x_axis_limit_initial_value = (error_range[1] + error_range[0]) / 2 - x_axis = np.arange(error_range[0], error_range[1], error_range[2]) - - # compute cumulative error curves - ceds = [compute_cumulative_error(e, x_axis) for e in errors] - x_axis = [x_axis] * len(ceds) - - # initialize plot options dictionaries and legend entries - colors = [np.random.random((3,)) for _ in range(n_curves)] - plot_options_list = [] - for k in range(n_curves): - plot_options_list.append({'show_line':True, - 'linewidth':2, - 'linecolor':colors[k], - 'linestyle':'-', - 'show_marker':True, - 'markersize':10, - 'markerfacecolor':'w', - 'markeredgecolor':colors[k], - 'markerstyle':'s', - 'markeredgewidth':1, - 'legend_entry':legend_entries[k]}) + + # initial options dictionaries + figure_options = {'x_scale': 1., + 'y_scale': 1., + 'render_axes': True, + 'axes_font_name': 'sans-serif', + 'axes_font_size': 10, + 'axes_font_style': 'normal', + 'axes_font_weight': 'normal', + 'axes_x_limits': None, + 'axes_y_limits': (0., 1.)} + legend_options = {'render_legend': True, + 'legend_title': '', + 'legend_font_name': 'sans-serif', + 'legend_font_style': 'normal', + 'legend_font_size': 10, + 'legend_font_weight': 'normal', + 'legend_marker_scale': 1., + 'legend_location': 2, + 'legend_bbox_to_anchor': (1.05, 1.), + 'legend_border_axes_pad': 1., + 'legend_n_columns': 1, + 'legend_horizontal_spacing': 1., + 'legend_vertical_spacing': 1., + 'legend_border': True, + 'legend_border_padding': 0.5, + 'legend_shadow': False, + 'legend_rounded_corners': False} + grid_options = {'render_grid': True, + 'grid_line_style': '--', + 'grid_line_width': 0.5} + + colours = sample_colours_from_colourmap(n_curves, 'jet') + viewer_options_default = [] + for i in range(n_curves): + lines_options_default = {'render_lines': True, + 'line_width': 2, + 'line_colour': [colours[i]], + 'line_style': '-'} + markers_options = {'render_markers': True, + 'marker_size': 10, + 'marker_face_colour': ['w'], + 'marker_edge_colour': [colours[i]], + 'marker_style': 's', + 'marker_edge_width': 1} + tmp = {'lines': lines_options_default, + 'markers': markers_options, + 'legend': legend_options, + 'figure': figure_options, + 'grid': grid_options} + viewer_options_default.append(tmp) # define plot function def plot_function(name, value): + import matplotlib.pyplot as plt # clear current figure, but wait until the new data to be displayed are # generated - clear_output(wait=True) - - # get the current figure id - figure_id = save_figure_wid.figure_id - - # plot the graph with the selected options - new_figure_id = _plot_graph( - figure_id, horizontal_axis_values=x_axis, vertical_axis_values=ceds, - plot_options_list=plot_options_wid.selected_options, - legend_visible=legend_visible.value, - grid_visible=grid_visible.value, gridlinestyle=gridlinestyle.value, - x_limit=x_axis_limit.value, y_limit=y_axis_limit.value, + ipydisplay.clear_output(wait=True) + + # get options that need to be a list + render_lines = [] + line_colour = [] + line_style = [] + line_width = [] + render_markers = [] + marker_style = [] + marker_size = [] + marker_face_colour = [] + marker_edge_colour = [] + marker_edge_width = [] + for idx in range(n_curves): + tmp1 = viewer_options_wid.selected_values[idx]['lines'] + tmp2 = viewer_options_wid.selected_values[idx]['markers'] + render_lines.append(tmp1['render_lines']) + line_colour.append(tmp1['line_colour'][0]) + line_style.append(tmp1['line_style']) + line_width.append(tmp1['line_width']) + render_markers.append(tmp2['render_markers']) + marker_style.append(tmp2['marker_style']) + marker_size.append(tmp2['marker_size']) + marker_face_colour.append(tmp2['marker_face_colour'][0]) + marker_edge_colour.append(tmp2['marker_edge_colour'][0]) + marker_edge_width.append(tmp2['marker_edge_width']) + + # rest of options + tmp3 = viewer_options_wid.selected_values[0]['legend'] + tmp4 = viewer_options_wid.selected_values[0]['figure'] + tmp5 = viewer_options_wid.selected_values[0]['grid'] + new_figure_size = (tmp4['x_scale'] * figure_size[0], + tmp4['y_scale'] * figure_size[1]) + + # horizontal axis limits + x_axis_limits = (0, + np.arange(0, errors_max.value, errors_step.value)[-1]) + + # render + renderer = plot_cumulative_error_distribution( + errors, error_range=[0., errors_max.value, errors_step.value], + figure_id=save_figure_wid.renderer[0].figure_id, new_figure=False, title=title.value, x_label=x_label.value, y_label=y_label.value, - x_scale=fig.x_scale, y_scale=fig.y_scale, figure_size=figure_size, - axes_fontsize=axes_fontsize.value, - labels_fontsize=labels_fontsize.value) + legend_entries=str(legend_entries_wid.value).split('\n')[:n_curves], + render_lines=render_lines, line_colour=line_colour, + line_style=line_style, line_width=line_width, + render_markers=render_markers, marker_style=marker_style, + marker_size=marker_size, marker_face_colour=marker_face_colour, + marker_edge_colour=marker_edge_colour, + marker_edge_width=marker_edge_width, + render_legend=tmp3['render_legend'], + legend_title=tmp3['legend_title'], + legend_font_name=tmp3['legend_font_name'], + legend_font_style=tmp3['legend_font_style'], + legend_font_size=tmp3['legend_font_size'], + legend_font_weight=tmp3['legend_font_weight'], + legend_marker_scale=tmp3['legend_marker_scale'], + legend_location=tmp3['legend_location'], + legend_bbox_to_anchor=tmp3['legend_bbox_to_anchor'], + legend_border_axes_pad=tmp3['legend_border_axes_pad'], + legend_n_columns=tmp3['legend_n_columns'], + legend_horizontal_spacing=tmp3['legend_horizontal_spacing'], + legend_vertical_spacing=tmp3['legend_vertical_spacing'], + legend_border=tmp3['legend_border'], + legend_border_padding=tmp3['legend_border_padding'], + legend_shadow=tmp3['legend_shadow'], + legend_rounded_corners=tmp3['legend_rounded_corners'], + render_axes=tmp4['render_axes'], + axes_font_name=tmp4['axes_font_name'], + axes_font_size=tmp4['axes_font_size'], + axes_font_style=tmp4['axes_font_style'], + axes_font_weight=tmp4['axes_font_weight'], + axes_x_limits=x_axis_limits, + axes_y_limits=viewer_options_wid.selected_values[0]['figure']['axes_y_limits'], + figure_size=new_figure_size, render_grid=tmp5['render_grid'], + grid_line_style=tmp5['grid_line_style'], + grid_line_width=tmp5['grid_line_width']) + + plt.show() # save the current figure id - save_figure_wid.figure_id = new_figure_id + save_figure_wid.renderer[0] = renderer # create options widgets - # x label, y label, title container - x_label = TextWidget(description='Horizontal axis label', - value=x_label_initial_value) - y_label = TextWidget(description='Vertical axis label', - value='Images Proportion') - title = TextWidget(description='Figure title', - value='Cumulative error ditribution') - labels_wid = ContainerWidget(children=[x_label, y_label, title]) - - # figure size - fig = figure_options_two_scales(plot_function, x_scale_default=1., - y_scale_default=1., coupled_default=False, - show_axes_default=True, - toggle_show_default=True, - figure_scales_bounds=(0.1, 2), - figure_scales_step=0.1, - figure_scales_visible=True, - show_axes_visible=False, - toggle_show_visible=False) - # fontsizes - labels_fontsize = FloatTextWidget(description='Labels fontsize', value=12.) - axes_fontsize = FloatTextWidget(description='Axes fontsize', value=12.) - fontsize_wid = ContainerWidget(children=[labels_fontsize, axes_fontsize]) - - # checkboxes - grid_visible = CheckboxWidget(description='Grid visible', value=False) - gridlinestyle_dict = OrderedDict() - gridlinestyle_dict['solid'] = '-' - gridlinestyle_dict['dashed'] = '--' - gridlinestyle_dict['dash-dot'] = '-.' - gridlinestyle_dict['dotted'] = ':' - gridlinestyle = DropdownWidget(values=gridlinestyle_dict, - value=':', - description='Grid style', disabled=False) - - def gridlinestyle_visibility(name, value): - gridlinestyle.disabled = not value - grid_visible.on_trait_change(gridlinestyle_visibility, 'value') - legend_visible = CheckboxWidget(description='Legend visible', value=True) - checkbox_wid = ContainerWidget(children=[grid_visible, gridlinestyle, - legend_visible]) - - # container of various options - tmp_various_wid = ContainerWidget(children=[fontsize_wid, checkbox_wid]) - various_wid = ContainerWidget(children=[fig, tmp_various_wid]) - - # axis limits - y_axis_limit = FloatSliderWidget(min=0., max=1.1, step=0.1, - description='Y axis limit', value=1.) - x_axis_limit = FloatSliderWidget(min=error_range[0] + error_range[2], - max=error_range[1], - step=error_range[2], - description='X axis limit', - value=x_axis_limit_initial_value) - axis_limits_wid = ContainerWidget(children=[x_axis_limit, y_axis_limit]) - - # accordion widget - figure_wid = AccordionWidget(children=[axis_limits_wid, labels_wid, - various_wid]) - figure_wid.set_title(0, 'Axes Limits') - figure_wid.set_title(1, 'Labels and Title') - figure_wid.set_title(2, 'Figure Size, Grid and Legend') - - # per curve options - plot_options_wid = plot_options(plot_options_list, - plot_function=plot_function, - toggle_show_visible=False, - toggle_show_default=True) - - # save figure options widget - # create figure and store its id - initial_figure_id = plt.figure() - save_figure_wid = save_figure_options(initial_figure_id, + # error_range + errors_max = ipywidgets.FloatSliderWidget( + min=error_range[0] + error_range[2], max=error_range[1], + step=error_range[2], description='Error axis max', + value=x_axis_limit_initial_value) + if error_type == 'me_norm': + errors_step = ipywidgets.FloatSliderWidget( + min=0., max=0.05, step=0.001, description='Error axis step', + value=x_axis_step_initial_value) + else: + errors_step = ipywidgets.FloatSliderWidget( + min=0., max=error_range[1], step=error_range[2] / 10., + description='Error axis step', value=x_axis_step_initial_value) + error_range_wid = ipywidgets.ContainerWidget(children=[errors_max, + errors_step]) + + # legend_entries, x label, y label, title container + legend_entries_wid = ipywidgets.TextareaWidget( + description='Legend entries', value="\n".join(legend_entries)) + x_label = ipywidgets.TextWidget(description='Horizontal axis label', + value=x_label_initial_value) + y_label = ipywidgets.TextWidget(description='Vertical axis label', + value='Images Proportion') + title = ipywidgets.TextWidget(description='Figure title', + value=' ') + labels_wid = ipywidgets.ContainerWidget(children=[legend_entries_wid, + x_label, y_label, title]) + + # viewer options widget + viewer_options_wid = viewer_options(viewer_options_default, + ['lines', 'markers', 'legend', + 'figure_two', 'grid'], + objects_names=legend_entries, + plot_function=plot_function, + toggle_show_visible=False, + toggle_show_default=True, + labels=None) + + # save figure widget + initial_renderer = MatplotlibImageViewer2d(figure_id=None, new_figure=True, + image=np.zeros((10, 10))) + save_figure_wid = save_figure_options(initial_renderer, toggle_show_default=True, toggle_show_visible=False) @@ -1936,64 +2459,322 @@ def gridlinestyle_visibility(name, value): x_label.on_trait_change(plot_function, 'value') y_label.on_trait_change(plot_function, 'value') title.on_trait_change(plot_function, 'value') - grid_visible.on_trait_change(plot_function, 'value') - gridlinestyle.on_trait_change(plot_function, 'value') - legend_visible.on_trait_change(plot_function, 'value') - y_axis_limit.on_trait_change(plot_function, 'value') - x_axis_limit.on_trait_change(plot_function, 'value') - labels_fontsize.on_trait_change(plot_function, 'value') - axes_fontsize.on_trait_change(plot_function, 'value') + legend_entries_wid.on_trait_change(plot_function, 'value') + errors_max.on_trait_change(plot_function, 'value') + errors_step.on_trait_change(plot_function, 'value') # create final widget - wid = TabWidget(children=[figure_wid, plot_options_wid, - save_figure_wid]) + tab_wid = ipywidgets.TabWidget(children=[error_range_wid, labels_wid, + viewer_options_wid, + save_figure_wid]) # create popup widget if asked if popup: - wid = PopupWidget(children=[wid], button_text='CED Menu') + wid = ipywidgets.PopupWidget(children=[logo(), tab_wid], + button_text='CED Menu') + else: + wid = ipywidgets.ContainerWidget(children=[logo(), tab_wid], + button_text='CED Menu') # display final widget - display(wid) + ipydisplay.display(wid) # set final tab titles - tab_titles = ['Figure options', 'Per Curve options', 'Save figure'] - if n_curves == 1: - tab_titles[1] = 'Curve options' - if popup: - for (k, tl) in enumerate(tab_titles): - wid.children[0].set_title(k, tl) - else: - for (k, tl) in enumerate(tab_titles): - wid.set_title(k, tl) + tab_titles = ['Error axis options', 'Labels options', 'Viewer options', + 'Save figure'] + for (k, tl) in enumerate(tab_titles): + tab_wid.set_title(k, tl) # format options' widgets labels_wid.add_class('align-end') - axis_limits_wid.add_class('align-start') - fontsize_wid.add_class('align-end') - fontsize_wid.set_css('margin-right', '1cm') - checkbox_wid.add_class('align-end') - tmp_various_wid.remove_class('vbox') - tmp_various_wid.add_class('hbox') - format_plot_options(plot_options_wid, container_padding='1px', - container_margin='1px', - container_border='1px solid black', - toggle_button_font_weight='bold', border_visible=False, - suboptions_border_visible=True) - format_figure_options_two_scales(fig, container_padding='6px', - container_margin='6px', - container_border='1px solid black', - toggle_button_font_weight='bold', - border_visible=False) + legend_entries_wid.set_css('width', '6cm') + legend_entries_wid.set_css('height', '2cm') + x_label.set_css('width', '6cm') + y_label.set_css('width', '6cm') + title.set_css('width', '6cm') + errors_max.set_css('width', '6cm') + errors_step.set_css('width', '6cm') + format_viewer_options(viewer_options_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=False, + suboptions_border_visible=True) format_save_figure_options(save_figure_wid, container_padding='6px', container_margin='6px', container_border='1px solid black', toggle_button_font_weight='bold', - tab_top_margin='0cm', - border_visible=False) + tab_top_margin='0cm', border_visible=False) # Reset value to trigger initial visualization - grid_visible.value = True + title.value = 'Cumulative error distribution' # return widget object if asked if return_widget: return wid + + +def _visualize(image, renderer, render_image, render_landmarks, image_is_masked, + masked_enabled, channels, glyph_enabled, glyph_block_size, + glyph_use_negative, sum_enabled, groups, with_labels, + subplots_enabled, subplots_titles, image_axes_mode, + render_lines, line_style, line_width, line_colour, + render_markers, marker_style, marker_size, marker_edge_width, + marker_edge_colour, marker_face_colour, render_numbering, + numbers_horizontal_align, numbers_vertical_align, + numbers_font_name, numbers_font_size, numbers_font_style, + numbers_font_weight, numbers_font_colour, render_legend, + legend_title, legend_font_name, legend_font_style, + legend_font_size, legend_font_weight, legend_marker_scale, + legend_location, legend_bbox_to_anchor, legend_border_axes_pad, + legend_n_columns, legend_horizontal_spacing, + legend_vertical_spacing, legend_border, legend_border_padding, + legend_shadow, legend_rounded_corners, render_axes, + axes_font_name, axes_font_size, axes_font_style, + axes_font_weight, axes_x_limits, axes_y_limits, interpolation, + alpha, figure_size): + import matplotlib.pyplot as plt + + global glyph + if glyph is None: + from menpo.visualize.image import glyph + + # This makes the code shorter for dealing with masked images vs non-masked + # images + mask_arguments = ({'masked': masked_enabled} + if image_is_masked else {}) + + # plot + if render_image: + # image will be displayed + if render_landmarks and len(groups) > 0: + # there are selected landmark groups and they will be displayed + if subplots_enabled: + # calculate subplots structure + subplots = MatplotlibSubplots()._subplot_layout(len(groups)) + # show image with landmarks + for k, group in enumerate(groups): + if subplots_enabled: + # create subplot + plt.subplot(subplots[0], subplots[1], k + 1) + if render_legend: + # set subplot's title + plt.title(subplots_titles[group], + fontname=legend_font_name, + fontstyle=legend_font_style, + fontweight=legend_font_weight, + fontsize=legend_font_size) + if glyph_enabled or sum_enabled: + # image, landmarks, masked, glyph + renderer = glyph(image, vectors_block_size=glyph_block_size, + use_negative=glyph_use_negative, + channels=channels).\ + view_landmarks( + group=group, with_labels=with_labels[k], + without_labels=None, figure_id=renderer.figure_id, + new_figure=False, render_lines=render_lines[k], + line_style=line_style[k], line_width=line_width[k], + line_colour=line_colour[k], + render_markers=render_markers[k], + marker_style=marker_style[k], + marker_size=marker_size[k], + marker_edge_width=marker_edge_width[k], + marker_edge_colour=marker_edge_colour[k], + marker_face_colour=marker_face_colour[k], + render_numbering=render_numbering, + numbers_horizontal_align=numbers_horizontal_align, + numbers_vertical_align=numbers_vertical_align, + numbers_font_name=numbers_font_name, + numbers_font_size=numbers_font_size, + numbers_font_style=numbers_font_style, + numbers_font_weight=numbers_font_weight, + numbers_font_colour=numbers_font_colour, + render_legend=render_legend and not subplots_enabled, + legend_title=legend_title, + legend_font_name=legend_font_name, + legend_font_style=legend_font_style, + legend_font_size=legend_font_size, + legend_font_weight=legend_font_weight, + legend_marker_scale=legend_marker_scale, + legend_location=legend_location, + legend_bbox_to_anchor=legend_bbox_to_anchor, + legend_border_axes_pad=legend_border_axes_pad, + legend_n_columns=legend_n_columns, + legend_horizontal_spacing=legend_horizontal_spacing, + legend_vertical_spacing=legend_vertical_spacing, + legend_border=legend_border, + legend_border_padding=legend_border_padding, + legend_shadow=legend_shadow, + legend_rounded_corners=legend_rounded_corners, + render_axes=render_axes, + axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, + interpolation=interpolation, alpha=alpha, + figure_size=figure_size, **mask_arguments) + else: + # image, landmarks, masked, not glyph + renderer = image.view_landmarks( + channels=channels, group=group, + with_labels=with_labels[k], without_labels=None, + figure_id=renderer.figure_id, new_figure=False, + render_lines=render_lines[k], line_style=line_style[k], + line_width=line_width[k], line_colour=line_colour[k], + render_markers=render_markers[k], + marker_style=marker_style[k], + marker_size=marker_size[k], + marker_edge_width=marker_edge_width[k], + marker_edge_colour=marker_edge_colour[k], + marker_face_colour=marker_face_colour[k], + render_numbering=render_numbering, + numbers_horizontal_align=numbers_horizontal_align, + numbers_vertical_align=numbers_vertical_align, + numbers_font_name=numbers_font_name, + numbers_font_size=numbers_font_size, + numbers_font_style=numbers_font_style, + numbers_font_weight=numbers_font_weight, + numbers_font_colour=numbers_font_colour, + render_legend=render_legend and not subplots_enabled, + legend_title=legend_title, + legend_font_name=legend_font_name, + legend_font_style=legend_font_style, + legend_font_size=legend_font_size, + legend_font_weight=legend_font_weight, + legend_marker_scale=legend_marker_scale, + legend_location=legend_location, + legend_bbox_to_anchor=legend_bbox_to_anchor, + legend_border_axes_pad=legend_border_axes_pad, + legend_n_columns=legend_n_columns, + legend_horizontal_spacing=legend_horizontal_spacing, + legend_vertical_spacing=legend_vertical_spacing, + legend_border=legend_border, + legend_border_padding=legend_border_padding, + legend_shadow=legend_shadow, + legend_rounded_corners=legend_rounded_corners, + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, + interpolation=interpolation, alpha=alpha, + figure_size=figure_size, **mask_arguments) + else: + # either there are not any landmark groups selected or they won't + # be displayed + if glyph_enabled or sum_enabled: + # image, not landmarks, masked, glyph + renderer = glyph(image, vectors_block_size=glyph_block_size, + use_negative=glyph_use_negative, + channels=channels).view( + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, axes_y_limits=axes_y_limits, + figure_size=figure_size, interpolation=interpolation, + alpha=alpha, **mask_arguments) + else: + # image, not landmarks, masked, not glyph + renderer = image.view( + channels=channels, render_axes=render_axes, + axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, + axes_y_limits=axes_y_limits, figure_size=figure_size, + interpolation=interpolation, alpha=alpha, **mask_arguments) + else: + # image won't be displayed + if render_landmarks and len(groups) > 0: + # there are selected landmark groups and they will be displayed + if subplots_enabled: + # calculate subplots structure + subplots = MatplotlibSubplots()._subplot_layout(len(groups)) + # not image, landmarks + for k, group in enumerate(groups): + if subplots_enabled: + # create subplot + plt.subplot(subplots[0], subplots[1], k + 1) + if render_legend: + # set subplot's title + plt.title(subplots_titles[group], + fontname=legend_font_name, + fontstyle=legend_font_style, + fontweight=legend_font_weight, + fontsize=legend_font_size) + image.landmarks[group].lms.view( + image_view=image_axes_mode, render_lines=render_lines[k], + line_style=line_style[k], line_width=line_width[k], + line_colour=line_colour[k], + render_markers=render_markers[k], + marker_style=marker_style[k], marker_size=marker_size[k], + marker_edge_width=marker_edge_width[k], + marker_edge_colour=marker_edge_colour[k], + marker_face_colour=marker_face_colour[k], + render_axes=render_axes, axes_font_name=axes_font_name, + axes_font_size=axes_font_size, + axes_font_style=axes_font_style, + axes_font_weight=axes_font_weight, + axes_x_limits=axes_x_limits, axes_y_limits=axes_y_limits, + figure_size=figure_size) + if not subplots_enabled: + if len(groups) % 2 == 0: + plt.gca().invert_yaxis() + if render_legend: + # Options related to legend's font + prop = {'family': legend_font_name, + 'size': legend_font_size, + 'style': legend_font_style, + 'weight': legend_font_weight} + + # display legend on side + plt.gca().legend(groups, title=legend_title, prop=prop, + loc=legend_location, + bbox_to_anchor=legend_bbox_to_anchor, + borderaxespad=legend_border_axes_pad, + ncol=legend_n_columns, + columnspacing=legend_horizontal_spacing, + labelspacing=legend_vertical_spacing, + frameon=legend_border, + borderpad=legend_border_padding, + shadow=legend_shadow, + fancybox=legend_rounded_corners, + markerscale=legend_marker_scale) + + # show plot + plt.show() + + return renderer + + +def _check_n_parameters(n_params, n_levels, max_n_params): + r""" + Checks the maximum number of components per level either of the shape + or the appearance model. It must be None or int or float or a list of + those containing 1 or {n_levels} elements. + """ + str_error = ("n_params must be None or 1 <= int <= max_n_params or " + "a list of those containing 1 or {} elements").format(n_levels) + if not isinstance(n_params, list): + n_params_list = [n_params] * n_levels + elif len(n_params) == 1: + n_params_list = [n_params[0]] * n_levels + elif len(n_params) == n_levels: + n_params_list = n_params + else: + raise ValueError(str_error) + for i, comp in enumerate(n_params_list): + if comp is None: + n_params_list[i] = max_n_params[i] + else: + if isinstance(comp, int): + if comp > max_n_params[i]: + n_params_list[i] = max_n_params[i] + else: + raise ValueError(str_error) + return n_params_list diff --git a/menpofit/visualize/widgets/options.py b/menpofit/visualize/widgets/options.py new file mode 100644 index 0000000..9d3c019 --- /dev/null +++ b/menpofit/visualize/widgets/options.py @@ -0,0 +1,1631 @@ +from collections import OrderedDict + +from menpo.visualize.widgets.tools import (colour_selection, + format_colour_selection) +from menpo.visualize.widgets.options import (animation_options, + format_animation_options, + _compare_groups_and_labels) + + +def model_parameters(n_params, plot_function=None, params_str='', + mode='multiple', params_bounds=(-3., 3.), + plot_eig_visible=True, plot_eig_function=None, + toggle_show_default=True, toggle_show_visible=True, + toggle_show_name='Parameters'): + r""" + Creates a widget with Model Parameters. Specifically, it has: + 1) A slider for each parameter if mode is 'multiple'. + 2) A single slider and a drop down menu selection if mode is 'single'. + 3) A reset button. + 4) A button and two radio buttons for plotting the eigenvalues variance + ratio. + + The structure of the widgets is the following: + model_parameters_wid.children = [toggle_button, parameters_and_reset] + parameters_and_reset.children = [parameters_widgets, reset] + If plot_eig_visible is True: + reset = [plot_eigenvalues, reset_button] + Else: + reset = reset_button + If mode is single: + parameters_widgets.children = [drop_down_menu, slider] + If mode is multiple: + parameters_widgets.children = [all_sliders] + + The returned widget saves the selected values in the following fields: + model_parameters_wid.parameters_values + model_parameters_wid.mode + model_parameters_wid.plot_eig_visible + + To fix the alignment within this widget please refer to + `format_model_parameters()` function. + + To update the state of this widget, please refer to + `update_model_parameters()` function. + + Parameters + ---------- + n_params : `int` + The number of principal components to use for the sliders. + + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + + params_str : `str`, optional + The string that will be used for each parameters name. + + mode : 'single' or 'multiple', optional + If single, only a single slider is constructed along with a drop down + menu. + If multiple, a slider is constructed for each parameter. + + params_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + + plot_eig_visible : `boolean`, optional + Defines whether the options for plotting the eigenvalues variance ratio + will be visible upon construction. + + plot_eig_function : `function` or None, optional + The plot function that is executed when the plot eigenvalues button is + clicked. If None, then nothing is assigned. + + toggle_show_default : `boolean`, optional + Defines whether the options will be visible upon construction. + + toggle_show_visible : `boolean`, optional + The visibility of the toggle button. + + toggle_show_name : `str`, optional + The name of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # If only one slider requested, then set mode to multiple + if n_params == 1: + mode = 'multiple' + + # Create all necessary widgets + but = ipywidgets.ToggleButtonWidget(description=toggle_show_name, + value=toggle_show_default, + visible=toggle_show_visible) + reset_button = ipywidgets.ButtonWidget(description='Reset') + if mode == 'multiple': + sliders = [ipywidgets.FloatSliderWidget( + description="{}{}".format(params_str, p), + min=params_bounds[0], max=params_bounds[1], + value=0.) + for p in range(n_params)] + parameters_wid = ipywidgets.ContainerWidget(children=sliders) + else: + vals = OrderedDict() + for p in range(n_params): + vals["{}{}".format(params_str, p)] = p + slider = ipywidgets.FloatSliderWidget(description='', + min=params_bounds[0], + max=params_bounds[1], value=0.) + dropdown_params = ipywidgets.DropdownWidget(values=vals) + parameters_wid = ipywidgets.ContainerWidget( + children=[dropdown_params, slider]) + + # Group widgets + if plot_eig_visible: + plot_button = ipywidgets.ButtonWidget(description='Plot eigenvalues') + if plot_eig_function is not None: + plot_button.on_click(plot_eig_function) + plot_and_reset = ipywidgets.ContainerWidget( + children=[plot_button, reset_button]) + params_and_reset = ipywidgets.ContainerWidget(children=[parameters_wid, + plot_and_reset]) + else: + params_and_reset = ipywidgets.ContainerWidget(children=[parameters_wid, + reset_button]) + + # Widget container + model_parameters_wid = ipywidgets.ContainerWidget( + children=[but, params_and_reset]) + + # Save mode and parameters values + model_parameters_wid.parameters_values = [0.0] * n_params + model_parameters_wid.mode = mode + model_parameters_wid.plot_eig_visible = plot_eig_visible + + # set up functions + if mode == 'single': + # assign slider value to parameters values list + def save_slider_value(name, value): + model_parameters_wid.parameters_values[dropdown_params.value] = \ + value + slider.on_trait_change(save_slider_value, 'value') + + # set correct value to slider when drop down menu value changes + def set_slider_value(name, value): + slider.value = model_parameters_wid.parameters_values[value] + dropdown_params.on_trait_change(set_slider_value, 'value') + + # assign main plotting function when slider value changes + if plot_function is not None: + slider.on_trait_change(plot_function, 'value') + else: + # assign slider value to parameters values list + def save_slider_value_from_id(description, name, value): + i = int(description[len(params_str)::]) + model_parameters_wid.parameters_values[i] = value + + # partial function that helps get the widget's description str + def partial_widget(description): + return lambda name, value: save_slider_value_from_id(description, + name, value) + + # assign saving values and main plotting function to all sliders + for w in parameters_wid.children: + # The widget (w) is lexically scoped and so we need a way of + # ensuring that we don't just receive the final value of w at every + # iteration. Therefore we create another lambda function that + # creates a new lexical scoping so that we can ensure the value of w + # is maintained (as x) at each iteration. + # In JavaScript, we would just use the 'let' keyword... + w.on_trait_change(partial_widget(w.description), 'value') + if plot_function is not None: + w.on_trait_change(plot_function, 'value') + + # reset function + def reset_params(name): + model_parameters_wid.parameters_values = \ + [0.0] * len(model_parameters_wid.parameters_values) + if mode == 'multiple': + for ww in parameters_wid.children: + ww.value = 0. + else: + parameters_wid.children[0].value = 0 + parameters_wid.children[1].value = 0. + reset_button.on_click(reset_params) + + # Toggle button function + def show_options(name, value): + params_and_reset.visible = value + show_options('', toggle_show_default) + but.on_trait_change(show_options, 'value') + + return model_parameters_wid + + +def format_model_parameters(model_parameters_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=True): + r""" + Function that corrects the align (style format) of a given model_parameters + widget. Usage example: + model_parameters_wid = model_parameters() + display(model_parameters_wid) + format_model_parameters(model_parameters_wid) + + Parameters + ---------- + model_parameters_wid : + The widget object generated by the `model_parameters()` function. + + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + + border_visible : `boolean`, optional + Defines whether to draw the border line around the widget. + """ + if model_parameters_wid.mode == 'single': + # align drop down menu and slider + model_parameters_wid.children[1].children[0].remove_class('vbox') + model_parameters_wid.children[1].children[0].add_class('hbox') + else: + # align sliders + model_parameters_wid.children[1].children[0].add_class('start') + + # align reset button to right + if model_parameters_wid.plot_eig_visible: + model_parameters_wid.children[1].children[1].remove_class('vbox') + model_parameters_wid.children[1].children[1].add_class('hbox') + model_parameters_wid.children[1].add_class('align-end') + + # set toggle button font bold + model_parameters_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + + # margin and border around plot_eigenvalues widget + if model_parameters_wid.plot_eig_visible: + model_parameters_wid.children[1].children[1].children[0].set_css( + 'margin-right', container_margin) + + # margin and border around container widget + model_parameters_wid.set_css('padding', container_padding) + model_parameters_wid.set_css('margin', container_margin) + if border_visible: + model_parameters_wid.set_css('border', container_border) + + +def update_model_parameters(model_parameters_wid, n_params, plot_function=None, + params_str=''): + r""" + Function that updates the state of a given model_parameters widget if the + requested number of parameters has changed. Usage example: + model_parameters_wid = model_parameters(n_params=5) + display(model_parameters_wid) + format_model_parameters(model_parameters_wid) + update_model_parameters(model_parameters_wid, 3) + + Parameters + ---------- + model_parameters_wid : + The widget object generated by the `model_parameters()` function. + + n_params : `int` + The requested number of parameters. + + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + + params_str : `str`, optional + The string that will be used for each parameters name. + """ + import IPython.html.widgets as ipywidgets + + if model_parameters_wid.mode == 'multiple': + # get the number of enabled parameters (number of sliders) + enabled_params = len(model_parameters_wid.children[1].children[0].children) + if n_params != enabled_params: + # reset all parameters values + model_parameters_wid.parameters_values = [0.0] * n_params + # get params_bounds + pb = [model_parameters_wid.children[1].children[0].children[0].min, + model_parameters_wid.children[1].children[0].children[0].max] + # create sliders widgets + sliders = [ipywidgets.FloatSliderWidget( + description="{}{}".format(params_str, + p), + min=pb[0], max=pb[1], value=0.) + for p in range(n_params)] + # assign sliders to container + model_parameters_wid.children[1].children[0].children = sliders + + # assign slider value to parameters values list + def save_slider_value_from_id(description, name, value): + i = int(description[len(params_str)::]) + model_parameters_wid.parameters_values[i] = value + + # partial function that helps get the widget's description str + def partial_widget(description): + return lambda name, value: save_slider_value_from_id( + description, + name, value) + + # assign saving values and main plotting function to all sliders + for w in model_parameters_wid.children[1].children[0].children: + # The widget (w) is lexically scoped and so we need a way of + # ensuring that we don't just receive the final value of w at + # every iteration. Therefore we create another lambda function + # that creates a new lexical scoping so that we can ensure the + # value of w is maintained (as x) at each iteration + # In JavaScript, we would just use the 'let' keyword... + w.on_trait_change(partial_widget(w.description), 'value') + if plot_function is not None: + w.on_trait_change(plot_function, 'value') + else: + # get the number of enabled parameters (len of list of drop down menu) + enabled_params = len( + model_parameters_wid.children[1].children[0].children[0].values) + if n_params != enabled_params: + # reset all parameters values + model_parameters_wid.parameters_values = [0.0] * n_params + # change drop down menu values + vals = OrderedDict() + for p in range(n_params): + vals["{}{}".format(params_str, p)] = p + model_parameters_wid.children[1].children[0].children[0].values = \ + vals + # set initial value to the first and slider value to zero + model_parameters_wid.children[1].children[0].children[0].value = \ + vals["{}{}".format(params_str, 0)] + model_parameters_wid.children[1].children[0].children[1].value = 0. + + +def final_result_options(final_result_options_default, plot_function=None, + title='Final Result', toggle_show_default=True, + toggle_show_visible=True): + r""" + Creates a widget with Final Result Options. Specifically, it has: + 1) A set of toggle buttons representing usually the initial, final and + ground truth shapes. + 2) A checkbox that controls the rendering of the image. + 3) A set of radio buttons that define whether subplots are enabled. + 4) A toggle button that controls the visibility of all the above, i.e. + the final result options. + + The structure of the widgets is the following: + final_result_wid.children = [toggle_button, shapes_toggle_buttons, + options] + options.children = [plot_mode_radio_buttons, show_image_checkbox] + + The returned widget saves the selected values in the following fields: + final_result_wid.selected_values + + To fix the alignment within this widget please refer to + `format_final_result_options()` function. + + To update the state of this widget, please refer to + `update_final_result_options()` function. + + Parameters + ---------- + final_result_options_default : `dict` + The default options. For example: + final_result_options_default = {'all_groups': ['initial', 'final', + 'ground'], + 'selected_groups': ['final'], + 'render_image': True, + 'subplots_enabled': True} + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + title : `str`, optional + The title of the widget printed at the toggle button. + toggle_show_default : `bool`, optional + Defines whether the options will be visible upon construction. + toggle_show_visible : `bool`, optional + The visibility of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # Toggle button that controls options' visibility + but = ipywidgets.ToggleButtonWidget(description=title, + value=toggle_show_default, + visible=toggle_show_visible) + + # Create widgets + shapes_checkboxes = [ipywidgets.LatexWidget(value='Select shape:')] + for group in final_result_options_default['all_groups']: + t = ipywidgets.ToggleButtonWidget( + description=group, + value=group in final_result_options_default['selected_groups']) + shapes_checkboxes.append(t) + render_image = ipywidgets.CheckboxWidget( + description='Render image', + value=final_result_options_default['render_image']) + mode = ipywidgets.RadioButtonsWidget( + description='Plot mode:', values={'Single': False, 'Multiple': True}, + value=final_result_options_default['subplots_enabled']) + + # Group widgets + shapes_wid = ipywidgets.ContainerWidget(children=shapes_checkboxes) + opts = ipywidgets.ContainerWidget(children=[mode, render_image]) + + # Widget container + final_result_wid = ipywidgets.ContainerWidget(children=[but, shapes_wid, + opts]) + + # Initialize variables + final_result_wid.selected_values = final_result_options_default + + # Groups function + def groups_fun(name, value): + final_result_wid.selected_values['selected_groups'] = [] + for i in shapes_wid.children[1::]: + if i.value: + final_result_wid.selected_values['selected_groups'].\ + append(str(i.description)) + for w in shapes_wid.children[1::]: + w.on_trait_change(groups_fun, 'value') + + # Render image function + def render_image_fun(name, value): + final_result_wid.selected_values['render_image'] = value + render_image.on_trait_change(render_image_fun, 'value') + + # Plot mode function + def plot_mode_fun(name, value): + final_result_wid.selected_values['subplots_enabled'] = value + mode.on_trait_change(plot_mode_fun, 'value') + + # Toggle button function + def show_options(name, value): + shapes_wid.visible = value + opts.visible = value + show_options('', toggle_show_default) + but.on_trait_change(show_options, 'value') + + # assign plot_function + if plot_function is not None: + render_image.on_trait_change(plot_function, 'value') + mode.on_trait_change(plot_function, 'value') + for w in shapes_wid.children[1::]: + w.on_trait_change(plot_function, 'value') + + return final_result_wid + + +def format_final_result_options(final_result_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=True): + r""" + Function that corrects the align (style format) of a given + final_result_options widget. Usage example: + final_result_options_default = {'all_groups': ['initial', 'final', + 'ground'], + 'selected_groups': ['final'], + 'render_image': True, + 'subplots_enabled': True} + final_result_wid = final_result_options(final_result_options_default) + display(final_result_wid) + format_final_result_options(final_result_wid) + + Parameters + ---------- + final_result_wid : + The widget object generated by the `final_result_options()` function. + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + border_visible : `boolean`, optional + Defines whether to draw the border line around the widget. + """ + # align shapes toggle buttons + final_result_wid.children[1].remove_class('vbox') + final_result_wid.children[1].add_class('hbox') + final_result_wid.children[1].add_class('align-center') + final_result_wid.children[1].children[0].set_css('margin-right', + container_margin) + + # align mode and legend options + final_result_wid.children[2].remove_class('vbox') + final_result_wid.children[2].add_class('hbox') + final_result_wid.children[2].children[0].set_css('margin-right', '20px') + + # set toggle button font bold + final_result_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + final_result_wid.children[1].set_css('margin-top', container_margin) + + # margin and border around container widget + final_result_wid.set_css('padding', container_padding) + final_result_wid.set_css('margin', container_margin) + if border_visible: + final_result_wid.set_css('border', container_border) + + +def update_final_result_options(final_result_wid, group_keys, plot_function): + r""" + Function that updates the state of a given final_result_options widget if + the group keys of an image has changed. Usage example: + final_result_options_default = {'all_groups': ['group1', 'group2'], + 'selected_groups': ['group1'], + 'render_image': True, + 'subplots_enabled': True} + final_result_wid = final_result_options(final_result_options_default) + display(final_result_wid) + format_final_result_options(final_result_wid) + update_final_result_options(final_result_wid, group_keys=['group3']) + format_final_result_options(final_result_wid) + + Note that the `format_final_result_options()` function needs to be called + again after the `update_final_result_options()` function. + + Parameters + ---------- + final_result_wid : + The widget object generated by the `final_result_options()` function. + group_keys : `list` of `str` + A list of the available landmark groups. + plot_function : `function` or None + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + """ + import IPython.html.widgets as ipywidgets + # check if the new group_keys are the same as the old ones + if not _compare_groups_and_labels( + group_keys, [], final_result_wid.selected_values['all_groups'], []): + # Create all necessary widgets + shapes_checkboxes = [ipywidgets.LatexWidget(value='Select shape:')] + for group in group_keys: + t = ipywidgets.ToggleButtonWidget(description=group, value=True) + shapes_checkboxes.append(t) + + # Group widgets + final_result_wid.children[1].children = shapes_checkboxes + + # Initialize output variables + final_result_wid.selected_values['all_groups'] = group_keys + final_result_wid.selected_values['selected_groups'] = group_keys + + # Groups function + def groups_fun(name, value): + final_result_wid.selected_values['selected_groups'] = [] + for i in final_result_wid.children[1].children[1::]: + if i.value: + final_result_wid.selected_values['selected_groups'].append(str(i.description)) + for w in final_result_wid.children[1].children[1::]: + w.on_trait_change(groups_fun, 'value') + + # Toggle button function + def show_options(name, value): + final_result_wid.children[1].visible = value + final_result_wid.children[2].visible = value + show_options('', final_result_wid.children[0].value) + final_result_wid.children[0].on_trait_change(show_options, 'value') + + # assign plot_function + if plot_function is not None: + final_result_wid.children[2].children[0].on_trait_change( + plot_function, 'value') + final_result_wid.children[2].children[1].on_trait_change( + plot_function, 'value') + for w in final_result_wid.children[1].children[1::]: + w.on_trait_change(plot_function, 'value') + + +def iterations_result_options(iterations_result_options_default, + plot_function=None, plot_errors_function=None, + plot_displacements_function=None, + title='Iterations Result', + toggle_show_default=True, + toggle_show_visible=True): + r""" + Creates a widget with Iterations Result Options. Specifically, it has: + 1) Two radio buttons that select an options mode, depending on whether + the user wants to visualize iterations in ``Animation`` or ``Static`` + mode. + 2) If mode is ``Animation``, an animation options widget appears. + If mode is ``Static``, the iterations range is selected by two + sliders and there is an update plot button. + 3) A checkbox that controls the visibility of the image. + 4) A set of radio buttons that define whether subplots are enabled. + 5) A button to plot the error evolution. + 6) A button to plot the landmark points' displacement. + 7) A drop down menu to select which displacement to plot. + 8) A toggle button that controls the visibility of all the above, i.e. + the final result options. + + The structure of the widgets is the following: + iterations_result_wid.children = [toggle_button, all_options] + all_options.children = [iterations_mode_and_sliders, options] + iterations_mode_and_sliders.children = [iterations_mode_radio_buttons, + all_sliders] + all_sliders.children = [animation_slider, first_slider, second_slider, + update_and_axes] + update_and_axes.children = [same_axes_checkbox, update_button] + options.children = [render_image_checkbox, plot_errors_button, + plot_displacements] + plot_displacements.children = [plot_displacements_button, + plot_displacements_drop_down_menu] + + The returned widget saves the selected values in the following fields: + iterations_result_wid.selected_values + + To fix the alignment within this widget please refer to + `format_iterations_result_options()` function. + + To update the state of this widget, please refer to + `update_iterations_result_options()` function. + + Parameters + ---------- + iterations_result_options_default : `dict` + The default options. For example: + iterations_result_options_default = {'n_iters': 10, + 'image_has_gt_shape': True, + 'n_points': 68, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + plot_errors_function : `function` or None, optional + The plot function that is executed when the 'Plot Errors' button is + pressed. + If None, then nothing is assigned. + plot_displacements_function : `function` or None, optional + The plot function that is executed when the 'Plot Displacements' button + is pressed. + If None, then nothing is assigned. + title : `str`, optional + The title of the widget printed at the toggle button. + toggle_show_default : `bool`, optional + Defines whether the options will be visible upon construction. + toggle_show_visible : `bool`, optional + The visibility of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # Create all necessary widgets + but = ipywidgets.ToggleButtonWidget(description=title, + value=toggle_show_default, + visible=toggle_show_visible) + iterations_mode = ipywidgets.RadioButtonsWidget( + values={'Animation': 'animation', 'Static': 'static'}, + value='animation', description='Mode:', visible=toggle_show_default) + # Don't assign the plot function to the animation_wid at this point. We + # first need to assign the get_groups function and then the plot_function() + # for synchronization reasons. + index_selection_default = { + 'min': 0, 'max': iterations_result_options_default['n_iters'] - 1, + 'step': 1, 'index': 0} + animation_wid = animation_options( + index_selection_default, plot_function=None, update_function=None, + index_description='Iteration', index_style='slider', loop_default=False, + interval_default=0.2, toggle_show_default=toggle_show_default, + toggle_show_visible=False) + first_slider_wid = ipywidgets.IntSliderWidget( + min=0, max=iterations_result_options_default['n_iters'] - 1, step=1, + value=0, description='From', visible=False) + second_slider_wid = ipywidgets.IntSliderWidget( + min=0, max=iterations_result_options_default['n_iters'] - 1, step=1, + value=iterations_result_options_default['n_iters'] - 1, + description='To', visible=False) + same_axes = ipywidgets.CheckboxWidget( + description='Same axes', + value=not iterations_result_options_default['subplots_enabled'], + visible=False) + update_but = ipywidgets.ButtonWidget(description='Update Plot', + visible=False) + render_image = ipywidgets.CheckboxWidget( + description='Render image', + value=iterations_result_options_default['render_image']) + plot_errors_button = ipywidgets.ButtonWidget(description='Plot Errors') + plot_displacements_button = ipywidgets.ButtonWidget( + description='Plot Displacements') + dropdown_menu = OrderedDict() + dropdown_menu['mean'] = 'mean' + dropdown_menu['median'] = 'median' + dropdown_menu['max'] = 'max' + dropdown_menu['min'] = 'min' + for p in range(iterations_result_options_default['n_points']): + dropdown_menu["point {}".format(p)] = p + plot_displacements_menu = ipywidgets.DropdownWidget( + values=dropdown_menu, + value=iterations_result_options_default['displacement_type']) + + # if just one iteration, disable multiple options + if iterations_result_options_default['n_iters'] == 1: + iterations_mode.value = 'animation' + iterations_mode.disabled = True + first_slider_wid.disabled = True + animation_wid.children[1].children[0].disabled = True + animation_wid.children[1].children[1].children[0].children[0].\ + disabled = True + animation_wid.children[1].children[1].children[0].children[1].\ + disabled = True + animation_wid.children[1].children[1].children[0].children[2].\ + disabled = True + second_slider_wid.disabled = True + plot_errors_button.disabled = True + plot_displacements_button.disabled = True + plot_displacements_menu.disabled = True + + # Group widgets + update_and_subplots = ipywidgets.ContainerWidget( + children=[same_axes, update_but]) + sliders = ipywidgets.ContainerWidget( + children=[animation_wid, first_slider_wid, second_slider_wid, + update_and_subplots]) + iterations_mode_and_sliders = ipywidgets.ContainerWidget( + children=[iterations_mode, sliders]) + plot_displacements = ipywidgets.ContainerWidget( + children=[plot_displacements_button, plot_displacements_menu]) + opts = ipywidgets.ContainerWidget( + children=[render_image, plot_errors_button, plot_displacements]) + all_options = ipywidgets.ContainerWidget( + children=[iterations_mode_and_sliders, opts]) + + # Widget container + iterations_result_wid = ipywidgets.ContainerWidget(children=[but, + all_options]) + + # Initialize variables + iterations_result_options_default['selected_groups'] = \ + _convert_iterations_to_groups( + 0, 0, iterations_result_options_default['iter_str']) + iterations_result_wid.selected_values = iterations_result_options_default + + # Define iterations mode visibility + def iterations_mode_selection(name, value): + if value == 'animation': + # get val that needs to be assigned + val = first_slider_wid.value + # update visibility + animation_wid.visible = True + first_slider_wid.visible = False + second_slider_wid.visible = False + same_axes.visible = False + update_but.visible = False + # set correct values + animation_wid.children[1].children[0].value = val + animation_wid.selected_values['index'] = val + first_slider_wid.value = 0 + second_slider_wid.value = \ + iterations_result_wid.selected_values['n_iters'] - 1 + else: + # get val that needs to be assigned + val = animation_wid.selected_values['index'] + # update visibility + animation_wid.visible = False + first_slider_wid.visible = True + second_slider_wid.visible = True + same_axes.visible = True + update_but.visible = True + # set correct values + second_slider_wid.value = val + first_slider_wid.value = val + animation_wid.children[1].children[0].value = 0 + animation_wid.selected_values['index'] = 0 + iterations_mode.on_trait_change(iterations_mode_selection, 'value') + + # Check first slider's value + def first_slider_val(name, value): + if value > second_slider_wid.value: + first_slider_wid.value = second_slider_wid.value + first_slider_wid.on_trait_change(first_slider_val, 'value') + + # Check second slider's value + def second_slider_val(name, value): + if value < first_slider_wid.value: + second_slider_wid.value = first_slider_wid.value + second_slider_wid.on_trait_change(second_slider_val, 'value') + + # Convert slider values to groups + def get_groups(name, value): + if iterations_mode.value == 'animation': + iterations_result_wid.selected_values['selected_groups'] = \ + _convert_iterations_to_groups( + animation_wid.selected_values['index'], + animation_wid.selected_values['index'], + iterations_result_wid.selected_values['iter_str']) + else: + iterations_result_wid.selected_values['selected_groups'] = \ + _convert_iterations_to_groups( + first_slider_wid.value, second_slider_wid.value, + iterations_result_wid.selected_values['iter_str']) + first_slider_wid.on_trait_change(get_groups, 'value') + second_slider_wid.on_trait_change(get_groups, 'value') + + # assign get_groups() to the slider of animation_wid + animation_wid.children[1].children[0].on_trait_change(get_groups, 'value') + + # Render image function + def render_image_fun(name, value): + iterations_result_wid.selected_values['render_image'] = value + render_image.on_trait_change(render_image_fun, 'value') + + # Same axes function + def same_axes_fun(name, value): + iterations_result_wid.selected_values['subplots_enabled'] = not value + same_axes.on_trait_change(same_axes_fun, 'value') + + # Displacement type function + def displacement_type_fun(name, value): + iterations_result_wid.selected_values['displacement_type'] = value + plot_displacements_menu.on_trait_change(displacement_type_fun, 'value') + + # Toggle button function + def show_options(name, value): + iterations_mode.visible = value + render_image.visible = value + plot_errors_button.visible = \ + iterations_result_wid.selected_values['image_has_gt_shape'] and value + plot_displacements.visible = value + if value: + if iterations_mode.value == 'animation': + animation_wid.visible = True + else: + first_slider_wid.visible = True + second_slider_wid.visible = True + same_axes.visible = True + update_but.visible = True + else: + animation_wid.visible = False + first_slider_wid.visible = False + second_slider_wid.visible = False + same_axes.visible = False + update_but.visible = False + show_options('', toggle_show_default) + but.on_trait_change(show_options, 'value') + + # assign general plot_function + if plot_function is not None: + def plot_function_but(name): + plot_function(name, 0) + + update_but.on_click(plot_function_but) + # Here we assign plot_function() to the slider of animation_wid, as + # we didn't do it at its creation. + animation_wid.children[1].children[0].on_trait_change(plot_function, + 'value') + render_image.on_trait_change(plot_function, 'value') + iterations_mode.on_trait_change(plot_function, 'value') + + # assign plot function of errors button + if plot_errors_function is not None: + plot_errors_button.on_click(plot_errors_function) + + # assign plot function of displacements button + if plot_displacements_function is not None: + plot_displacements_button.on_click(plot_displacements_function) + + return iterations_result_wid + + +def format_iterations_result_options(iterations_result_wid, + container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', + border_visible=True): + r""" + Function that corrects the align (style format) of a given + iterations_result_options widget. Usage example: + iterations_result_wid = iterations_result_options() + display(iterations_result_wid) + format_iterations_result_options(iterations_result_wid) + + Parameters + ---------- + iterations_result_wid : + The widget object generated by the `iterations_result_options()` + function. + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + border_visible : `bool`, optional + Defines whether to draw the border line around the widget. + """ + # format animations options + format_animation_options( + iterations_result_wid.children[1].children[0].children[1].children[0], + index_text_width='0.5cm', container_padding=container_padding, + container_margin=container_margin, container_border=container_border, + toggle_button_font_weight=toggle_button_font_weight, + border_visible=False) + + # align displacement button and drop down menu + iterations_result_wid.children[1].children[1].children[2].\ + remove_class('vbox') + iterations_result_wid.children[1].children[1].children[2].\ + add_class('hbox') + iterations_result_wid.children[1].children[1].children[2].\ + add_class('align-center') + iterations_result_wid.children[1].children[1].children[2].children[0].\ + set_css('margin-right', '0px') + iterations_result_wid.children[1].children[1].children[2].children[1].\ + set_css('margin-left', '0px') + + # align options + iterations_result_wid.children[1].children[1].remove_class('vbox') + iterations_result_wid.children[1].children[1].add_class('hbox') + iterations_result_wid.children[1].children[1].add_class('align-center') + iterations_result_wid.children[1].children[1].children[0].\ + set_css('margin-right', '30px') + iterations_result_wid.children[1].children[1].children[1].\ + set_css('margin-right', '30px') + + # align update button and same axes checkbox + iterations_result_wid.children[1].children[0].children[1].children[3].\ + remove_class('vbox') + iterations_result_wid.children[1].children[0].children[1].children[3].\ + add_class('hbox') + iterations_result_wid.children[1].children[0].children[1].children[3].children[0].\ + set_css('margin-right', '20px') + + # align sliders + iterations_result_wid.children[1].children[0].children[1].\ + add_class('align-end') + iterations_result_wid.children[1].children[0].children[1].\ + set_css('margin-bottom', '20px') + + # align sliders and iterations_mode + iterations_result_wid.children[1].children[0].remove_class('vbox') + iterations_result_wid.children[1].children[0].add_class('hbox') + iterations_result_wid.children[1].children[0].add_class('align-start') + + # align sliders and options + iterations_result_wid.children[1].add_class('align-end') + + # set toggle button font bold + iterations_result_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + + # margin and border around container widget + iterations_result_wid.set_css('padding', container_padding) + iterations_result_wid.set_css('margin', container_margin) + if border_visible: + iterations_result_wid.set_css('border', container_border) + + +def update_iterations_result_options(iterations_result_wid, + iterations_result_default): + r""" + Function that updates the state of a given iterations_result_options widget. Usage example: + iterations_result_options_default = {'n_iters': 10, + 'image_has_gt_shape': True, + 'n_points': 68, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} + iterations_result_wid = iterations_result_options(iterations_result_options_default) + display(iterations_result_wid) + format_iterations_result_options(iterations_result_wid) + iterations_result_options_default = {'n_iters': 100, + 'image_has_gt_shape': False, + 'n_points': 15, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': False, + 'subplots_enabled': False, + 'displacement_type': 'median'} + update_iterations_result_options(iterations_result_wid, iterations_result_options_default) + + Parameters + ---------- + iterations_result_wid : + The widget generated by `iterations_result_options()` function. + iterations_result_options_default : `dict` + The default options. For example: + iterations_result_options_default = {'n_iters': 10, + 'image_has_gt_shape': True, + 'n_points': 68, + 'iter_str': 'iter_', + 'selected_groups': [0], + 'render_image': True, + 'subplots_enabled': True, + 'displacement_type': 'mean'} + """ + # if image_has_gt_shape flag has actually changed from the previous value + if ('image_has_gt_shape' in iterations_result_default and + iterations_result_default['image_has_gt_shape'] != + iterations_result_wid.selected_values['image_has_gt_shape']): + # set the plot errors visibility + iterations_result_wid.children[1].children[1].children[1].visible = \ + (iterations_result_wid.children[0].value and + iterations_result_default['image_has_gt_shape']) + # store the flag + iterations_result_wid.selected_values['image_has_gt_shape'] = \ + iterations_result_default['image_has_gt_shape'] + + # if n_points has actually changed from the previous value + if ('n_points' in iterations_result_default and + iterations_result_default['n_points'] != + iterations_result_wid.selected_values['n_points']): + # change the contents of the displacement types + select_menu = OrderedDict() + select_menu['mean'] = 'mean' + select_menu['median'] = 'median' + select_menu['max'] = 'max' + select_menu['min'] = 'min' + for p in range(iterations_result_default['n_points']): + select_menu["point {}".format(p + 1)] = p + iterations_result_wid.children[1].children[1].children[2].children[1].\ + values = select_menu + # store the number of points + iterations_result_wid.selected_values['n_points'] = \ + iterations_result_default['n_points'] + + # if displacement_type has actually changed from the previous value + if ('displacement_type' in iterations_result_default and + iterations_result_default['displacement_type'] != + iterations_result_wid.selected_values['displacement_type']): + iterations_result_wid.children[1].children[1].children[2].children[1].\ + value = iterations_result_default['displacement_type'] + + # if iter_str are actually different from the previous value + if ('iter_str' in iterations_result_default and + iterations_result_default['iter_str'] != + iterations_result_wid.selected_values['iter_str']): + iterations_result_wid.selected_values['iter_str'] = \ + iterations_result_default['iter_str'] + + # if render_image are actually different from the previous value + if ('render_image' in iterations_result_default and + iterations_result_default['render_image'] != + iterations_result_wid.selected_values['render_image']): + iterations_result_wid.children[1].children[1].children[0].value = \ + iterations_result_default['render_image'] + + # if subplots_enabled are actually different from the previous value + if ('subplots_enabled' in iterations_result_default and + iterations_result_default['subplots_enabled'] != + iterations_result_wid.selected_values['subplots_enabled']): + iterations_result_wid.children[1].children[0].children[1].children[3].children[0].value = \ + not iterations_result_default['subplots_enabled'] + + # if n_iters are actually different from the previous value + if ('n_iters' in iterations_result_default and + iterations_result_default['n_iters'] != + iterations_result_wid.selected_values['n_iters']): + # change the iterations_result_wid output + iterations_result_wid.selected_values['n_iters'] = \ + iterations_result_default['n_iters'] + iterations_result_wid.selected_values['selected_groups'] = \ + _convert_iterations_to_groups( + 0, 0, iterations_result_wid.selected_values['iter_str']) + + animation_options_wid = iterations_result_wid.children[1].children[0].children[1].children[0] + # set the iterations options state + if iterations_result_default['n_iters'] == 1: + # set sliders values and visibility + for t in range(4): + if t == 0: + # first slider + iterations_result_wid.children[1].children[0].children[1].children[1].value = 0 + iterations_result_wid.children[1].children[0].children[1].children[1].max = 0 + iterations_result_wid.children[1].children[0].children[1].children[1].visible = False + elif t == 1: + # second slider + iterations_result_wid.children[1].children[0].children[1].children[2].value = 0 + iterations_result_wid.children[1].children[0].children[1].children[2].max = 0 + iterations_result_wid.children[1].children[0].children[1].children[2].visible = False + elif t == 2: + # animation slider + animation_options_wid.selected_values['index'] = 0 + animation_options_wid.selected_values['max'] = 0 + animation_options_wid.children[1].children[0].value = 0 + animation_options_wid.children[1].children[0]. max = 0 + animation_options_wid.children[1].children[0].disabled = True + animation_options_wid.children[1].children[1].children[0].children[0].disabled = True + animation_options_wid.children[1].children[1].children[0].children[1].disabled = True + animation_options_wid.children[1].children[1].children[0].children[2].disabled = True + else: + # iterations mode + iterations_result_wid.children[1].children[0].children[0].value = 'animation' + #iterations_result_wid.groups = [iter_str + "0"] + iterations_result_wid.children[1].children[0].children[0].disabled = True + else: + # set sliders max and min values + for t in range(4): + if t == 0: + # first slider + iterations_result_wid.children[1].children[0].children[1].children[1].value = 0 + iterations_result_wid.children[1].children[0].children[1].children[1].max = \ + iterations_result_default['n_iters'] - 1 + iterations_result_wid.children[1].children[0].children[1].children[1].visible = False + elif t == 1: + # second slider + iterations_result_wid.children[1].children[0].children[1].children[2].value = \ + iterations_result_default['n_iters'] - 1 + iterations_result_wid.children[1].children[0].children[1].children[2].max = \ + iterations_result_default['n_iters'] - 1 + iterations_result_wid.children[1].children[0].children[1].children[2].visible = False + elif t == 2: + # animation slider + animation_options_wid.children[1].children[0].value = 0 + animation_options_wid.children[1].children[0].max = \ + iterations_result_default['n_iters'] - 1 + animation_options_wid.selected_values['index'] = 0 + animation_options_wid.selected_values['max'] = \ + iterations_result_default['n_iters'] - 1 + animation_options_wid.children[1].children[0].disabled = \ + False + animation_options_wid.children[1].children[1].children[0].children[0].disabled = False + animation_options_wid.children[1].children[1].children[0].children[1].disabled = True + animation_options_wid.children[1].children[1].children[0].children[2].disabled = False + else: + # iterations mode + iterations_result_wid.children[1].children[0].children[0].\ + value = 'animation' + #iterations_result_wid.groups = [iter_str + "0"] + iterations_result_wid.children[1].children[0].children[0].\ + disabled = False + + +def plot_options(plot_options_default, plot_function=None, + toggle_show_visible=True, toggle_show_default=True): + r""" + Creates a widget with Plot Options. Specifically, it has: + 1) A drop down menu for curve selection. + 2) A text area for the legend entry. + 3) A checkbox that controls line's visibility. + 4) A checkbox that controls markers' visibility. + 5) Options for line colour, style and width. + 6) Options for markers face colour, edge colour, size, edge width and + style. + 7) A toggle button that controls the visibility of all the above, i.e. + the plot options. + + The structure of the widgets is the following: + plot_options_wid.children = [toggle_button, options] + options.children = [curve_menu, per_curve_options_wid] + per_curve_options_wid = ipywidgets.ContainerWidget(children=[legend_entry, + line_marker_wid]) + line_marker_wid = ipywidgets.ContainerWidget(children=[line_widget, marker_widget]) + line_widget.children = [show_line_checkbox, line_options] + marker_widget.children = [show_marker_checkbox, marker_options] + line_options.children = [linestyle, linewidth, linecolour] + marker_options.children = [markerstyle, markersize, markeredgewidth, + markerfacecolour, markeredgecolour] + + The returned widget saves the selected values in the following dictionary: + plot_options_wid.selected_options + + To fix the alignment within this widget please refer to + `format_plot_options()` function. + + Parameters + ---------- + plot_options_default : list of `dict` + A list of dictionaries with the initial selected plot options per curve. + Example: + plot_options_1={'show_line':True, + 'linewidth':2, + 'linecolour':'r', + 'linestyle':'-', + 'show_marker':True, + 'markersize':20, + 'markerfacecolour':'r', + 'markeredgecolour':'b', + 'markerstyle':'o', + 'markeredgewidth':1, + 'legend_entry':'final errors'} + plot_options_2={'show_line':False, + 'linewidth':3, + 'linecolour':'r', + 'linestyle':'-', + 'show_marker':True, + 'markersize':60, + 'markerfacecolour':[0.1, 0.2, 0.3], + 'markeredgecolour':'k', + 'markerstyle':'x', + 'markeredgewidth':1, + 'legend_entry':'initial errors'} + plot_options_default = [plot_options_1, plot_options_2] + + plot_function : `function` or None, optional + The plot function that is executed when a widgets' value changes. + If None, then nothing is assigned. + + toggle_show_default : `boolean`, optional + Defines whether the options will be visible upon construction. + + toggle_show_visible : `boolean`, optional + The visibility of the toggle button. + """ + import IPython.html.widgets as ipywidgets + # make sure that plot_options_default is a list even with one member + if not isinstance(plot_options_default, list): + plot_options_default = [plot_options_default] + + # find number of curves + n_curves = len(plot_options_default) + + # Create widgets + # toggle button + but = ipywidgets.ToggleButtonWidget(description='Plot Options', + value=toggle_show_default, + visible=toggle_show_visible) + + # select curve drop down menu + curves_dict = OrderedDict() + for k in range(n_curves): + curves_dict['Curve ' + str(k)] = k + curve_selection = ipywidgets.DropdownWidget(values=curves_dict, + value=0, + description='Select curve', + visible=n_curves > 1) + + # legend entry + legend_entry = ipywidgets.TextWidget(description='Legend entry', + value=plot_options_default[0][ + 'legend_entry']) + + # show line, show markers checkboxes + show_line = ipywidgets.CheckboxWidget(description='Show line', + value=plot_options_default[0][ + 'show_line']) + show_marker = ipywidgets.CheckboxWidget(description='Show markers', + value=plot_options_default[0][ + 'show_marker']) + + # linewidth, markersize + linewidth = ipywidgets.FloatTextWidget(description='Width', + value=plot_options_default[0][ + 'linewidth']) + markersize = ipywidgets.IntTextWidget(description='Size', + value=plot_options_default[0][ + 'markersize']) + markeredgewidth = ipywidgets.FloatTextWidget( + description='Edge width', + value=plot_options_default[0]['markeredgewidth']) + + # markerstyle + markerstyle_dict = OrderedDict() + markerstyle_dict['point'] = '.' + markerstyle_dict['pixel'] = ',' + markerstyle_dict['circle'] = 'o' + markerstyle_dict['triangle down'] = 'v' + markerstyle_dict['triangle up'] = '^' + markerstyle_dict['triangle left'] = '<' + markerstyle_dict['triangle right'] = '>' + markerstyle_dict['tri down'] = '1' + markerstyle_dict['tri up'] = '2' + markerstyle_dict['tri left'] = '3' + markerstyle_dict['tri right'] = '4' + markerstyle_dict['octagon'] = '8' + markerstyle_dict['square'] = 's' + markerstyle_dict['pentagon'] = 'p' + markerstyle_dict['star'] = '*' + markerstyle_dict['hexagon 1'] = 'h' + markerstyle_dict['hexagon 2'] = 'H' + markerstyle_dict['plus'] = '+' + markerstyle_dict['x'] = 'x' + markerstyle_dict['diamond'] = 'D' + markerstyle_dict['thin diamond'] = 'd' + markerstyle = ipywidgets.DropdownWidget(values=markerstyle_dict, + value=plot_options_default[0][ + 'markerstyle'], + description='Style') + + # linestyle + linestyle_dict = OrderedDict() + linestyle_dict['solid'] = '-' + linestyle_dict['dashed'] = '--' + linestyle_dict['dash-dot'] = '-.' + linestyle_dict['dotted'] = ':' + linestyle = ipywidgets.DropdownWidget(values=linestyle_dict, + value=plot_options_default[0][ + 'linestyle'], + description='Style') + + # colours + # do not assign the plot_function here + linecolour = colour_selection(plot_options_default[0]['linecolour'], + title='Colour') + markerfacecolour = colour_selection( + plot_options_default[0]['markerfacecolour'], + title='Face Colour') + markeredgecolour = colour_selection( + plot_options_default[0]['markeredgecolour'], + title='Edge Colour') + + # Group widgets + line_options = ipywidgets.ContainerWidget( + children=[linestyle, linewidth, linecolour]) + marker_options = ipywidgets.ContainerWidget( + children=[markerstyle, markersize, + markeredgewidth, + markerfacecolour, + markeredgecolour]) + line_wid = ipywidgets.ContainerWidget(children=[show_line, line_options]) + marker_wid = ipywidgets.ContainerWidget( + children=[show_marker, marker_options]) + line_options_options_wid = ipywidgets.ContainerWidget( + children=[line_wid, marker_wid]) + options_wid = ipywidgets.ContainerWidget(children=[legend_entry, + line_options_options_wid]) + options_and_curve_wid = ipywidgets.ContainerWidget( + children=[curve_selection, + options_wid]) + plot_options_wid = ipywidgets.ContainerWidget( + children=[but, options_and_curve_wid]) + + # initialize output + plot_options_wid.selected_options = plot_options_default + + # line options visibility + def line_options_visible(name, value): + linestyle.disabled = not value + linewidth.disabled = not value + linecolour.children[0].disabled = not value + linecolour.children[1].children[0].disabled = not value + linecolour.children[1].children[1].disabled = not value + linecolour.children[1].children[2].disabled = not value + show_line.on_trait_change(line_options_visible, 'value') + + # marker options visibility + def marker_options_visible(name, value): + markerstyle.disabled = not value + markersize.disabled = not value + markeredgewidth.disabled = not value + markerfacecolour.children[0].disabled = not value + markerfacecolour.children[1].children[0].disabled = not value + markerfacecolour.children[1].children[1].disabled = not value + markerfacecolour.children[1].children[2].disabled = not value + markeredgecolour.children[0].disabled = not value + markeredgecolour.children[1].children[0].disabled = not value + markeredgecolour.children[1].children[1].disabled = not value + markeredgecolour.children[1].children[2].disabled = not value + show_marker.on_trait_change(marker_options_visible, 'value') + + # function that gets colour selection + def get_colour(colour_wid): + if colour_wid.children[0].value == 'custom': + return [float(colour_wid.children[1].children[0].value), + float(colour_wid.children[1].children[1].value), + float(colour_wid.children[1].children[2].value)] + else: + return colour_wid.children[0].value + + # assign options + def save_legend_entry(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'legend_entry'] = str(value) + + legend_entry.on_trait_change(save_legend_entry, 'value') + + def save_show_line(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'show_line'] = value + + show_line.on_trait_change(save_show_line, 'value') + + def save_show_marker(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'show_marker'] = value + + show_marker.on_trait_change(save_show_marker, 'value') + + def save_linewidth(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'linewidth'] = float(value) + + linewidth.on_trait_change(save_linewidth, 'value') + + def save_linestyle(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'linestyle'] = value + + linestyle.on_trait_change(save_linestyle, 'value') + + def save_markersize(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markersize'] = int(value) + + markersize.on_trait_change(save_markersize, 'value') + + def save_markeredgewidth(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markeredgewidth'] = float(value) + + markeredgewidth.on_trait_change(save_markeredgewidth, 'value') + + def save_markerstyle(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markerstyle'] = value + + markerstyle.on_trait_change(save_markerstyle, 'value') + + def save_linecolour(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'linecolour'] = get_colour(linecolour) + + linecolour.children[0].on_trait_change(save_linecolour, 'value') + linecolour.children[1].children[0].on_trait_change(save_linecolour, 'value') + linecolour.children[1].children[1].on_trait_change(save_linecolour, 'value') + linecolour.children[1].children[2].on_trait_change(save_linecolour, 'value') + + def save_markerfacecolour(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markerfacecolour'] = get_colour(markerfacecolour) + + markerfacecolour.children[0].on_trait_change(save_markerfacecolour, 'value') + markerfacecolour.children[1].children[0].on_trait_change( + save_markerfacecolour, 'value') + markerfacecolour.children[1].children[1].on_trait_change( + save_markerfacecolour, 'value') + markerfacecolour.children[1].children[2].on_trait_change( + save_markerfacecolour, 'value') + + def save_markeredgecolour(name, value): + plot_options_wid.selected_options[curve_selection.value][ + 'markeredgecolour'] = get_colour(markeredgecolour) + + markeredgecolour.children[0].on_trait_change(save_markeredgecolour, 'value') + markeredgecolour.children[1].children[0].on_trait_change( + save_markeredgecolour, 'value') + markeredgecolour.children[1].children[1].on_trait_change( + save_markeredgecolour, 'value') + markeredgecolour.children[1].children[2].on_trait_change( + save_markeredgecolour, 'value') + + # set correct value to slider when drop down menu value changes + def set_options(name, value): + legend_entry.value = plot_options_wid.selected_options[value][ + 'legend_entry'] + show_line.value = plot_options_wid.selected_options[value]['show_line'] + show_marker.value = plot_options_wid.selected_options[value][ + 'show_marker'] + linewidth.value = plot_options_wid.selected_options[value]['linewidth'] + linestyle.value = plot_options_wid.selected_options[value]['linestyle'] + markersize.value = plot_options_wid.selected_options[value][ + 'markersize'] + markerstyle.value = plot_options_wid.selected_options[value][ + 'markerstyle'] + markeredgewidth.value = plot_options_wid.selected_options[value][ + 'markeredgewidth'] + default_colour = plot_options_wid.selected_options[value]['linecolour'] + if not isinstance(default_colour, str): + r_val = default_colour[0] + g_val = default_colour[1] + b_val = default_colour[2] + default_colour = 'custom' + linecolour.children[1].children[0].value = r_val + linecolour.children[1].children[1].value = g_val + linecolour.children[1].children[2].value = b_val + linecolour.children[0].value = default_colour + default_colour = plot_options_wid.selected_options[value][ + 'markerfacecolour'] + if not isinstance(default_colour, str): + r_val = default_colour[0] + g_val = default_colour[1] + b_val = default_colour[2] + default_colour = 'custom' + markerfacecolour.children[1].children[0].value = r_val + markerfacecolour.children[1].children[1].value = g_val + markerfacecolour.children[1].children[2].value = b_val + markerfacecolour.children[0].value = default_colour + default_colour = plot_options_wid.selected_options[value][ + 'markeredgecolour'] + if not isinstance(default_colour, str): + r_val = default_colour[0] + g_val = default_colour[1] + b_val = default_colour[2] + default_colour = 'custom' + markeredgecolour.children[1].children[0].value = r_val + markeredgecolour.children[1].children[1].value = g_val + markeredgecolour.children[1].children[2].value = b_val + markeredgecolour.children[0].value = default_colour + curve_selection.on_trait_change(set_options, 'value') + + # Toggle button function + def toggle_fun(name, value): + options_and_curve_wid.visible = value + toggle_fun('', toggle_show_default) + but.on_trait_change(toggle_fun, 'value') + + # assign plot_function + if plot_function is not None: + legend_entry.on_trait_change(plot_function, 'value') + show_line.on_trait_change(plot_function, 'value') + linestyle.on_trait_change(plot_function, 'value') + linewidth.on_trait_change(plot_function, 'value') + show_marker.on_trait_change(plot_function, 'value') + markerstyle.on_trait_change(plot_function, 'value') + markeredgewidth.on_trait_change(plot_function, 'value') + markersize.on_trait_change(plot_function, 'value') + linecolour.children[0].on_trait_change(plot_function, 'value') + linecolour.children[1].children[0].on_trait_change(plot_function, + 'value') + linecolour.children[1].children[1].on_trait_change(plot_function, + 'value') + linecolour.children[1].children[2].on_trait_change(plot_function, + 'value') + markerfacecolour.children[0].on_trait_change(plot_function, 'value') + markerfacecolour.children[1].children[0].on_trait_change(plot_function, + 'value') + markerfacecolour.children[1].children[1].on_trait_change(plot_function, + 'value') + markerfacecolour.children[1].children[2].on_trait_change(plot_function, + 'value') + markeredgecolour.children[0].on_trait_change(plot_function, 'value') + markeredgecolour.children[1].children[0].on_trait_change(plot_function, + 'value') + markeredgecolour.children[1].children[1].on_trait_change(plot_function, + 'value') + markeredgecolour.children[1].children[2].on_trait_change(plot_function, + 'value') + + return plot_options_wid + + +def format_plot_options(plot_options_wid, container_padding='6px', + container_margin='6px', + container_border='1px solid black', + toggle_button_font_weight='bold', border_visible=True, + suboptions_border_visible=True): + r""" + Function that corrects the align (style format) of a given figure_options + widget. Usage example: + plot_options_wid = plot_options() + display(plot_options_wid) + format_plot_options(figure_options_wid) + + Parameters + ---------- + plot_options_wid : + The widget object generated by the `figure_options()` function. + + container_padding : `str`, optional + The padding around the widget, e.g. '6px' + + container_margin : `str`, optional + The margin around the widget, e.g. '6px' + + container_border : `str`, optional + The border around the widget, e.g. '1px solid black' + + toggle_button_font_weight : `str` + The font weight of the toggle button, e.g. 'bold' + + border_visible : `boolean`, optional + Defines whether to draw the border line around the widget. + + suboptions_border_visible : `boolean`, optional + Defines whether to draw the border line around the per curve options. + """ + # align line options with checkbox + plot_options_wid.children[1].children[1].children[1].children[0]. \ + add_class('align-end') + + # align marker options with checkbox + plot_options_wid.children[1].children[1].children[1].children[1]. \ + add_class('align-end') + + # set text boxes width + plot_options_wid.children[1].children[1].children[1].children[0].children[ + 1].children[1]. \ + set_css('width', '1cm') + plot_options_wid.children[1].children[1].children[1].children[1].children[ + 1].children[1]. \ + set_css('width', '1cm') + plot_options_wid.children[1].children[1].children[1].children[1].children[ + 1].children[2]. \ + set_css('width', '1cm') + + # align line and marker options + plot_options_wid.children[1].children[1].children[1].remove_class('vbox') + plot_options_wid.children[1].children[1].children[1].add_class('hbox') + if suboptions_border_visible: + plot_options_wid.children[1].children[1].set_css('margin', + container_margin) + plot_options_wid.children[1].children[1].set_css('border', + container_border) + + # align curve selection with line and marker options + plot_options_wid.children[1].add_class('align-start') + + # format colour options + format_colour_selection( + plot_options_wid.children[1].children[1].children[1].children[ + 0].children[1].children[2]) + format_colour_selection( + plot_options_wid.children[1].children[1].children[1].children[ + 1].children[1].children[3]) + format_colour_selection( + plot_options_wid.children[1].children[1].children[1].children[ + 1].children[1].children[4]) + + # set toggle button font bold + plot_options_wid.children[0].set_css('font-weight', + toggle_button_font_weight) + + # margin and border around container widget + plot_options_wid.set_css('padding', container_padding) + plot_options_wid.set_css('margin', container_margin) + if border_visible: + plot_options_wid.set_css('border', container_border) + + +def _convert_iterations_to_groups(from_iter, to_iter, iter_str): + r""" + Function that generates a list of group labels given the range bounds and + the str to be used. + """ + return ["{}{}".format(iter_str, i) for i in range(from_iter, to_iter + 1)]