Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, init_list: Union[Sequence[object], object] = None, name_field
# Set class to be used for this instance of the ClassList, checking that all elements of the input list are of
# the same type and have unique values of the specified name_field
if init_list:
self._class_handle = type(init_list[0])
self._class_handle = self._determine_class_handle(init_list)
self._check_classes(init_list)
self._check_unique_name_fields(init_list)

Expand All @@ -61,15 +61,15 @@ def __repr__(self):
output = repr(self.data)
return output

def __setitem__(self, index: int, set_dict: dict[str, Any]) -> None:
"""Assign the values of an existing object's attributes using a dictionary."""
self._setitem(index, set_dict)
def __setitem__(self, index: int, item: 'RAT.models') -> None:
"""Replace the object at an existing index of the ClassList."""
self._setitem(index, item)

def _setitem(self, index: int, set_dict: dict[str, Any]) -> None:
def _setitem(self, index: int, item: 'RAT.models') -> None:
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._validate_name_field(set_dict)
for key, value in set_dict.items():
setattr(self.data[index], key, value)
self._check_classes(self + [item])
self._check_unique_name_fields(self + [item])
self.data[index] = item

def __delitem__(self, index: int) -> None:
"""Delete an object from the list by index."""
Expand All @@ -85,8 +85,10 @@ def __iadd__(self, other: Sequence[object]) -> 'ClassList':

def _iadd(self, other: Sequence[object]) -> 'ClassList':
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
self._class_handle = type(other[0])
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
super().__iadd__(other)
Expand Down Expand Up @@ -201,12 +203,20 @@ def index(self, item: Union[object, str], *args) -> int:

def extend(self, other: Sequence[object]) -> None:
"""Extend the ClassList by adding another sequence."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
self._class_handle = type(other[0])
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
self.data.extend(other)

def set_fields(self, index: int, **kwargs) -> None:
"""Assign the values of an existing object's attributes using keyword arguments."""
self._validate_name_field(kwargs)
for key, value in kwargs.items():
setattr(self.data[index], key, value)

def get_names(self) -> list[str]:
"""Return a list of the values of the name_field attribute of each class object in the list.

Expand Down Expand Up @@ -302,3 +312,28 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
object with that value of the name_field attribute cannot be found.
"""
return next((model for model in self.data if getattr(model, self.name_field) == value), value)

@staticmethod
def _determine_class_handle(input_list: Sequence[object]):
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
element which satisfies "issubclass" for all of the other elements.

Parameters
----------
input_list : Sequence [object]
A list of instances to populate the ClassList.

Returns
-------
class_handle : type
The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other
elements.
"""
for this_element in input_list:
if all([issubclass(type(instance), type(this_element)) for instance in input_list]):
class_handle = type(this_element)
break
else:
class_handle = type(input_list[0])

return class_handle
3 changes: 2 additions & 1 deletion RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def model_post_init(self, __context: Any) -> None:

# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
# model, handle errors and reset previous values if necessary.
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend',
'set_fields']
for class_list in class_lists:
attribute = getattr(self, class_list)
for methodName in methods_to_wrap:
Expand Down
95 changes: 80 additions & 15 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from RAT.classlist import ClassList
from tests.utils import InputAttributes
from tests.utils import InputAttributes, SubInputAttributes


@pytest.fixture
Expand Down Expand Up @@ -59,7 +59,21 @@ def test_input_sequence(self, input_sequence: Sequence[object]) -> None:
"""
class_list = ClassList(input_sequence)
assert class_list.data == list(input_sequence)
assert isinstance(input_sequence[-1], class_list._class_handle)
for element in input_sequence:
assert isinstance(element, class_list._class_handle)

