diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 439f2f0d..6d72e624 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest, windows-latest, macos-13] version: ["3.9", "3.x"] defaults: run: diff --git a/RAT/utils/plotting.py b/RAT/utils/plotting.py new file mode 100644 index 00000000..d53840b6 --- /dev/null +++ b/RAT/utils/plotting.py @@ -0,0 +1,193 @@ +""" +Plots using the matplotlib library +""" +import matplotlib.pyplot as plt +import numpy as np +from RAT.rat_core import PlotEventData, makeSLDProfileXY + + +class Figure: + """ + Creates a plotting figure. + """ + + def __init__(self, row: int = 1, col: int = 2): + """ + Initializes the figure and the subplots. + + Parameters + ---------- + row : int + The number of rows in subplot + col : int + The number of columns in subplot + """ + self._fig, self._ax = \ + plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)") + plt.show(block=False) + self._esc_pressed = False + self._close_clicked = False + self._fig.canvas.mpl_connect("key_press_event", + self._process_button_press) + self._fig.canvas.mpl_connect('close_event', + self._close) + + def wait_for_close(self): + """ + Waits for the user to close the figure + using the esc key. + """ + while not (self._esc_pressed or self._close_clicked): + plt.waitforbuttonpress(timeout=0.005) + plt.close(self._fig) + + def _process_button_press(self, event): + """ + Process the key_press_event. + """ + if event.key == 'escape': + self._esc_pressed = True + + def _close(self, _): + """ + Process the close_event. + """ + self._close_clicked = True + + +def plot_errorbars(ax, x, y, err, onesided, color): + """ + Plots the error bars. + + Parameters + ---------- + ax : matplotlib.axes._axes.Axes + The axis on which to draw errorbars + x : np.ndarray + The shifted data x axis data + y : np.ndarray + The shifted data y axis data + err : np.ndarray + The shifted data e data + onesided : bool + A boolean to indicate whether to draw one sided errorbars + color : str + The hex representing the color of the errorbars + """ + y_error = [[0]*len(err), err] if onesided else err + ax.errorbar(x=x, + y=y, + yerr=y_error, + fmt='none', + ecolor=color, + elinewidth=1, + capsize=0) + ax.scatter(x=x, y=y, s=3, marker="o", color=color) + + +def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True): + """ + Clears the previous plots and updates the ref and SLD plots. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + fig : Figure + The figure class that has two subplots + delay : bool + Controls whether to delay 0.005s after plot is created + + Returns + ------- + fig : Figure + The figure class that has two subplots + """ + if fig is None: + fig = Figure() + elif fig._ax.shape != (2,): + fig._fig.clf() + fig._ax = fig._fig.subplots(1, 2) + + ref_plot = fig._ax[0] + sld_plot = fig._ax[1] + + # Clears the previous plots + ref_plot.cla() + sld_plot.cla() + + for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity, + data.shiftedData, + data.sldProfiles, + data.resampledLayers)): + + r, sd, sld, layer = map(lambda x: x[0], (r, sd, sld, layer)) + + # Calculate the divisor + div = 1 if i == 0 else 2**(4*(i+1)) + + # Plot the reflectivity on plot (1,1) + ref_plot.plot(r[:, 0], + r[:, 1]/div, + label=f'ref {i+1}', + linewidth=2) + color = ref_plot.get_lines()[-1].get_color() + + if data.dataPresent[i]: + sd_x = sd[:, 0] + sd_y, sd_e = map(lambda x: x/div, (sd[:, 1], sd[:, 2])) + + # Plot the errorbars + indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0]) + sd_x_r, sd_y_r, sd_e_r = map(lambda x: + np.delete(x, indices_removed), + (sd_x, sd_y, sd_e)) + plot_errorbars(ref_plot, sd_x_r, sd_y_r, sd_e_r, False, color) + + # Plot one sided errorbars + indices_selected = [x for x in indices_removed + if x not in np.nonzero(sd_y < 0)[0]] + sd_x_s, sd_y_s, sd_e_s = map(lambda x: + [x[i] for i in indices_selected], + (sd_x, sd_y, sd_e)) + plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color) + + # Plot the slds on plot (1,2) + for j in range(1, sld.shape[1]): + sld_plot.plot(sld[:, 0], + sld[:, j], + label=f'sld {i+1}', + color=color, + linewidth=2) + + if data.resample[i] == 1 or data.modelType == 'custom xy': + new = makeSLDProfileXY(layer[0, 1], + layer[-1, 1], + data.subRoughs[i], + layer, + len(layer), + 1.0) + + sld_plot.plot([row[0]-49 for row in new], + [row[1] for row in new], + color=color, + linewidth=1) + + # Format the axis + ref_plot.set_yscale('log') + ref_plot.set_xscale('log') + ref_plot.set_xlabel('Qz') + ref_plot.set_ylabel('Ref') + ref_plot.legend() + ref_plot.grid() + + sld_plot.set_xlabel('Z') + sld_plot.set_ylabel('SLD') + sld_plot.legend() + sld_plot.grid() + + if delay: + plt.pause(0.005) + + return fig diff --git a/cpp/rat.cpp b/cpp/rat.cpp index bb281915..c1f829a4 100644 --- a/cpp/rat.cpp +++ b/cpp/rat.cpp @@ -16,6 +16,7 @@ setup_pybind11(cfg) #include "RAT/RATMain_initialize.h" #include "RAT/RATMain_terminate.h" #include "RAT/RATMain_types.h" +#include "RAT/makeSLDProfileXY.h" #include "RAT/classHandle.hpp" #include "RAT/dylib.hpp" #include "RAT/events/eventManager.h" @@ -1165,6 +1166,27 @@ py::tuple RATMain(const ProblemDefinition& problem_def, const Cells& cells, cons bayesResultsFromStruct8T(bayesResults)); } +py::array_t makeSLDProfileXY(real_T bulk_in, + real_T bulk_out, + real_T ssub, + const py::array_t &layers, + real_T number_of_layers, + real_T repeats) +{ + coder::array out; + coder::array layers_array = pyArrayToRatArray2d(layers); + RAT::makeSLDProfileXY(bulk_in, + bulk_out, + ssub, + layers_array, + number_of_layers, + repeats, + out); + + return pyArrayFromRatArray2d(out); + +} + class Module { public: @@ -1434,5 +1456,7 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("fitLimits", &ProblemDefinition::fitLimits) .def_readwrite("otherLimits", &ProblemDefinition::otherLimits); - m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation."); + m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation."); + + m.def("makeSLDProfileXY", &makeSLDProfileXY, "Creates the profiles for the SLD plots"); } diff --git a/requirements.txt b/requirements.txt index 6504da7f..d86d2294 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ pybind11 >= 2.4 pydantic >= 2.4.2, <= 2.6.4 pytest >= 7.4.0 pytest-cov >= 4.1.0 +matplotlib >= 3.8.3 StrEnum >= 0.4.15; python_version < '3.11' diff --git a/setup.py b/setup.py index 2a6fdadd..da896906 100644 --- a/setup.py +++ b/setup.py @@ -159,7 +159,7 @@ def build_libraries(self, libraries): libraries = [libevent], ext_modules = ext_modules, python_requires = '>=3.9', - install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4'], + install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4', 'matplotlib >= 3.8.3'], extras_require = {':python_version < "3.11"': ['StrEnum >= 0.4.15'], 'Dev': ['pytest>=7.4.0', 'pytest-cov>=4.1.0'], 'Matlab_latest': ['matlabengine'], diff --git a/tests/test_data/plotting_data.pickle b/tests/test_data/plotting_data.pickle new file mode 100644 index 00000000..1202e6ae Binary files /dev/null and b/tests/test_data/plotting_data.pickle differ diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 00000000..62e8dd5e --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,179 @@ +import os +import re +import csv +import pytest +import pickle +from unittest.mock import patch +from unittest.mock import MagicMock +import numpy as np +import matplotlib.pyplot as plt +from RAT.rat_core import PlotEventData +from RAT.utils.plotting import Figure, plot_ref_sld + + +TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'test_data') + + +def data() -> PlotEventData: + """ + Creates the fixture for the tests. + """ + data_path = os.path.join(TEST_DIR_PATH, 'plotting_data.pickle') + with open(data_path, 'rb') as f: + loaded_data = pickle.load(f) + + data = PlotEventData() + data.modelType = loaded_data['modelType'] + data.dataPresent = loaded_data['dataPresent'] + data.subRoughs = loaded_data['subRoughs'] + data.resample = loaded_data['resample'] + data.resampledLayers = loaded_data['resampledLayers'] + data.reflectivity = loaded_data['reflectivity'] + data.shiftedData = loaded_data['shiftedData'] + data.sldProfiles = loaded_data['sldProfiles'] + return data + + +@pytest.fixture +def fig() -> Figure: + """ + Creates the fixture for the tests. + """ + plt.close('all') + figure = Figure(1, 3) + fig = plot_ref_sld(fig=figure, data=data()) + return fig + + +def test_figure_axis_formating(fig: Figure) -> None: + """ + Tests the axis formating of the figure. + """ + ref_plot = fig._ax[0] + sld_plot = fig._ax[1] + + assert fig._fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) + assert fig._ax.shape == (2,) + + assert ref_plot.get_xlabel() == "Qz" + assert ref_plot.get_xscale() == "log" + assert ref_plot.get_ylabel() == "Ref" + assert ref_plot.get_yscale() == "log" + assert [label._text for label in ref_plot.get_legend().texts] == ['ref 1', 'ref 2', 'ref 3'] + + assert sld_plot.get_xlabel() == "Z" + assert sld_plot.get_xscale() == "linear" + assert sld_plot.get_ylabel() == "SLD" + assert sld_plot.get_yscale() == "linear" + assert [label._text for label in sld_plot.get_legend().texts] == ['sld 1', 'sld 2', 'sld 3'] + + +def test_figure_color_formating(fig: Figure) -> None: + """ + Tests the color formating of the figure. + """ + ref_plot = fig._ax[0] + sld_plot = fig._ax[1] + + assert len(ref_plot.get_lines()) == 3 + assert len(sld_plot.get_lines()) == 6 + + for axis_ix in range(len(ref_plot.get_lines())): + ax1 = axis_ix*2 + ax2 = ax1 + 1 + + # Tests whether the color of the line and the errorbars match on the ref_plot + assert (ref_plot.containers[ax1][2][0]._original_edgecolor == + ref_plot.containers[ax2][2][0]._original_edgecolor == + ref_plot.get_lines()[axis_ix].get_color()) + + # Tests whether the color of the sld and resampled_sld match on the sld_plot + assert (sld_plot.get_lines()[ax1].get_color() == + sld_plot.get_lines()[ax2].get_color()) + + +def test_eventhandlers_linked_to_figure(fig: Figure) -> None: + """ + Tests whether the eventhandlers for close_event + and key_press_event in the figure are linked to the + class methods. + """ + pattern = r'\(([^\]]+)\)' + + for ix, val in fig._fig.canvas.callbacks.callbacks['close_event'].items(): + if str(type(val)) == "": + break + canvas_close_event_callback = fig._fig.canvas.callbacks.callbacks['close_event'][ix]._func_ref.__repr__() + close_event_callback = re.findall(pattern, + canvas_close_event_callback)[0] + assert close_event_callback == "_close" + assert hasattr(Figure, "_close") + + for ix, val in fig._fig.canvas.callbacks.callbacks['key_press_event'].items(): + if str(type(val)) == "": + break + canvas_key_press_event_callback = fig._fig.canvas.callbacks.callbacks['key_press_event'][ix]._func_ref.__repr__() + key_press_event_callback = re.findall(pattern, + canvas_key_press_event_callback)[0] + assert key_press_event_callback == "_process_button_press" + assert hasattr(Figure, "_process_button_press") + + +def test_eventhandler_variable_update(fig: Figure) -> None: + """ + Tests whether the eventhandlers for close_event + and key_press_event update variables that stop + while loop in wait_for_close. + """ + assert not fig._esc_pressed + on_key_mock_event = type('MockEvent', (object,), {'key': 'escape'}) + fig._process_button_press(on_key_mock_event) + assert fig._esc_pressed + + assert not fig._close_clicked + fig._close('test') + assert fig._close_clicked + + +@patch("RAT.utils.plotting.plt.waitforbuttonpress") +def test_wait_for_close(mock: MagicMock, fig: Figure) -> None: + """ + Tests the _wait_for_close method stops the + while loop when _esc_pressed is True. + """ + def mock_wait_for_button_press(timeout): + fig._esc_pressed = True + + mock.side_effect = mock_wait_for_button_press + assert not fig._esc_pressed + fig.wait_for_close() + assert fig._esc_pressed + + +@patch("RAT.utils.plotting.makeSLDProfileXY") +def test_sld_profile_function_call(mock: MagicMock) -> None: + """ + Tests the makeSLDProfileXY function called with + correct args. + """ + plot_ref_sld(data()) + + assert mock.call_count == 3 + assert mock.call_args_list[0].args[0] == 2.07e-06 + assert mock.call_args_list[0].args[1] == 6.28e-06 + assert mock.call_args_list[0].args[2] == 0.0 + assert mock.call_args_list[0].args[4] == 82 + assert mock.call_args_list[0].args[5] == 1.0 + + assert mock.call_args_list[1].args[0] == 2.07e-06 + assert mock.call_args_list[1].args[1] == 1.83e-06 + assert mock.call_args_list[1].args[2] == 0.0 + assert mock.call_args_list[1].args[4] == 128 + assert mock.call_args_list[1].args[5] == 1.0 + + assert mock.call_args_list[2].args[0] == 2.07e-06 + assert mock.call_args_list[2].args[1] == -5.87e-07 + assert mock.call_args_list[2].args[2] == 0.0 + assert mock.call_args_list[2].args[4] == 153 + assert mock.call_args_list[2].args[5] == 1.0