Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 18 additions & 75 deletions src/orion/core/utils/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,85 +21,28 @@
NodeType = TypeVar("NodeType", bound="TreeNode")


# pylint: disable=too-few-public-methods
class PreOrderTraversal(Iterable[NodeType]):
"""Iterate on a tree in a pre-order traversal fashion
def PreOrderTraversal(tree_node: NodeType) -> Iterator[NodeType]:
"""Iterate on a tree in a pre-order traversal fashion"""
stack = [tree_node]
while stack:
node = stack.pop()
yield node
stack.extend(node.children[::-1])

Attributes
----------
stack: list of `orion.core.utils.tree.TreeNode`
Nodes logged during iteration

"""

__slots__ = ("stack",)

def __init__(self, tree_node: NodeType):
"""Initialize the stack for iteration"""
self.stack = [tree_node]

def __iter__(self) -> Iterator[NodeType]:
"""Get the iterator"""
return self

def __next__(self) -> NodeType:
"""Get the next node in pre-order traversal"""
try:
node = self.stack.pop()
except IndexError as exc:
raise StopIteration from exc

self.stack += node.children[::-1]

return node


# pylint: disable=too-few-public-methods
class DepthFirstTraversal(Iterable[NodeType]):
"""Iterate on a tree in a pre-order traversal fashion

Attributes
----------
stack: list of `orion.core.utils.tree.TreeNode`
Nodes logged during iteration
seen: set of `orion.core.utils.tree.TreeNode`
Nodes which have been returned during iteration

"""

__slots__ = ("stack", "seen")

def __init__(self, tree_node: NodeType):
"""Initialize the stack and set of seen nodes for iteration"""
self.stack = [tree_node]
self.seen: set[NodeType] = set()
def DepthFirstTraversal(tree_node: NodeType) -> Iterable[NodeType]:
"""Iterate on a tree in a post-order traversal fashion"""
seen: set[NodeType] = set()

def _compute_potential(self) -> list[NodeType]:
"""Filter out seen nodes from the stack"""
if not self.stack:
return []

return list(filter(lambda n: n not in self.seen, self.stack[-1].children))

def __iter__(self) -> Iterator[NodeType]:
"""Get the iterator"""
return self

def __next__(self) -> NodeType:
"""Get the next node in depth-first traversal"""
potential = self._compute_potential()
while self.stack and potential:
self.stack.extend(potential[::-1])
potential = self._compute_potential()

try:
node = self.stack.pop()
except IndexError as exc:
raise StopIteration from exc

self.seen.add(node)
def _inner(node: NodeType) -> Iterable[NodeType]:
if node in seen:
return
seen.add(node)
for child in node.children:
yield from _inner(child)
yield node

return node
return _inner(tree_node)


class TreeNode(Generic[T], Iterable[T]):
Expand Down