@pytest.mark.parametrize("input_sequence", [
([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')]),
([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')]),
])
def test_input_sequence_subclass(self, input_sequence: Sequence[object]) -> None:
"""For an input of a sequence containing objects of a class and its subclasses, the ClassList should be a list
equal to the input sequence, and _class_handle should be set to the type of the parent class.
"""
class_list = ClassList(input_sequence)
assert class_list.data == list(input_sequence)
for element in input_sequence:
assert isinstance(element, class_list._class_handle)

@pytest.mark.parametrize("empty_input", [([]), (())])
def test_empty_input(self, empty_input: Sequence[object]) -> None:
Expand Down Expand Up @@ -119,26 +133,33 @@ def test_repr_empty_classlist() -> None:
assert repr(ClassList()) == repr([])


@pytest.mark.parametrize(["new_values", "expected_classlist"], [
({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
({'name': 'John', 'surname': 'Luther'},
@pytest.mark.parametrize(["new_item", "expected_classlist"], [
(InputAttributes(name='Eve'), ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
(InputAttributes(name='John', surname='Luther'),
ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])),
])
def test_setitem(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList') -> None:
"""We should be able to set values in an element of a ClassList using a dictionary."""
def test_setitem(two_name_class_list: ClassList, new_item: InputAttributes, expected_classlist: ClassList) -> None:
"""We should be able to set values in an element of a ClassList using a new object."""
class_list = two_name_class_list
class_list[0] = new_values
class_list[0] = new_item
assert class_list == expected_classlist


@pytest.mark.parametrize("new_item", [
(InputAttributes(name='Bob')),
])
def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_item: InputAttributes) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"):
two_name_class_list[0] = new_item


@pytest.mark.parametrize("new_values", [
({'name': 'Bob'}),
'Bob',
])
def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
def test_setitem_different_classes(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList"):
with pytest.raises(ValueError, match=f"Input list contains elements of type other than 'InputAttributes'"):
two_name_class_list[0] = new_values


Expand All @@ -160,9 +181,11 @@ def test_delitem_not_present(two_name_class_list: 'ClassList') -> None:
(ClassList(InputAttributes(name='Eve'))),
([InputAttributes(name='Eve')]),
(InputAttributes(name='Eve'),),
(InputAttributes(name='Eve')),
])
def test_iadd(two_name_class_list: 'ClassList', added_list: Iterable, three_name_class_list: 'ClassList') -> None:
"""We should be able to use the "+=" operator to add iterables to a ClassList."""
"""We should be able to use the "+=" operator to add iterables to a ClassList. Individual objects should be wrapped
in a list before being added."""
class_list = two_name_class_list
class_list += added_list
assert class_list == three_name_class_list
Expand Down Expand Up @@ -439,9 +462,11 @@ def test_index_not_present(two_name_class_list: 'ClassList', index_value: Union[
(ClassList(InputAttributes(name='Eve'))),
([InputAttributes(name='Eve')]),
(InputAttributes(name='Eve'),),
(InputAttributes(name='Eve')),
])
def test_extend(two_name_class_list: 'ClassList', extended_list: Sequence, three_name_class_list: 'ClassList') -> None:
"""We should be able to extend a ClassList using another ClassList or a sequence"""
"""We should be able to extend a ClassList using another ClassList or a sequence. Individual objects should be
wrapped in a list before being added."""
class_list = two_name_class_list
class_list.extend(extended_list)
assert class_list == three_name_class_list
Expand All @@ -460,6 +485,30 @@ def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: 'C
assert isinstance(extended_list[-1], class_list._class_handle)


@pytest.mark.parametrize(["new_values", "expected_classlist"], [
({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
({'name': 'John', 'surname': 'Luther'},
ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])),
])
def test_set_fields(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList')\
-> None:
"""We should be able to set field values in an element of a ClassList using keyword arguments."""
class_list = two_name_class_list
class_list.set_fields(0, **new_values)
assert class_list == expected_classlist


@pytest.mark.parametrize("new_values", [
({'name': 'Bob'}),
])
def test_set_fields_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList"):
two_name_class_list.set_fields(0, **new_values)


@pytest.mark.parametrize(["class_list", "expected_names"], [
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), ["Alice", "Bob"]),
(ClassList([InputAttributes(id='Alice'), InputAttributes(id='Bob')], name_field='id'), ["Alice", "Bob"]),
Expand Down Expand Up @@ -563,3 +612,19 @@ def test__get_item_from_name_field(two_name_class_list: 'ClassList',
If the value is not the name_field of an object defined in the ClassList, we should return the value.
"""
assert two_name_class_list._get_item_from_name_field(value) == expected_output


@pytest.mark.parametrize(["input_list", "expected_type"], [
([InputAttributes(name='Alice')], InputAttributes),
([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')], InputAttributes),
([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')], InputAttributes),
([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob')], SubInputAttributes),
([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob'), InputAttributes(name='Eve')], InputAttributes),
([InputAttributes(name='Alice'), dict(name='Bob')], InputAttributes),
([dict(name='Alice'), InputAttributes(name='Bob')], dict),
])
def test_determine_class_handle(input_list: 'ClassList', expected_type: type) -> None:
"""The _class_handle for the ClassList should be the type that satisfies the condition "isinstance(element, type)"
for all elements in the ClassList.
"""
assert ClassList._determine_class_handle(input_list) == expected_type
6 changes: 3 additions & 3 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def test_project():
"""Add parameters to the default project, so each ClassList can be tested properly."""
test_project = RAT.project.Project()
test_project.data[0] = {'data': np.array([[1, 1, 1]])}
test_project.data.set_fields(0, data=np.array([[1, 1, 1]]))
test_project.parameters.append(name='Test SLD')
test_project.custom_files.append(name='Test Custom File')
test_project.layers.append(name='Test Layer', SLD='Test SLD')
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_rename_models(test_project, model: str, field: str) -> None:
"""When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere
in the project.
"""
getattr(test_project, model)[-1] = {'name': 'New Name'}
getattr(test_project, model).set_fields(-1, name='New Name')
attribute = RAT.project.model_names_used_in[model].attribute
assert getattr(getattr(test_project, attribute)[-1], field) == 'New Name'

Expand Down Expand Up @@ -307,7 +307,7 @@ def test_wrap_set(test_project, class_list: str, field: str) -> None:
orig_class_list = copy.deepcopy(test_attribute)

with contextlib.redirect_stdout(io.StringIO()) as print_str:
test_attribute[0] = {field: 'undefined'}
test_attribute.set_fields(0, **{field: 'undefined'})
assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in '
f'the "{field}" field of "{class_list}" must be defined in '
f'"{RAT.project.values_defined_in[class_list+"."+field]}".\033[0m\n')
Expand Down
5 changes: 5 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ def __eq__(self, other: Any):
if isinstance(other, InputAttributes):
return self.__dict__ == other.__dict__
return False


class SubInputAttributes(InputAttributes):
"""Trivial subclass of InputAttributes"""
pass