diff --git a/src/orion/core/utils/tree.py b/src/orion/core/utils/tree.py index e128dd12b..443feffbc 100644 --- a/src/orion/core/utils/tree.py +++ b/src/orion/core/utils/tree.py @@ -21,85 +21,28 @@ NodeType = TypeVar("NodeType", bound="TreeNode") -# pylint: disable=too-few-public-methods -class PreOrderTraversal(Iterable[NodeType]): - """Iterate on a tree in a pre-order traversal fashion +def PreOrderTraversal(tree_node: NodeType) -> Iterator[NodeType]: + """Iterate on a tree in a pre-order traversal fashion""" + stack = [tree_node] + while stack: + node = stack.pop() + yield node + stack.extend(node.children[::-1]) - Attributes - ---------- - stack: list of `orion.core.utils.tree.TreeNode` - Nodes logged during iteration - - """ - - __slots__ = ("stack",) - - def __init__(self, tree_node: NodeType): - """Initialize the stack for iteration""" - self.stack = [tree_node] - - def __iter__(self) -> Iterator[NodeType]: - """Get the iterator""" - return self - - def __next__(self) -> NodeType: - """Get the next node in pre-order traversal""" - try: - node = self.stack.pop() - except IndexError as exc: - raise StopIteration from exc - - self.stack += node.children[::-1] - - return node - - -# pylint: disable=too-few-public-methods -class DepthFirstTraversal(Iterable[NodeType]): - """Iterate on a tree in a pre-order traversal fashion - Attributes - ---------- - stack: list of `orion.core.utils.tree.TreeNode` - Nodes logged during iteration - seen: set of `orion.core.utils.tree.TreeNode` - Nodes which have been returned during iteration - - """ - - __slots__ = ("stack", "seen") - - def __init__(self, tree_node: NodeType): - """Initialize the stack and set of seen nodes for iteration""" - self.stack = [tree_node] - self.seen: set[NodeType] = set() +def DepthFirstTraversal(tree_node: NodeType) -> Iterable[NodeType]: + """Iterate on a tree in a post-order traversal fashion""" + seen: set[NodeType] = set() - def _compute_potential(self) -> list[NodeType]: - """Filter out seen nodes from the stack""" - if not self.stack: - return [] - - return list(filter(lambda n: n not in self.seen, self.stack[-1].children)) - - def __iter__(self) -> Iterator[NodeType]: - """Get the iterator""" - return self - - def __next__(self) -> NodeType: - """Get the next node in depth-first traversal""" - potential = self._compute_potential() - while self.stack and potential: - self.stack.extend(potential[::-1]) - potential = self._compute_potential() - - try: - node = self.stack.pop() - except IndexError as exc: - raise StopIteration from exc - - self.seen.add(node) + def _inner(node: NodeType) -> Iterable[NodeType]: + if node in seen: + return + seen.add(node) + for child in node.children: + yield from _inner(child) + yield node - return node + return _inner(tree_node) class TreeNode(Generic[T], Iterable[T]):