Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,52 @@ def check_indices(problem: ProblemDefinition) -> None:
)


class FileHandles:
"""Class to defer creation of custom file handles.

Parameters
----------
files : ClassList[CustomFile]
A list of custom file models.
"""

def __init__(self, files):
self.index = 0
self.files = [*files]

def __iter__(self):
self.index = 0
return self

def get_handle(self, index):
"""Returns file handle for a given custom file.

Parameters
----------
index : int
The index of the custom file.

"""
custom_file = self.files[index]
full_path = os.path.join(custom_file.path, custom_file.filename)
if custom_file.language == Languages.Python:
file_handle = get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path)
elif custom_file.language == Languages.Matlab:
file_handle = RATapi.wrappers.MatlabWrapper(full_path).getHandle()
elif custom_file.language == Languages.Cpp:
file_handle = RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle()

return file_handle

def __next__(self):
if self.index < len(self.files):
custom_file = self.get_handle(self.index)
self.index += 1
return custom_file
else:
raise StopIteration


def make_cells(project: RATapi.Project) -> Cells:
"""Constructs the cells input required for the compiled RAT code.

Expand Down Expand Up @@ -344,16 +390,6 @@ def make_cells(project: RATapi.Project) -> Cells:
else:
simulation_limits.append([0.0, 0.0])

file_handles = []
for custom_file in project.custom_files:
full_path = os.path.join(custom_file.path, custom_file.filename)
if custom_file.language == Languages.Python:
file_handles.append(get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path))
elif custom_file.language == Languages.Matlab:
file_handles.append(RATapi.wrappers.MatlabWrapper(full_path).getHandle())
elif custom_file.language == Languages.Cpp:
file_handles.append(RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle())

# Populate the set of cells
cells = Cells()
cells.f1 = [[0, 1]] * len(project.contrasts) # This is marked as "to do" in RAT
Expand All @@ -369,7 +405,7 @@ def make_cells(project: RATapi.Project) -> Cells:
cells.f11 = [param.name for param in project.bulk_in]
cells.f12 = [param.name for param in project.bulk_out]
cells.f13 = [param.name for param in project.resolution_parameters]
cells.f14 = file_handles
cells.f14 = FileHandles(project.custom_files)
cells.f15 = [param.type for param in project.backgrounds]
cells.f16 = [param.type for param in project.resolutions]

Expand Down
11 changes: 6 additions & 5 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ struct Cells {
py::list f11;
py::list f12;
py::list f13;
py::list f14;
py::object f14;
py::list f15;
py::list f16;
py::list f17;
Expand Down Expand Up @@ -844,12 +844,13 @@ coder::array<RAT::cell_wrap_6, 2U> pyListToRatCellWrap6(py::list values)
return result;
}

coder::array<RAT::cell_wrap_6, 2U> py_function_array_to_rat_cell_wrap_6(py::list values)
coder::array<RAT::cell_wrap_6, 2U> py_function_array_to_rat_cell_wrap_6(py::object values)
{
auto handles = py::cast<py::list>(values);
coder::array<RAT::cell_wrap_6, 2U> result;
result.set_size(1, values.size());
result.set_size(1, handles.size());
int32_T idx {0};
for (py::handle array: values)
for (py::handle array: handles)
{
auto func = py::cast<py::function>(array);
std::string func_ptr = convertPtr2String<CallbackInterface>(new Library(func));
Expand Down Expand Up @@ -1585,7 +1586,7 @@ PYBIND11_MODULE(rat_core, m) {
cell.f11 = t[10].cast<py::list>();
cell.f12 = t[11].cast<py::list>();
cell.f13 = t[12].cast<py::list>();
cell.f14 = t[13].cast<py::list>();
cell.f14 = t[13].cast<py::object>();
cell.f15 = t[14].cast<py::list>();
cell.f16 = t[15].cast<py::list>();
cell.f17 = t[16].cast<py::list>();
Expand Down
60 changes: 21 additions & 39 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,25 +624,7 @@ def test_make_input(test_project, test_problem, test_cells, test_limits, test_pr
"domainRatio",
]

mocked_matlab_future = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_future.result.return_value = mocked_engine

with mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"loader",
mocked_matlab_future,
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
RATapi.inputs,
"get_python_handle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"getHandle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls())

problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls())
problem = pickle.loads(pickle.dumps(problem))
check_problem_equal(problem, test_problem)
cells = pickle.loads(pickle.dumps(cells))
Expand Down Expand Up @@ -768,25 +750,7 @@ def test_make_cells(test_project, test_cells, request) -> None:
"""The cells object should be populated according to the input project object."""
test_project = request.getfixturevalue(test_project)
test_cells = request.getfixturevalue(test_cells)

mocked_matlab_future = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_future.result.return_value = mocked_engine
with mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"loader",
mocked_matlab_future,
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
RATapi.inputs,
"get_python_handle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"getHandle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
cells = make_cells(test_project)

cells = make_cells(test_project)
check_cells_equal(cells, test_cells)


Expand Down Expand Up @@ -865,7 +829,25 @@ def check_cells_equal(actual_cells, expected_cells) -> None:
"NaN" if np.isnan(el) else el for entry in actual_cells.f6 for el in entry
] == ["NaN" if np.isnan(el) else el for entry in expected_cells.f6 for el in entry]

for index in chain(range(3, 6), range(7, 21)):
mocked_matlab_future = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_future.result.return_value = mocked_engine
with mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"loader",
mocked_matlab_future,
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
RATapi.inputs,
"get_python_handle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"getHandle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
assert list(actual_cells.f14) == expected_cells.f14

for index in chain(range(3, 6), range(7, 14), range(15, 21)):
field = f"f{index}"
assert getattr(actual_cells, field) == getattr(expected_cells, field)

Expand Down