Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 8 additions & 24 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,9 +1070,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
)

mock_agent_engine = mock.Mock()
mock_agent_engine.async_create_session = mock.AsyncMock(
return_value={"id": "session1"}
)
mock_agent_engine.create_session.return_value = {"id": "session1"}
stream_query_return_value = [
{
"id": "1",
Expand All @@ -1088,13 +1086,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
},
]

async def _async_iterator(iterable):
for item in iterable:
yield item

mock_agent_engine.async_stream_query.return_value = _async_iterator(
stream_query_return_value
)
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
mock_vertexai_client.return_value.agent_engines.get.return_value = (
mock_agent_engine
)
Expand All @@ -1108,10 +1100,10 @@ async def _async_iterator(iterable):
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
name="projects/test-project/locations/us-central1/reasoningEngines/123"
)
mock_agent_engine.async_create_session.assert_called_once_with(
mock_agent_engine.create_session.assert_called_once_with(
user_id="123", state={"a": "1"}
)
mock_agent_engine.async_stream_query.assert_called_once_with(
mock_agent_engine.stream_query.assert_called_once_with(
user_id="123", session_id="session1", message="agent prompt"
)

Expand Down Expand Up @@ -1162,9 +1154,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
)

mock_agent_engine = mock.Mock()
mock_agent_engine.async_create_session = mock.AsyncMock(
return_value={"id": "session1"}
)
mock_agent_engine.create_session.return_value = {"id": "session1"}
stream_query_return_value = [
{
"id": "1",
Expand All @@ -1180,13 +1170,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
},
]

async def _async_iterator(iterable):
for item in iterable:
yield item

mock_agent_engine.async_stream_query.return_value = _async_iterator(
stream_query_return_value
)
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
mock_vertexai_client.return_value.agent_engines.get.return_value = (
mock_agent_engine
)
Expand All @@ -1200,10 +1184,10 @@ async def _async_iterator(iterable):
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
name="projects/test-project/locations/us-central1/reasoningEngines/123"
)
mock_agent_engine.async_create_session.assert_called_once_with(
mock_agent_engine.create_session.assert_called_once_with(
user_id="123", state={"a": "1"}
)
mock_agent_engine.async_stream_query.assert_called_once_with(
mock_agent_engine.stream_query.assert_called_once_with(
user_id="123", session_id="session1", message="agent prompt"
)

Expand Down
20 changes: 9 additions & 11 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,10 @@ def agent_run_wrapper(
and type(agent_engine).__name__ == "AgentEngine"
):
agent_engine_instance = agent_engine
return asyncio.run(
inference_fn_arg(
row=row_arg,
contents=contents_arg,
agent_engine=agent_engine_instance,
)
return inference_fn_arg(
row=row_arg,
contents=contents_arg,
agent_engine=agent_engine_instance,
)

future = executor.submit(
Expand Down Expand Up @@ -1265,7 +1263,7 @@ def _run_agent(
)


async def _execute_agent_run_with_retry(
def _execute_agent_run_with_retry(
row: pd.Series,
contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict],
agent_engine: types.AgentEngine,
Expand All @@ -1287,7 +1285,7 @@ async def _execute_agent_run_with_retry(
)
user_id = session_inputs.user_id
session_state = session_inputs.state
session = await agent_engine.async_create_session(
session = agent_engine.create_session(
user_id=user_id,
state=session_state,
)
Expand All @@ -1298,7 +1296,7 @@ async def _execute_agent_run_with_retry(
for attempt in range(max_retries):
try:
responses = []
async for event in agent_engine.async_stream_query(
for event in agent_engine.stream_query(
user_id=user_id,
session_id=session["id"],
message=contents,
Expand All @@ -1317,7 +1315,7 @@ async def _execute_agent_run_with_retry(
)
if attempt == max_retries - 1:
return {"error": f"Resource exhausted after retries: {e}"}
await asyncio.sleep(2**attempt)
time.sleep(2**attempt)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Unexpected error during generate_content on attempt %d/%d: %s",
Expand All @@ -1328,7 +1326,7 @@ async def _execute_agent_run_with_retry(

if attempt == max_retries - 1:
return {"error": f"Failed after retries: {e}"}
await asyncio.sleep(1)
time.sleep(1)
return {"error": f"Failed to get agent run results after {max_retries} retries"}


Expand Down
Loading