Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os: [ubuntu-latest, windows-latest, macos-13]
version: ["3.9", "3.x"]
defaults:
run:
Expand Down
193 changes: 193 additions & 0 deletions RAT/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""
Plots using the matplotlib library
"""
import matplotlib.pyplot as plt
import numpy as np
from RAT.rat_core import PlotEventData, makeSLDProfileXY


class Figure:
"""
Creates a plotting figure.
"""

def __init__(self, row: int = 1, col: int = 2):
"""
Initializes the figure and the subplots.

Parameters
----------
row : int
The number of rows in subplot
col : int
The number of columns in subplot
"""
self._fig, self._ax = \
plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
plt.show(block=False)
self._esc_pressed = False
self._close_clicked = False
self._fig.canvas.mpl_connect("key_press_event",
self._process_button_press)
self._fig.canvas.mpl_connect('close_event',
self._close)

def wait_for_close(self):
"""
Waits for the user to close the figure
using the esc key.
"""
while not (self._esc_pressed or self._close_clicked):
plt.waitforbuttonpress(timeout=0.005)
plt.close(self._fig)

def _process_button_press(self, event):
"""
Process the key_press_event.
"""
if event.key == 'escape':
self._esc_pressed = True

def _close(self, _):
"""
Process the close_event.
"""
self._close_clicked = True


def plot_errorbars(ax, x, y, err, onesided, color):
"""
Plots the error bars.

Parameters
----------
ax : matplotlib.axes._axes.Axes
The axis on which to draw errorbars
x : np.ndarray
The shifted data x axis data
y : np.ndarray
The shifted data y axis data
err : np.ndarray
The shifted data e data
onesided : bool
A boolean to indicate whether to draw one sided errorbars
color : str
The hex representing the color of the errorbars
"""
y_error = [[0]*len(err), err] if onesided else err
ax.errorbar(x=x,
y=y,
yerr=y_error,
fmt='none',
ecolor=color,
elinewidth=1,
capsize=0)
ax.scatter(x=x, y=y, s=3, marker="o", color=color)


def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
"""
Clears the previous plots and updates the ref and SLD plots.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : Figure
The figure class that has two subplots
delay : bool
Controls whether to delay 0.005s after plot is created

Returns
-------
fig : Figure
The figure class that has two subplots
"""
if fig is None:
fig = Figure()
elif fig._ax.shape != (2,):
fig._fig.clf()
fig._ax = fig._fig.subplots(1, 2)

ref_plot = fig._ax[0]
sld_plot = fig._ax[1]

# Clears the previous plots
ref_plot.cla()
sld_plot.cla()

for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
data.shiftedData,
data.sldProfiles,
data.resampledLayers)):

r, sd, sld, layer = map(lambda x: x[0], (r, sd, sld, layer))

# Calculate the divisor
div = 1 if i == 0 else 2**(4*(i+1))

# Plot the reflectivity on plot (1,1)
ref_plot.plot(r[:, 0],
r[:, 1]/div,
label=f'ref {i+1}',
linewidth=2)
color = ref_plot.get_lines()[-1].get_color()

if data.dataPresent[i]:
sd_x = sd[:, 0]
sd_y, sd_e = map(lambda x: x/div, (sd[:, 1], sd[:, 2]))

# Plot the errorbars
indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0])
sd_x_r, sd_y_r, sd_e_r = map(lambda x:
np.delete(x, indices_removed),
(sd_x, sd_y, sd_e))
plot_errorbars(ref_plot, sd_x_r, sd_y_r, sd_e_r, False, color)

# Plot one sided errorbars
indices_selected = [x for x in indices_removed
if x not in np.nonzero(sd_y < 0)[0]]
sd_x_s, sd_y_s, sd_e_s = map(lambda x:
[x[i] for i in indices_selected],
(sd_x, sd_y, sd_e))
plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color)

# Plot the slds on plot (1,2)
for j in range(1, sld.shape[1]):
sld_plot.plot(sld[:, 0],
sld[:, j],
label=f'sld {i+1}',
color=color,
linewidth=2)

if data.resample[i] == 1 or data.modelType == 'custom xy':
new = makeSLDProfileXY(layer[0, 1],
layer[-1, 1],
data.subRoughs[i],
layer,
len(layer),
1.0)

sld_plot.plot([row[0]-49 for row in new],
[row[1] for row in new],
color=color,
linewidth=1)

# Format the axis
ref_plot.set_yscale('log')
ref_plot.set_xscale('log')
ref_plot.set_xlabel('Qz')
ref_plot.set_ylabel('Ref')
ref_plot.legend()
ref_plot.grid()

sld_plot.set_xlabel('Z')
sld_plot.set_ylabel('SLD')
sld_plot.legend()
sld_plot.grid()

if delay:
plt.pause(0.005)

return fig
26 changes: 25 additions & 1 deletion cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ setup_pybind11(cfg)
#include "RAT/RATMain_initialize.h"
#include "RAT/RATMain_terminate.h"
#include "RAT/RATMain_types.h"
#include "RAT/makeSLDProfileXY.h"
#include "RAT/classHandle.hpp"
#include "RAT/dylib.hpp"
#include "RAT/events/eventManager.h"
Expand Down Expand Up @@ -1165,6 +1166,27 @@ py::tuple RATMain(const ProblemDefinition& problem_def, const Cells& cells, cons
bayesResultsFromStruct8T(bayesResults));
}

py::array_t<real_T> makeSLDProfileXY(real_T bulk_in,
real_T bulk_out,
real_T ssub,
const py::array_t<real_T> &layers,
real_T number_of_layers,
real_T repeats)
{
coder::array<real_T, 2U> out;
coder::array<real_T, 2U> layers_array = pyArrayToRatArray2d(layers);
RAT::makeSLDProfileXY(bulk_in,
bulk_out,
ssub,
layers_array,
number_of_layers,
repeats,
out);

return pyArrayFromRatArray2d(out);

}

class Module
{
public:
Expand Down Expand Up @@ -1434,5 +1456,7 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("fitLimits", &ProblemDefinition::fitLimits)
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits);

m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");
m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");

m.def("makeSLDProfileXY", &makeSLDProfileXY, "Creates the profiles for the SLD plots");
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pybind11 >= 2.4
pydantic >= 2.4.2, <= 2.6.4
pytest >= 7.4.0
pytest-cov >= 4.1.0
matplotlib >= 3.8.3
StrEnum >= 0.4.15; python_version < '3.11'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def build_libraries(self, libraries):
libraries = [libevent],
ext_modules = ext_modules,
python_requires = '>=3.9',
install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4'],
install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4', 'matplotlib >= 3.8.3'],
extras_require = {':python_version < "3.11"': ['StrEnum >= 0.4.15'],
'Dev': ['pytest>=7.4.0', 'pytest-cov>=4.1.0'],
'Matlab_latest': ['matlabengine'],
Expand Down
Binary file added tests/test_data/plotting_data.pickle
Binary file not shown.
Loading