Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(
checkpoint_prefix=checkpoint_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
self.loop = asyncio.get_running_loop()
# Deferred: the event loop is captured in asetup() so that the saver can
# be constructed outside an async context (Issue #179).
self.loop: Optional[asyncio.AbstractEventLoop] = None

# Instance-level cache for frequently used keys (limited size to prevent memory issues)
self._key_cache: Dict[str, str] = {}
Expand Down Expand Up @@ -243,6 +245,13 @@ async def __aexit__(

async def asetup(self) -> None:
"""Set up the checkpoint saver."""
# Capture the running event loop here so that sync wrapper methods
# (get_tuple, put, put_writes, …) can dispatch coroutines to it via
# asyncio.run_coroutine_threadsafe. Deferring this to asetup() instead
# of __init__ lets callers construct the saver outside an async context
# (Issue #179).
self.loop = asyncio.get_running_loop()

self.create_indexes()
await self.checkpoints_index.create(overwrite=False)
await self.checkpoint_writes_index.create(overwrite=False)
Expand Down Expand Up @@ -1307,6 +1316,20 @@ def put_writes(
task_id (str): Identifier for the task creating the writes.
task_path (str): Path of the task creating the writes.
"""
if self.loop is None:
raise RuntimeError(
"AsyncRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
Comment thread
bsbodden marked this conversation as resolved.
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncRedisSaver are only allowed from a "
"different thread. From the main thread, use the async interface. "
"For example, use `await checkpointer.aput_writes(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()
Expand All @@ -1315,12 +1338,17 @@ def get_channel_values(
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
) -> Dict[str, Any]:
"""Retrieve channel_values using efficient FT.SEARCH with checkpoint_id (sync wrapper)."""
if self.loop is None:
raise RuntimeError(
"AsyncRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncRedisSaver are only allowed from a "
"different thread. From the main thread, use the async interface."
"For example, use `await checkpointer.get_channel_values(...)`."
"different thread. From the main thread, use the async interface. "
"For example, use `await checkpointer.aget_channel_values(...)`."
)
except RuntimeError:
pass
Expand All @@ -1345,6 +1373,11 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Raises:
asyncio.InvalidStateError: If called from the wrong thread/event loop
"""
if self.loop is None:
raise RuntimeError(
"AsyncRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
Comment thread
bsbodden marked this conversation as resolved.
try:
# check if we are in the main thread, only bg threads can block
if asyncio.get_running_loop() is self.loop:
Expand Down Expand Up @@ -1381,6 +1414,11 @@ def put(
Raises:
asyncio.InvalidStateError: If called from the wrong thread/event loop
"""
if self.loop is None:
raise RuntimeError(
"AsyncRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
Comment thread
bsbodden marked this conversation as resolved.
try:
# check if we are in the main thread, only bg threads can block
if asyncio.get_running_loop() is self.loop:
Expand Down
49 changes: 48 additions & 1 deletion langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __init__(
checkpoint_prefix=checkpoint_prefix,
checkpoint_write_prefix=checkpoint_write_prefix,
)
self.loop = asyncio.get_running_loop()
# Deferred: the event loop is captured in asetup() so that the saver can
# be constructed outside an async context (Issue #179).
self.loop: Optional[asyncio.AbstractEventLoop] = None

# Instance-level cache for frequently used keys (limited size to prevent memory issues)
self._key_cache: Dict[str, str] = {}
Expand Down Expand Up @@ -139,6 +141,13 @@ async def from_conn_string(

async def asetup(self) -> None:
"""Initialize Redis indexes asynchronously."""
# Capture the running event loop here so that sync wrapper methods
# (get_tuple, put, put_writes, …) can dispatch coroutines to it via
# asyncio.run_coroutine_threadsafe. Deferring this to asetup() instead
# of __init__ lets callers construct the saver outside an async context
# (Issue #179).
self.loop = asyncio.get_running_loop()

await self.checkpoints_index.create(overwrite=False)
await self.checkpoint_writes_index.create(overwrite=False)

Expand Down Expand Up @@ -725,6 +734,11 @@ def create_indexes(self) -> None:

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Retrieve a checkpoint tuple from Redis synchronously."""
if self.loop is None:
raise RuntimeError(
"AsyncShallowRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
Expand All @@ -747,6 +761,20 @@ def put(
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store only the latest checkpoint synchronously."""
if self.loop is None:
raise RuntimeError(
"AsyncShallowRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
Comment thread
bsbodden marked this conversation as resolved.
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncShallowRedisSaver are only allowed from a "
"different thread. From the main thread, use the async interface. "
"For example, use `await checkpointer.aput(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aput(config, checkpoint, metadata, new_versions), self.loop
).result()
Expand All @@ -759,6 +787,20 @@ def put_writes(
task_path: str = "",
) -> None:
"""Store intermediate writes synchronously."""
if self.loop is None:
raise RuntimeError(
"AsyncShallowRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
Comment thread
bsbodden marked this conversation as resolved.
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
"Synchronous calls to AsyncShallowRedisSaver are only allowed from a "
"different thread. From the main thread, use the async interface. "
"For example, use `await checkpointer.aput_writes(...)`."
)
except RuntimeError:
pass
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()
Expand All @@ -771,6 +813,11 @@ def get_channel_values(
channel_versions: Optional[Dict[str, Any]] = None,
) -> dict[str, Any]:
"""Retrieve channel_values dictionary with properly constructed message objects (sync wrapper)."""
if self.loop is None:
raise RuntimeError(
"AsyncShallowRedisSaver must be set up before calling synchronous methods. "
"Call `await saver.asetup()` or use `async with saver:` first."
)
try:
if asyncio.get_running_loop() is self.loop:
raise asyncio.InvalidStateError(
Expand Down
75 changes: 75 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,78 @@ async def test_root_graph_checkpoint(
checkpoints = [c async for c in checkpointer.alist(config)]
assert len(checkpoints) > 0
assert checkpoints[-1].checkpoint["id"] == latest["id"]


# --- Issue #179: AsyncRedisSaver construction outside async context ---


def test_async_redis_saver_construction_outside_event_loop(redis_url: str) -> None:
"""AsyncRedisSaver should be constructable outside an async context (Issue #179).

Previously, AsyncRedisSaver.__init__ called asyncio.get_running_loop() which
raised RuntimeError when no event loop was running.
"""
# This must not raise RuntimeError even when there is no running event loop
saver = AsyncRedisSaver(redis_url)
assert saver is not None
# Loop should be None until asetup() is called
assert saver.loop is None


def test_async_redis_saver_construction_with_client_outside_event_loop(
redis_url: str,
) -> None:
"""AsyncRedisSaver should accept a pre-built client without a running loop (Issue #179).

The typical use-case from the issue: constructing the saver synchronously,
then setting up (and using it) later inside an async lifespan handler.
"""
from redis.asyncio import Redis as AsyncRedis

client = AsyncRedis.from_url(redis_url)
try:
saver = AsyncRedisSaver(redis_client=client)
assert saver is not None
assert saver.loop is None
finally:
asyncio.run(client.aclose())


@pytest.mark.asyncio
async def test_async_redis_saver_loop_captured_in_asetup(redis_url: str) -> None:
"""asetup() must capture the running event loop so sync wrappers work (Issue #179)."""
saver = AsyncRedisSaver(redis_url)
assert saver.loop is None # not yet set

await saver.asetup()

# After asetup the loop attribute must point to the current running loop
assert saver.loop is not None
assert saver.loop is asyncio.get_running_loop()

await saver._redis.aclose()


@pytest.mark.asyncio
async def test_async_redis_saver_context_manager_after_sync_construction(
redis_url: str,
) -> None:
"""Saver built before entering the async context manager must still work."""
# Construct before entering `async with`; in this async test a loop is already
# running, but this still verifies the saver is usable end-to-end once setup
# happens on context-manager entry.
saver = AsyncRedisSaver(redis_url)

async with saver:
# After entering the context the loop must be set
assert saver.loop is asyncio.get_running_loop()

# Basic functional smoke test
config: RunnableConfig = {
"configurable": {"thread_id": "issue-179-test", "checkpoint_ns": ""}
}
chk: Checkpoint = empty_checkpoint()
meta: CheckpointMetadata = {"source": "input", "step": 0, "writes": {}}
await saver.aput(config, chk, meta, {})
result = await saver.aget_tuple(config)
assert result is not None
70 changes: 70 additions & 0 deletions tests/test_shallow_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, AsyncGenerator, Dict

import pytest
Expand Down Expand Up @@ -494,3 +495,72 @@ async def test_shallow_redis_saver_inline_storage(redis_url: str) -> None:
# Clean up test data
await redis_client.flushdb()
await redis_client.aclose()


# --- Issue #179: AsyncShallowRedisSaver construction outside async context ---


def test_async_shallow_redis_saver_construction_outside_event_loop(
redis_url: str,
) -> None:
"""AsyncShallowRedisSaver should be constructable outside an async context (Issue #179).

Previously, AsyncShallowRedisSaver.__init__ called asyncio.get_running_loop() which
raised RuntimeError when no event loop was running.
"""
# This must not raise RuntimeError even when there is no running event loop
saver = AsyncShallowRedisSaver(redis_url)
assert saver is not None
# Loop should be None until asetup() is called
assert saver.loop is None


def test_async_shallow_redis_saver_construction_with_client_outside_event_loop(
redis_url: str,
) -> None:
"""AsyncShallowRedisSaver accepts a pre-built client without a running loop (Issue #179)."""
from redis.asyncio import Redis as AsyncRedis

client = AsyncRedis.from_url(redis_url)
try:
saver = AsyncShallowRedisSaver(redis_client=client)
assert saver is not None
assert saver.loop is None
finally:
asyncio.run(client.aclose())


@pytest.mark.asyncio
async def test_async_shallow_redis_saver_loop_captured_in_asetup(
redis_url: str,
) -> None:
"""asetup() must capture the running event loop so sync wrappers work (Issue #179)."""
saver = AsyncShallowRedisSaver(redis_url)
assert saver.loop is None # not yet set

await saver.asetup()

assert saver.loop is not None
assert saver.loop is asyncio.get_running_loop()

await saver._redis.aclose()


@pytest.mark.asyncio
async def test_async_shallow_redis_saver_context_manager_after_sync_construction(
redis_url: str,
) -> None:
"""Saver constructed before entering the async context manager must still work."""
saver = AsyncShallowRedisSaver(redis_url)

async with saver:
assert saver.loop is asyncio.get_running_loop()

config: RunnableConfig = {
"configurable": {"thread_id": "issue-179-shallow-test", "checkpoint_ns": ""}
}
chk: Checkpoint = empty_checkpoint()
meta: CheckpointMetadata = {"source": "input", "step": 0, "writes": {}}
await saver.aput(config, chk, meta, {})
result = await saver.aget_tuple(config)
assert result is not None
Loading