Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 19 additions & 39 deletions mesa_llm/reasoning/react.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import json
from typing import TYPE_CHECKING

from pydantic import BaseModel, Field

from mesa_llm.reasoning.reasoning import Observation, Plan, Reasoning

if TYPE_CHECKING:
from mesa_llm.llm_agent import LLMAgent


class ReActOutput(BaseModel):
reasoning: str = Field(
description="Step-by-step reasoning about the situation based on memory and observation"
)
action: str = Field(description="The specific action to take without using tools")


class ReActReasoning(Reasoning):
"""
Reasoning + Acting with alternating reasoning and action in flexible conversational format. Combines thinking and acting in natural language flow. Less structured than CoT but incorporates memory and communication history.
Expand Down Expand Up @@ -46,10 +36,9 @@ def get_react_system_prompt(self) -> str:
{persona_section}

# Instructions
Based on the information given to you, think about what you should do with proper reasoning, And then decide your plan of action. Respond in the
following format:
reasoning: [Your reasoning about the situation, including how your memory informs your decision]
action: [The action you decide to take - Do NOT use any tools here, just describe the action you will take]
Based on the information given to you, think about what you should do with proper reasoning.
Describe your thought process about the situation, including how your memory informs your decision.
Then, use the provided tools to take action if necessary.

"""
return system_prompt
Expand Down Expand Up @@ -115,25 +104,21 @@ def plan(
selected_tools
)

# ---------------- generate the plan ----------------
# ---------------- generate & execute the plan ----------------
rsp = self.agent.llm.generate(
prompt=prompt_list,
tool_schema=selected_tools_schema,
tool_choice="none",
response_format=ReActOutput,
tool_choice=tool_calls,
system_prompt=react_system_prompt,
)

formatted_response = json.loads(rsp.choices[0].message.content)

self.agent.memory.add_to_memory(type="plan", content=formatted_response)
response_message = rsp.choices[0].message
react_plan = Plan(
step=self.agent.model.steps, llm_plan=response_message, ttl=ttl
)

