Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

from collections import deque
from typing import TYPE_CHECKING, Any, Generic

from cudf_polars.typing import (
Expand Down Expand Up @@ -41,8 +42,13 @@ def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
Unique nodes in the expressions, parent before child, children
in-order from left to right.
"""
seen = set(nodes)
lifo = list(nodes)
seen: set[NodeT] = set()
lifo: deque[NodeT] = deque()

for node in nodes:
if node not in seen:
lifo.append(node)
seen.add(node)

while lifo:
node = lifo.pop()
Expand All @@ -53,6 +59,40 @@ def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
lifo.append(child)


def post_traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
"""
Post-order traversal of nodes in an expression.

Parameters
----------
nodes
Roots of expressions to traverse.

Yields
------
Unique nodes in the expressions, child before parent, children
in-order from left to right.
"""
seen: set[NodeT] = set()
lifo: deque[NodeT] = deque()

for node in nodes:
if node not in seen:
lifo.append(node)
seen.add(node)

while lifo:
node = lifo[-1]
for child in node.children:
if child not in seen:
lifo.append(child)
seen.add(child)
break
else:
yield node
lifo.pop()


def reuse_if_unchanged(
node: NodeT, fn: GenericTransformer[NodeT, NodeT, StateT_co]
) -> NodeT:
Expand Down
41 changes: 38 additions & 3 deletions python/cudf_polars/tests/dsl/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from cudf_polars.dsl.traversal import (
CachingVisitor,
make_recursive,
post_traversal,
reuse_if_unchanged,
traversal,
)
Expand All @@ -39,7 +40,7 @@ def make_expr(dt, n1, n2):


def test_traversal_unique():
dt = DataType(pl.datatypes.Int8())
dt = DataType(pl.Int8())

e1 = make_expr(dt, "a", "a")
unique_exprs = list(traversal([e1]))
Expand All @@ -63,6 +64,40 @@ def test_traversal_unique():
assert unique_exprs == [e3, expr.Col(dt, "b"), expr.Col(dt, "a")]


def test_post_traversal_unique():
dt = DataType(pl.Int8())

e1 = make_expr(dt, "a", "a")
unique_exprs = list(post_traversal([e1]))
assert unique_exprs == [expr.Col(dt, "a"), e1]

e2 = make_expr(dt, "a", "b")
unique_exprs = list(post_traversal([e2]))
assert unique_exprs == [expr.Col(dt, "a"), expr.Col(dt, "b"), e2]

e3 = make_expr(dt, "b", "a")
unique_exprs = list(post_traversal([e3]))
assert unique_exprs == [expr.Col(dt, "b"), expr.Col(dt, "a"), e3]


def test_post_traversal_multi():
dt = DataType(pl.Int8())

e1 = make_expr(dt, "a", "a")
e2 = make_expr(dt, "a", "b")
e3 = make_expr(dt, "b", "a")

unique_exprs = list(post_traversal([e1, e2, e3]))
assert len(unique_exprs) == 5
assert unique_exprs == [
expr.Col(dt, "b"),
expr.Col(dt, "a"),
e3,
e2,
e1,
]


def rename(e, rec):
mapping = rec.state["mapping"]
if isinstance(e, expr.Col) and e.name in mapping:
Expand All @@ -71,7 +106,7 @@ def rename(e, rec):


def test_caching_visitor():
dt = DataType(pl.datatypes.Int8())
dt = DataType(pl.Int8())

e1 = make_expr(dt, "a", "b")

Expand All @@ -95,7 +130,7 @@ def test_caching_visitor():


def test_noop_visitor():
dt = DataType(pl.datatypes.Int8())
dt = DataType(pl.Int8())

e1 = make_expr(dt, "a", "b")

Expand Down