Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,13 +1091,19 @@ def directional_dfs_traverse(
nodes = set()
user_nodes = set()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this dfs transversal order should in principle be different from the iterative one.

Since we are doing DAGs I think this should be fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made it the same.

def dfs_traverse(node: node.Node):
nodes.add(node)
for n in next_nodes_fn(node):
if n not in nodes:
dfs_traverse(n)
if node.user_defined:
user_nodes.add(node)
def dfs_traverse_iterative(start_node: node.Node):
"""Iterative DFS to avoid recursion depth limits with large DAGs."""
stack = [start_node]
nodes.add(start_node)
while stack:
n = stack.pop()
if n.user_defined:
user_nodes.add(n)
# reversed() preserves the same traversal order as the recursive version
for next_n in reversed(next_nodes_fn(n)):
if next_n not in nodes:
nodes.add(next_n)
stack.append(next_n)
Comment on lines +1099 to +1106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can lead to duplicate nodes being on stack I think because you don't mark the nodes as seen until you pop from the stack instead of marking them when you push onto the stack.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, fixed.


missing_vars = []
for var in starting_nodes:
Expand All @@ -1108,7 +1114,7 @@ def dfs_traverse(node: node.Node):
# if it's not in the runtime inputs, it's a properly missing variable
missing_vars.append(var)
continue # collect all missing final variables
dfs_traverse(self.nodes[var])
dfs_traverse_iterative(self.nodes[var])
if missing_vars:
missing_vars_str = ",\n".join(missing_vars)
raise ValueError(f"Unknown nodes [{missing_vars_str}] requested. Check for typos?")
Expand Down
114 changes: 114 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import inspect
import pathlib
import sys
import uuid
from itertools import permutations

Expand All @@ -26,6 +27,7 @@
import hamilton.graph_utils
import hamilton.htypes
from hamilton import ad_hoc_utils, base, graph, node
from hamilton import function_modifiers as fm
from hamilton.execution import graph_functions
from hamilton.function_modifiers import schema
from hamilton.lifecycle import base as lifecycle_base
Expand Down Expand Up @@ -538,6 +540,118 @@ def test_get_downstream_nodes():
assert actual_nodes == expected_nodes


def test_get_upstream_nodes_large_chain_no_recursion_error():
"""Regression test: get_upstream_nodes with only final_node on a large chain DAG.

A recursive DFS would exceed Python's recursion limit (~1000) when traversing
a long dependency chain from a single final node. This test verifies that
the iterative DFS in directional_dfs_traverse handles large DAGs correctly.

Chain size is chosen to exceed recursion limit: 1200 nodes > 1000.
"""

def step(prev: float) -> float:
"""Single step in a linear chain."""
return prev + 1.0

# Build a linear chain: node_0 -> node_1 -> ... -> node_N
chain_size = sys.getrecursionlimit() + 200 # Exceeds recursion limit
config = {}
for i in range(chain_size):
prev = f"node_{i - 1}" if i > 0 else 0.0
config[f"node_{i}"] = {
"prev": fm.source(prev) if i > 0 else fm.value(0.0),
}
decorated = fm.parameterize(**config)(step)
module = ad_hoc_utils.create_temporary_module(decorated, module_name="large_chain")

fg = graph.FunctionGraph.from_modules(module, config={})
final_node = f"node_{chain_size - 1}"

# This would raise RecursionError with recursive DFS
nodes, user_nodes = fg.get_upstream_nodes([final_node])

assert len(nodes) == chain_size
assert len(user_nodes) == 0
assert all(fg.nodes[f"node_{i}"] in nodes for i in range(chain_size))


def test_get_upstream_nodes_diamond_dag():
"""Tests that diamond-shaped DAGs don't produce duplicate visits.

DAG shape:
x, y (inputs)
|
left right (both depend on x and y)
\\ /
bottom (depends on left and right)

The shared inputs x and y are reachable via both left and right.
With a naive iterative DFS (mark-on-pop), x and y could be pushed
onto the stack multiple times. This verifies they appear exactly once.
"""

def left(x: int, y: int) -> int:
return x + y

def right(x: int, y: int) -> int:
return x * y

def bottom(left: int, right: int) -> int:
return left + right

module = ad_hoc_utils.create_temporary_module(left, right, bottom)
fg = graph.FunctionGraph.from_modules(module, config={})
nodes, user_nodes = fg.get_upstream_nodes(["bottom"])

assert len(nodes) == 5 # x, y, left, right, bottom
assert {n.name for n in nodes} == {"x", "y", "left", "right", "bottom"}
# x and y are external inputs
assert {n.name for n in user_nodes} == {"x", "y"}


def test_get_upstream_nodes_single_node():
"""Tests traversal of a single node with no dependencies."""

def solo() -> int:
return 42

module = ad_hoc_utils.create_temporary_module(solo)
fg = graph.FunctionGraph.from_modules(module, config={})
nodes, user_nodes = fg.get_upstream_nodes(["solo"])

assert len(nodes) == 1
assert {n.name for n in nodes} == {"solo"}
assert len(user_nodes) == 0


def test_get_upstream_nodes_overlapping_starting_nodes():
"""Tests that overlapping subgraphs from multiple starting nodes are handled correctly.

DAG shape:
shared (input)
/ \\
a b (both depend on shared)

Requesting both a and b as starting nodes means 'shared' is reachable
from both traversals. It should still appear exactly once in the result.
"""

def a(shared: int) -> int:
return shared + 1

def b(shared: int) -> int:
return shared + 2

module = ad_hoc_utils.create_temporary_module(a, b)
fg = graph.FunctionGraph.from_modules(module, config={})
nodes, user_nodes = fg.get_upstream_nodes(["a", "b"])

assert len(nodes) == 3 # shared, a, b
assert {n.name for n in nodes} == {"shared", "a", "b"}
assert {n.name for n in user_nodes} == {"shared"}


def test_function_graph_from_multiple_sources():
fg = graph.FunctionGraph.from_modules(
tests.resources.dummy_functions, tests.resources.parametrized_nodes, config={}
Expand Down
Loading