Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, init_list: Union[Sequence[object], object] = None, name_field
# Set class to be used for this instance of the ClassList, checking that all elements of the input list are of
# the same type and have unique values of the specified name_field
if init_list:
self._class_handle = type(init_list[0])
self._class_handle = self._determine_class_handle(init_list)
self._check_classes(init_list)
self._check_unique_name_fields(init_list)

Expand All @@ -61,15 +61,22 @@ def __repr__(self):
output = repr(self.data)
return output

def __setitem__(self, index: int, set_dict: dict[str, Any]) -> None:
"""Assign the values of an existing object's attributes using a dictionary."""
self._setitem(index, set_dict)
def __setitem__(self, index: int, item: Union['RAT.models', dict[str, Any]]) -> None:
"""Assign the values of an existing object's attributes using either a replacement object or a dictionary
containing key-value pairs.
"""
self._setitem(index, item)

def _setitem(self, index: int, set_dict: dict[str, Any]) -> None:
def _setitem(self, index: int, item: Union['RAT.models', dict[str, Any]]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to support one type i.e only models or dict (my pick is models) then provide a fallback for anyone who wants to use the alternate e.g. add_rows_as_type. Otherwise there will be failing edge cases in the current design

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's fine with me. I'll get on it!

"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._validate_name_field(set_dict)
for key, value in set_dict.items():
setattr(self.data[index], key, value)
if isinstance(item, dict):
self._validate_name_field(item)
for key, value in item.items():
setattr(self.data[index], key, value)
else:
self._check_classes(self + [item])
self._check_unique_name_fields(self + [item])
self.data[index] = item

def __delitem__(self, index: int) -> None:
"""Delete an object from the list by index."""
Expand All @@ -85,8 +92,10 @@ def __iadd__(self, other: Sequence[object]) -> 'ClassList':

def _iadd(self, other: Sequence[object]) -> 'ClassList':
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
self._class_handle = type(other[0])
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
super().__iadd__(other)
Expand Down Expand Up @@ -201,8 +210,10 @@ def index(self, item: Union[object, str], *args) -> int:

def extend(self, other: Sequence[object]) -> None:
"""Extend the ClassList by adding another sequence."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
if not hasattr(self, '_class_handle'):
self._class_handle = type(other[0])
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
self.data.extend(other)
Expand Down Expand Up @@ -302,3 +313,28 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
object with that value of the name_field attribute cannot be found.
"""
return next((model for model in self.data if getattr(model, self.name_field) == value), value)

@staticmethod
def _determine_class_handle(input_list: Sequence[object]):
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
element which satisfies "issubclass" for all of the other elements.

Parameters
----------
input_list : Sequence [object]
A list of instances to populate the ClassList.

Returns
-------
class_handle : type
The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other
elements.
"""
for this_element in input_list:
if all([issubclass(type(instance), type(this_element)) for instance in input_list]):
class_handle = type(this_element)
break
else:
class_handle = type(input_list[0])

return class_handle
74 changes: 68 additions & 6 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from RAT.classlist import ClassList
from tests.utils import InputAttributes
from tests.utils import InputAttributes, SubInputAttributes


@pytest.fixture
Expand Down Expand Up @@ -59,7 +59,21 @@ def test_input_sequence(self, input_sequence: Sequence[object]) -> None:
"""
class_list = ClassList(input_sequence)
assert class_list.data == list(input_sequence)
assert isinstance(input_sequence[-1], class_list._class_handle)
for element in input_sequence:
assert isinstance(element, class_list._class_handle)

@pytest.mark.parametrize("input_sequence", [
([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')]),
([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')]),
])
def test_input_sequence_subclass(self, input_sequence: Sequence[object]) -> None:
"""For an input of a sequence containing objects of a class and its subclasses, the ClassList should be a list
equal to the input sequence, and _class_handle should be set to the type of the parent class.
"""
class_list = ClassList(input_sequence)
assert class_list.data == list(input_sequence)
for element in input_sequence:
assert isinstance(element, class_list._class_handle)

@pytest.mark.parametrize("empty_input", [([]), (())])
def test_empty_input(self, empty_input: Sequence[object]) -> None:
Expand Down Expand Up @@ -119,28 +133,58 @@ def test_repr_empty_classlist() -> None:
assert repr(ClassList()) == repr([])


@pytest.mark.parametrize(["new_item", "expected_classlist"], [
(InputAttributes(name='Eve'), ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
(InputAttributes(name='John', surname='Luther'),
ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])),
])
def test_setitem_obj(two_name_class_list: ClassList, new_item: InputAttributes, expected_classlist: ClassList)\
-> None:
"""We should be able to set values in an element of a ClassList using a new object."""
class_list = two_name_class_list
class_list[0] = new_item
assert class_list == expected_classlist


@pytest.mark.parametrize(["new_values", "expected_classlist"], [
({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
({'name': 'John', 'surname': 'Luther'},
ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])),
])
def test_setitem(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList') -> None:
def test_setitem_dict(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList') -> None:
"""We should be able to set values in an element of a ClassList using a dictionary."""
class_list = two_name_class_list
class_list[0] = new_values
assert class_list == expected_classlist


@pytest.mark.parametrize("new_item", [
(InputAttributes(name='Bob')),
])
def test_setitem_same_name_field_obj(two_name_class_list: 'ClassList', new_item: InputAttributes) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"):
two_name_class_list[0] = new_item


@pytest.mark.parametrize("new_values", [
({'name': 'Bob'}),
])
def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
def test_setitem_same_name_field_dict(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList"):
two_name_class_list[0] = new_values

@pytest.mark.parametrize("new_values", [
'Bob',
])
def test_setitem_different_classes(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match=f"Input list contains elements of type other than 'InputAttributes'"):
two_name_class_list[0] = new_values


def test_delitem(two_name_class_list: 'ClassList', one_name_class_list: 'ClassList') -> None:
"""We should be able to delete elements from a ClassList with the del operator."""
Expand All @@ -160,9 +204,11 @@ def test_delitem_not_present(two_name_class_list: 'ClassList') -> None:
(ClassList(InputAttributes(name='Eve'))),
([InputAttributes(name='Eve')]),
(InputAttributes(name='Eve'),),
(InputAttributes(name='Eve')),
])
def test_iadd(two_name_class_list: 'ClassList', added_list: Iterable, three_name_class_list: 'ClassList') -> None:
"""We should be able to use the "+=" operator to add iterables to a ClassList."""
"""We should be able to use the "+=" operator to add iterables to a ClassList. Individual objects should be wrapped
in a list before being added."""
class_list = two_name_class_list
class_list += added_list
assert class_list == three_name_class_list
Expand Down Expand Up @@ -439,9 +485,11 @@ def test_index_not_present(two_name_class_list: 'ClassList', index_value: Union[
(ClassList(InputAttributes(name='Eve'))),
([InputAttributes(name='Eve')]),
(InputAttributes(name='Eve'),),
(InputAttributes(name='Eve')),
])
def test_extend(two_name_class_list: 'ClassList', extended_list: Sequence, three_name_class_list: 'ClassList') -> None:
"""We should be able to extend a ClassList using another ClassList or a sequence"""
"""We should be able to extend a ClassList using another ClassList or a sequence. Individual objects should be
wrapped in a list before being added."""
class_list = two_name_class_list
class_list.extend(extended_list)
assert class_list == three_name_class_list
Expand Down Expand Up @@ -563,3 +611,17 @@ def test__get_item_from_name_field(two_name_class_list: 'ClassList',
If the value is not the name_field of an object defined in the ClassList, we should return the value.
"""
assert two_name_class_list._get_item_from_name_field(value) == expected_output


@pytest.mark.parametrize("input_list", [
([InputAttributes(name='Alice')]),
([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')]),
([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')]),
([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob'), InputAttributes(name='Eve')]),
([InputAttributes(name='Alice'), dict(name='Bob')]),
])
def test_determine_class_handle(input_list: 'ClassList') -> None:
"""The _class_handle for the ClassList should be the type that satisfies the condition "isinstance(element, type)"
for all elements in the ClassList
"""
assert ClassList._determine_class_handle(input_list) == InputAttributes
5 changes: 5 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ def __eq__(self, other: Any):
if isinstance(other, InputAttributes):
return self.__dict__ == other.__dict__
return False


class SubInputAttributes(InputAttributes):
"""Trivial subclass of InputAttributes"""
pass