diff --git a/RAT/classlist.py b/RAT/classlist.py index 263f6e82..41fdd155 100644 --- a/RAT/classlist.py +++ b/RAT/classlist.py @@ -42,7 +42,7 @@ def __init__(self, init_list: Union[Sequence[object], object] = None, name_field # Set class to be used for this instance of the ClassList, checking that all elements of the input list are of # the same type and have unique values of the specified name_field if init_list: - self._class_handle = type(init_list[0]) + self._class_handle = self._determine_class_handle(init_list) self._check_classes(init_list) self._check_unique_name_fields(init_list) @@ -61,15 +61,15 @@ def __repr__(self): output = repr(self.data) return output - def __setitem__(self, index: int, set_dict: dict[str, Any]) -> None: - """Assign the values of an existing object's attributes using a dictionary.""" - self._setitem(index, set_dict) + def __setitem__(self, index: int, item: 'RAT.models') -> None: + """Replace the object at an existing index of the ClassList.""" + self._setitem(index, item) - def _setitem(self, index: int, set_dict: dict[str, Any]) -> None: + def _setitem(self, index: int, item: 'RAT.models') -> None: """Auxiliary routine of "__setitem__" used to enable wrapping.""" - self._validate_name_field(set_dict) - for key, value in set_dict.items(): - setattr(self.data[index], key, value) + self._check_classes(self + [item]) + self._check_unique_name_fields(self + [item]) + self.data[index] = item def __delitem__(self, index: int) -> None: """Delete an object from the list by index.""" @@ -85,8 +85,10 @@ def __iadd__(self, other: Sequence[object]) -> 'ClassList': def _iadd(self, other: Sequence[object]) -> 'ClassList': """Auxiliary routine of "__iadd__" used to enable wrapping.""" + if other and not (isinstance(other, Sequence) and not isinstance(other, str)): + other = [other] if not hasattr(self, '_class_handle'): - self._class_handle = type(other[0]) + self._class_handle = self._determine_class_handle(self + other) self._check_classes(self + other) self._check_unique_name_fields(self + other) super().__iadd__(other) @@ -201,12 +203,20 @@ def index(self, item: Union[object, str], *args) -> int: def extend(self, other: Sequence[object]) -> None: """Extend the ClassList by adding another sequence.""" + if other and not (isinstance(other, Sequence) and not isinstance(other, str)): + other = [other] if not hasattr(self, '_class_handle'): - self._class_handle = type(other[0]) + self._class_handle = self._determine_class_handle(self + other) self._check_classes(self + other) self._check_unique_name_fields(self + other) self.data.extend(other) + def set_fields(self, index: int, **kwargs) -> None: + """Assign the values of an existing object's attributes using keyword arguments.""" + self._validate_name_field(kwargs) + for key, value in kwargs.items(): + setattr(self.data[index], key, value) + def get_names(self) -> list[str]: """Return a list of the values of the name_field attribute of each class object in the list. @@ -302,3 +312,28 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, object with that value of the name_field attribute cannot be found. """ return next((model for model in self.data if getattr(model, self.name_field) == value), value) + + @staticmethod + def _determine_class_handle(input_list: Sequence[object]): + """When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the + element which satisfies "issubclass" for all of the other elements. + + Parameters + ---------- + input_list : Sequence [object] + A list of instances to populate the ClassList. + + Returns + ------- + class_handle : type + The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other + elements. + """ + for this_element in input_list: + if all([issubclass(type(instance), type(this_element)) for instance in input_list]): + class_handle = type(this_element) + break + else: + class_handle = type(input_list[0]) + + return class_handle diff --git a/RAT/project.py b/RAT/project.py index 7d36b02a..cd98e5f8 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -170,7 +170,8 @@ def model_post_init(self, __context: Any) -> None: # Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the # model, handle errors and reset previous values if necessary. - methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend'] + methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend', + 'set_fields'] for class_list in class_lists: attribute = getattr(self, class_list) for methodName in methods_to_wrap: diff --git a/tests/test_classlist.py b/tests/test_classlist.py index deee1d2f..e5f09da0 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -7,7 +7,7 @@ import warnings from RAT.classlist import ClassList -from tests.utils import InputAttributes +from tests.utils import InputAttributes, SubInputAttributes @pytest.fixture @@ -59,7 +59,21 @@ def test_input_sequence(self, input_sequence: Sequence[object]) -> None: """ class_list = ClassList(input_sequence) assert class_list.data == list(input_sequence) - assert isinstance(input_sequence[-1], class_list._class_handle) + for element in input_sequence: + assert isinstance(element, class_list._class_handle) + + @pytest.mark.parametrize("input_sequence", [ + ([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')]), + ([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')]), + ]) + def test_input_sequence_subclass(self, input_sequence: Sequence[object]) -> None: + """For an input of a sequence containing objects of a class and its subclasses, the ClassList should be a list + equal to the input sequence, and _class_handle should be set to the type of the parent class. + """ + class_list = ClassList(input_sequence) + assert class_list.data == list(input_sequence) + for element in input_sequence: + assert isinstance(element, class_list._class_handle) @pytest.mark.parametrize("empty_input", [([]), (())]) def test_empty_input(self, empty_input: Sequence[object]) -> None: @@ -119,26 +133,33 @@ def test_repr_empty_classlist() -> None: assert repr(ClassList()) == repr([]) -@pytest.mark.parametrize(["new_values", "expected_classlist"], [ - ({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])), - ({'name': 'John', 'surname': 'Luther'}, +@pytest.mark.parametrize(["new_item", "expected_classlist"], [ + (InputAttributes(name='Eve'), ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])), + (InputAttributes(name='John', surname='Luther'), ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])), ]) -def test_setitem(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList') -> None: - """We should be able to set values in an element of a ClassList using a dictionary.""" +def test_setitem(two_name_class_list: ClassList, new_item: InputAttributes, expected_classlist: ClassList) -> None: + """We should be able to set values in an element of a ClassList using a new object.""" class_list = two_name_class_list - class_list[0] = new_values + class_list[0] = new_item assert class_list == expected_classlist +@pytest.mark.parametrize("new_item", [ + (InputAttributes(name='Bob')), +]) +def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_item: InputAttributes) -> None: + """If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError.""" + with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"): + two_name_class_list[0] = new_item + + @pytest.mark.parametrize("new_values", [ - ({'name': 'Bob'}), + 'Bob', ]) -def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: +def test_setitem_different_classes(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: """If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError.""" - with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} " - f"'{new_values[two_name_class_list.name_field]}', " - f"which is already specified in the ClassList"): + with pytest.raises(ValueError, match=f"Input list contains elements of type other than 'InputAttributes'"): two_name_class_list[0] = new_values @@ -160,9 +181,11 @@ def test_delitem_not_present(two_name_class_list: 'ClassList') -> None: (ClassList(InputAttributes(name='Eve'))), ([InputAttributes(name='Eve')]), (InputAttributes(name='Eve'),), + (InputAttributes(name='Eve')), ]) def test_iadd(two_name_class_list: 'ClassList', added_list: Iterable, three_name_class_list: 'ClassList') -> None: - """We should be able to use the "+=" operator to add iterables to a ClassList.""" + """We should be able to use the "+=" operator to add iterables to a ClassList. Individual objects should be wrapped + in a list before being added.""" class_list = two_name_class_list class_list += added_list assert class_list == three_name_class_list @@ -439,9 +462,11 @@ def test_index_not_present(two_name_class_list: 'ClassList', index_value: Union[ (ClassList(InputAttributes(name='Eve'))), ([InputAttributes(name='Eve')]), (InputAttributes(name='Eve'),), + (InputAttributes(name='Eve')), ]) def test_extend(two_name_class_list: 'ClassList', extended_list: Sequence, three_name_class_list: 'ClassList') -> None: - """We should be able to extend a ClassList using another ClassList or a sequence""" + """We should be able to extend a ClassList using another ClassList or a sequence. Individual objects should be + wrapped in a list before being added.""" class_list = two_name_class_list class_list.extend(extended_list) assert class_list == three_name_class_list @@ -460,6 +485,30 @@ def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: 'C assert isinstance(extended_list[-1], class_list._class_handle) +@pytest.mark.parametrize(["new_values", "expected_classlist"], [ + ({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])), + ({'name': 'John', 'surname': 'Luther'}, + ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])), +]) +def test_set_fields(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList')\ + -> None: + """We should be able to set field values in an element of a ClassList using keyword arguments.""" + class_list = two_name_class_list + class_list.set_fields(0, **new_values) + assert class_list == expected_classlist + + +@pytest.mark.parametrize("new_values", [ + ({'name': 'Bob'}), +]) +def test_set_fields_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None: + """If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError.""" + with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} " + f"'{new_values[two_name_class_list.name_field]}', " + f"which is already specified in the ClassList"): + two_name_class_list.set_fields(0, **new_values) + + @pytest.mark.parametrize(["class_list", "expected_names"], [ (ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), ["Alice", "Bob"]), (ClassList([InputAttributes(id='Alice'), InputAttributes(id='Bob')], name_field='id'), ["Alice", "Bob"]), @@ -563,3 +612,19 @@ def test__get_item_from_name_field(two_name_class_list: 'ClassList', If the value is not the name_field of an object defined in the ClassList, we should return the value. """ assert two_name_class_list._get_item_from_name_field(value) == expected_output + + +@pytest.mark.parametrize(["input_list", "expected_type"], [ + ([InputAttributes(name='Alice')], InputAttributes), + ([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')], InputAttributes), + ([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')], InputAttributes), + ([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob')], SubInputAttributes), + ([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob'), InputAttributes(name='Eve')], InputAttributes), + ([InputAttributes(name='Alice'), dict(name='Bob')], InputAttributes), + ([dict(name='Alice'), InputAttributes(name='Bob')], dict), +]) +def test_determine_class_handle(input_list: 'ClassList', expected_type: type) -> None: + """The _class_handle for the ClassList should be the type that satisfies the condition "isinstance(element, type)" + for all elements in the ClassList. + """ + assert ClassList._determine_class_handle(input_list) == expected_type diff --git a/tests/test_project.py b/tests/test_project.py index 78a96a36..a06ecc8b 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -17,7 +17,7 @@ def test_project(): """Add parameters to the default project, so each ClassList can be tested properly.""" test_project = RAT.project.Project() - test_project.data[0] = {'data': np.array([[1, 1, 1]])} + test_project.data.set_fields(0, data=np.array([[1, 1, 1]])) test_project.parameters.append(name='Test SLD') test_project.custom_files.append(name='Test Custom File') test_project.layers.append(name='Test Layer', SLD='Test SLD') @@ -161,7 +161,7 @@ def test_rename_models(test_project, model: str, field: str) -> None: """When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere in the project. """ - getattr(test_project, model)[-1] = {'name': 'New Name'} + getattr(test_project, model).set_fields(-1, name='New Name') attribute = RAT.project.model_names_used_in[model].attribute assert getattr(getattr(test_project, attribute)[-1], field) == 'New Name' @@ -307,7 +307,7 @@ def test_wrap_set(test_project, class_list: str, field: str) -> None: orig_class_list = copy.deepcopy(test_attribute) with contextlib.redirect_stdout(io.StringIO()) as print_str: - test_attribute[0] = {field: 'undefined'} + test_attribute.set_fields(0, **{field: 'undefined'}) assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in ' f'the "{field}" field of "{class_list}" must be defined in ' f'"{RAT.project.values_defined_in[class_list+"."+field]}".\033[0m\n') diff --git a/tests/utils.py b/tests/utils.py index 4c4778c7..199979e2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,3 +11,8 @@ def __eq__(self, other: Any): if isinstance(other, InputAttributes): return self.__dict__ == other.__dict__ return False + + +class SubInputAttributes(InputAttributes): + """Trivial subclass of InputAttributes""" + pass