# ---------------- execute the plan ----------------
react_plan = self.execute_tool_call(
formatted_response["action"],
selected_tools=selected_tools,
ttl=ttl,
tool_calls=tool_calls,
self.agent.memory.add_to_memory(
type="plan", content={"content": str(react_plan)}
)

return react_plan
Expand Down Expand Up @@ -186,26 +171,21 @@ async def aplan(
selected_tools
)

# ---------------- generate the plan ----------------

# ---------------- generate & execute the plan ----------------
rsp = await self.agent.llm.agenerate(
prompt=prompt_list,
tool_schema=selected_tools_schema,
tool_choice="none",
response_format=ReActOutput,
tool_choice=tool_calls,
system_prompt=react_system_prompt,
)

formatted_response = json.loads(rsp.choices[0].message.content)

await self.agent.memory.aadd_to_memory(type="plan", content=formatted_response)
response_message = rsp.choices[0].message
react_plan = Plan(
step=self.agent.model.steps, llm_plan=response_message, ttl=ttl
)

# ---------------- execute the plan ----------------
react_plan = await self.aexecute_tool_call(
formatted_response["action"],
selected_tools=selected_tools,
ttl=ttl,
tool_calls=tool_calls,
await self.agent.memory.aadd_to_memory(
type="plan", content={"content": str(react_plan)}
)

return react_plan
24 changes: 10 additions & 14 deletions tests/test_integration/test_memory_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,36 +364,32 @@ def test_plan_records_to_memory(self, monkeypatch):
"""ReAct plan() stores formatted response to memory."""
agent, memory, reasoning = self._setup(monkeypatch)

rsp_react = make_react_response()
rsp_exec = make_llm_response("executing")
agent.llm.generate = Mock(side_effect=[rsp_react, rsp_exec])
rsp_react = make_llm_response("test reasoning")
agent.llm.generate = Mock(return_value=rsp_react)

obs = Observation(step=0, self_state={}, local_state={})
plan = reasoning.plan(obs=obs)

assert isinstance(plan, Plan)
assert memory.step_content["plan"]["reasoning"] == "test reasoning"
assert memory.step_content["plan"]["action"] == "test action"
assert memory.step_content["plan_execution"]["content"] == str(plan)
assert agent.llm.generate.call_count == 2
assert memory.step_content["plan"]["content"] == str(plan)
assert "plan_execution" not in memory.step_content
assert agent.llm.generate.call_count == 1

def test_async_plan_works(self, monkeypatch):
"""aplan() completes with STLTMemory."""
agent, memory, reasoning = self._setup(monkeypatch)

rsp_react = make_react_response()
rsp_exec = make_llm_response("executing")
agent.llm.agenerate = AsyncMock(side_effect=[rsp_react, rsp_exec])
rsp_react = make_llm_response("test reasoning")
agent.llm.agenerate = AsyncMock(return_value=rsp_react)
agent.memory.aadd_to_memory = AsyncMock(side_effect=memory.add_to_memory)

obs = Observation(step=0, self_state={}, local_state={})
plan = asyncio.run(reasoning.aplan(obs=obs))

assert isinstance(plan, Plan)
assert memory.step_content["plan"]["reasoning"] == "test reasoning"
assert memory.step_content["plan"]["action"] == "test action"
assert memory.step_content["plan_execution"]["content"] == str(plan)
assert agent.llm.agenerate.await_count == 2
assert memory.step_content["plan"]["content"] == str(plan)
assert "plan_execution" not in memory.step_content
assert agent.llm.agenerate.await_count == 1


class TestReActWithShortTermMemory:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,3 +1159,30 @@ def test_system_prompt_proxies_llm_prompt(basic_agent):
basic_agent.llm.system_prompt = "LLM prompt"

assert basic_agent.system_prompt == "LLM prompt"


def test_observation_excludes_system_prompt(monkeypatch):
"""ensure system_prompt is NOT included in the observation state to save tokens."""

class DummyModel(Model):
def __init__(self):
super().__init__(rng=42)
self.grid = MultiGrid(3, 3, torus=False)

model = DummyModel()
system_prompt = "DUMMY_SYSTEM_PROMPT"
agent = LLMAgent(
model=model,
reasoning=ReActReasoning,
system_prompt=system_prompt,
vision=-1,
)
model.grid.place_agent(agent, (1, 1))

# bypass memory
monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None)

obs = agent.generate_obs()

assert "system_prompt" not in obs.self_state
assert system_prompt not in str(obs.self_state)
101 changes: 21 additions & 80 deletions tests/test_reasoning/test_react.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,12 @@
# tests/test_reasoning/test_react.py

import asyncio
import json
from unittest.mock import AsyncMock, Mock

import pytest

from mesa_llm.reasoning.react import ReActOutput, ReActReasoning
from mesa_llm.reasoning.reasoning import Observation, Plan


class TestReActOutput:
"""Test the ReActOutput model."""

def test_react_output_creation(self):
"""Test creating a ReActOutput with valid data."""
output = ReActOutput(
reasoning="I need to move to a better position", action="move_north"
)

assert output.reasoning == "I need to move to a better position"
assert output.action == "move_north"

def test_react_output_schema_includes_field_descriptions(self):
"""Structured output schema should keep the field guidance text."""
schema = ReActOutput.model_json_schema()

assert schema["properties"]["reasoning"]["description"] == (
"Step-by-step reasoning about the situation based on memory and observation"
)
assert schema["properties"]["action"]["description"] == (
"The specific action to take without using tools"
)
from mesa_llm.reasoning.react import ReActReasoning
from mesa_llm.reasoning.reasoning import Observation


class TestReActReasoning:
Expand All @@ -52,8 +27,6 @@ def test_get_react_system_prompt(self, mock_agent):

assert "Agent Persona" in prompt
assert "Agent persona" in prompt
assert "reasoning:" in prompt
assert "action:" in prompt

def test_get_react_system_prompt_omits_empty_persona(self, mock_agent):
"""Empty agent persona should not add a persona section."""
Expand Down Expand Up @@ -103,26 +76,17 @@ def test_plan_with_prompt(self, llm_response_factory, mock_agent):
mock_agent.tool_manager.get_all_tools_schema.return_value = {}

mock_agent.llm.generate.return_value = llm_response_factory(
content=json.dumps(
{"reasoning": "Custom reasoning", "action": "custom_action"}
)
content="Custom reasoning"
)

# Mock execute_tool_call
mock_plan = Plan(step=1, llm_plan=Mock())
reasoning = ReActReasoning(mock_agent)
reasoning.execute_tool_call = Mock(return_value=mock_plan)

obs = Observation(step=1, self_state={}, local_state={})
result = reasoning.plan(obs=obs, prompt="Custom prompt")

assert result == mock_plan
reasoning.execute_tool_call.assert_called_once_with(
"custom_action",
selected_tools=None,
ttl=1,
tool_calls="auto",
)
assert result.step == mock_agent.model.steps
assert result.llm_plan.content == "Custom reasoning"
assert result.ttl == 1

def test_plan_with_selected_tools(self, llm_response_factory, mock_agent):
"""Test plan method with selected tools."""
Expand All @@ -136,26 +100,19 @@ def test_plan_with_selected_tools(self, llm_response_factory, mock_agent):
mock_agent.tool_manager.get_all_tools_schema.return_value = {}

mock_agent.llm.generate.return_value = llm_response_factory(
content=json.dumps({"reasoning": "Test reasoning", "action": "test_action"})
content="Test reasoning"
)

# Mock execute_tool_call
mock_plan = Plan(step=1, llm_plan=Mock())
reasoning = ReActReasoning(mock_agent)
reasoning.execute_tool_call = Mock(return_value=mock_plan)

obs = Observation(step=1, self_state={}, local_state={})
selected_tools = ["tool1", "tool2"]
result = reasoning.plan(obs=obs, ttl=3, selected_tools=selected_tools)

assert result == mock_plan
assert result.step == mock_agent.model.steps
assert result.llm_plan.content == "Test reasoning"
assert result.ttl == 3
mock_agent.tool_manager.get_all_tools_schema.assert_called_with(selected_tools)
reasoning.execute_tool_call.assert_called_once_with(
"test_action",
selected_tools=selected_tools,
ttl=3,
tool_calls="auto",
)

def test_plan_with_custom_tool_calls(self, llm_response_factory, mock_agent):
"""Test plan method forwards a custom execution tool choice."""
Expand All @@ -169,23 +126,18 @@ def test_plan_with_custom_tool_calls(self, llm_response_factory, mock_agent):
mock_agent.tool_manager.get_all_tools_schema.return_value = {}

mock_agent.llm.generate.return_value = llm_response_factory(
content=json.dumps({"reasoning": "Test reasoning", "action": "test_action"})
content="Test reasoning"
)

mock_plan = Plan(step=1, llm_plan=Mock())
reasoning = ReActReasoning(mock_agent)
reasoning.execute_tool_call = Mock(return_value=mock_plan)

obs = Observation(step=1, self_state={}, local_state={})
result = reasoning.plan(obs=obs, tool_calls="required")

assert result == mock_plan
reasoning.execute_tool_call.assert_called_once_with(
"test_action",
selected_tools=None,
ttl=1,
tool_calls="required",
)
assert result.step == mock_agent.model.steps
assert result.llm_plan.content == "Test reasoning"
assert result.ttl == 1
assert mock_agent.llm.generate.call_args.kwargs["tool_choice"] == "required"

def test_plan_no_prompt_error(self, mock_agent):
"""Test plan method raises error when no prompt is provided."""
Expand Down Expand Up @@ -215,31 +167,21 @@ def test_aplan_async_version(self, llm_response_factory, mock_agent):
mock_agent.tool_manager.get_all_tools_schema.return_value = {}

mock_agent.llm.agenerate = AsyncMock(
return_value=llm_response_factory(
content=json.dumps(
{"reasoning": "Async reasoning", "action": "async_action"}
)
)
return_value=llm_response_factory(content="Async reasoning")
)

# Mock aexecute_tool_call
mock_plan = Plan(step=1, llm_plan=Mock())
reasoning = ReActReasoning(mock_agent)
reasoning.aexecute_tool_call = AsyncMock(return_value=mock_plan)

obs = Observation(step=1, self_state={}, local_state={})

# Test async execution
result = asyncio.run(reasoning.aplan(obs=obs, ttl=4))

assert result == mock_plan
assert result.step == mock_agent.model.steps
assert result.llm_plan.content == "Async reasoning"
assert result.ttl == 4
mock_agent.llm.agenerate.assert_called_once()
reasoning.aexecute_tool_call.assert_called_once_with(
"async_action",
selected_tools=None,
ttl=4,
tool_calls="auto",
)
assert mock_agent.llm.agenerate.call_args.kwargs["tool_choice"] == "auto"

def test_aplan_no_prompt_error(self, mock_agent):
"""Test aplan method raises error when no prompt is provided."""
Expand Down Expand Up @@ -268,11 +210,10 @@ def test_plan_uses_scoped_system_prompt(self, llm_response_factory, mock_agent):
mock_agent.tool_manager.get_all_tools_schema.return_value = {}

mock_agent.llm.generate.return_value = llm_response_factory(
content=json.dumps({"reasoning": "Test reasoning", "action": "test_action"})
content="Test reasoning"
)

reasoning = ReActReasoning(mock_agent)
reasoning.execute_tool_call = Mock(return_value=Plan(step=1, llm_plan=Mock()))

obs = Observation(step=1, self_state={}, local_state={})
expected_prompt = reasoning.get_react_system_prompt()
Expand Down
Loading