Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions mesa_llm/memory/episodic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def __init__(
"llm_model must be provided for the usage of episodic memory"
)

super().__init__(agent, llm_model=llm_model, api_base=api_base, display=display)
super().__init__(
agent,
llm_model=llm_model,
api_base=api_base,
display=display,
)

self.max_capacity = max_capacity
self.memory_entries = deque(maxlen=self.max_capacity)
Expand Down Expand Up @@ -257,13 +262,17 @@ def get_communication_history(self) -> str:
"""
Get the communication history
"""
return "\n".join(
[
f"step {entry.step}: {entry.content['message']}\n\n"
for entry in self.memory_entries
if "message" in entry.content
]
)
lines = []
for entry in self.memory_entries:
if "message" not in entry.content:
continue
msgs = entry.content["message"]
if isinstance(msgs, list):
for msg in msgs:
lines.append(f"step {entry.step}: {msg}\n\n")
else:
lines.append(f"step {entry.step}: {msgs}\n\n")
return "\n".join(lines)

async def aprocess_step(self, pre_step: bool = False):
"""
Expand Down
28 changes: 24 additions & 4 deletions mesa_llm/memory/lt_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class LongTermMemory(Memory):
agent : the agent that the memory belongs to
display : whether to display the memory
llm_model : the model to use for the summarization
additive_event_types : event types accumulated as lists within a step.
Defaults to ``{"message", "action"}``.

"""

Expand All @@ -23,7 +25,20 @@ def __init__(
display: bool = True,
llm_model: str = "openai/gpt-4o-mini",
api_base: str | None = None,
additive_event_types: list[str] | set[str] | tuple[str, ...] | None = None,
):
"""
Initialize long-term memory.

Args:
agent : the agent that owns this memory
display : whether memory entries should be displayed
llm_model : the model used for long-term summarization
api_base : the API base URL to use for the LLM provider
additive_event_types : event types that accumulate multiple values
within a step instead of overwriting. Defaults to
``{"message", "action"}``.
"""
if not llm_model:
raise ValueError(
"llm_model must be provided for the usage of long term memory"
Expand All @@ -34,6 +49,7 @@ def __init__(
llm_model=llm_model,
api_base=api_base,
display=display,
additive_event_types=additive_event_types,
)

self.long_term_memory = ""
Expand Down Expand Up @@ -95,10 +111,12 @@ def process_step(self, pre_step: bool = False):
return

elif self.buffer and self.buffer.step is None:
self.step_content.update(self.buffer.content)
merged_content = self._merge_step_contents(
self.step_content, self.buffer.content
)
new_entry = MemoryEntry(
agent=self.agent,
content=self.step_content,
content=merged_content,
step=self.agent.model.steps,
)
self.buffer = new_entry
Expand Down Expand Up @@ -126,10 +144,12 @@ async def aprocess_step(self, pre_step: bool = False):
return

elif self.buffer and self.buffer.step is None:
self.step_content.update(self.buffer.content)
merged_content = self._merge_step_contents(
self.step_content, self.buffer.content
)
new_entry = MemoryEntry(
agent=self.agent,
content=self.step_content,
content=merged_content,
step=self.agent.model.steps,
)
self.buffer = new_entry
Expand Down
84 changes: 65 additions & 19 deletions mesa_llm/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def format_nested_dict(data, indent_level=0):
continue

lines.append(f"\n[bold cyan][{key.title()}][/bold cyan]")
if isinstance(value, dict):
lines.extend(format_nested_dict(value, 1))
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(value, list):
for i, item in enumerate(value, 1):
lines.append(f" [blue]({i})[/blue]")
if isinstance(item, dict):
lines.append(f" [blue]├──[/blue] [cyan]({i + 1})[/cyan]")
lines.extend(format_nested_dict(item, 2))
else:
lines.append(f" [blue]├──[/blue] [cyan]{item}[/cyan]")
lines.append(f" [blue]└──[/blue] [cyan]{item}[/cyan]")
elif isinstance(value, dict):
lines.extend(format_nested_dict(value, 1))
else:
lines.append(f" [blue]└──[/blue] [cyan]{value} :[/cyan]")
lines.append(f" [blue]└──[/blue] [cyan]{value}[/cyan]")

content = "\n".join(lines)

Expand All @@ -93,17 +93,24 @@ def display(self):

class Memory(ABC):
"""
Create a memory generic parent class that can be used to create different types of memories
Generic parent class for memory backends.

Attributes:
agent : the agent that the memory belongs to
llm_model : the model to use for the summarization if used
display : whether to display the memory
additive_event_types : event types that accumulate multiple values
within a step. Defaults to ``{"message", "action"}``.

Content Addition
- Before each agent step, the agent can add new events to the memory through `add_to_memory(type, content)` so that the memory can be used to reason about the most recent events as well as the past events.
- During the step, actions, messages, and plans are added to the memory through `add_to_memory(type, content)`
- During the step, content for types in ``additive_event_types`` is accumulated as a list; all other types overwrite the previous value for that step.
- At the end of the step, the memory is processed via `process_step()`, managing when memory entries are added,consolidated, displayed, or removed

Default behavior
- By default, ``additive_event_types == {"message", "action"}``.
- Repeated ``message`` or ``action`` entries within one step are accumulated as a list.
- Repeated ``observation`` or ``plan`` entries within one step overwrite the previous value unless configured otherwise.
"""

def __init__(
Expand All @@ -112,6 +119,7 @@ def __init__(
llm_model: str | None = None,
display: bool = True,
api_base: str | None = None,
additive_event_types: list[str] | set[str] | tuple[str, ...] | None = None,
):
"""
Initialize the memory
Expand All @@ -121,6 +129,11 @@ def __init__(
llm_model : the model to use for summarization
display : whether to display memory entries in the console
api_base : the API base URL to use for the LLM provider
additive_event_types : event types that should accumulate multiple
values within the same step instead of overwriting. Defaults to
``{"message", "action"}``. For example, ``message`` and
``action`` accumulate by default, while ``observation`` and
``plan`` overwrite unless explicitly included here.
"""
self.agent = agent
if llm_model:
Expand All @@ -129,7 +142,7 @@ def __init__(
self.display = display

self.step_content: dict = {}
self.last_observation: dict = {}
self.additive_event_types = set(additive_event_types or {"message", "action"})

@abstractmethod
def get_prompt_ready(self) -> str:
Expand All @@ -156,24 +169,57 @@ def process_step(self, pre_step: bool = False):
async def aprocess_step(self, pre_step: bool = False):
return self.process_step(pre_step)

@staticmethod
def _coerce_additive_values(value):
if isinstance(value, list):
return list(value)
return [value]

def _merge_step_contents(self, current_content: dict, staged_content: dict) -> dict:
"""
Merge the current step buffer with staged pre-step content.

Non-additive keys keep the staged value, matching the previous
overwrite semantics during finalization. Additive event types are
concatenated in chronological order so events from both halves of the
step are preserved.
"""
merged = dict(current_content)
for key, staged_value in staged_content.items():
if key in self.additive_event_types and key in merged:
merged[key] = self._coerce_additive_values(
staged_value
) + self._coerce_additive_values(merged[key])
else:
merged[key] = staged_value
return merged

def add_to_memory(self, type: str, content: dict):
"""
Add a new entry to the memory
Add a new entry to the memory.

Event types in ``self.additive_event_types`` accumulate multiple values
within the same step. All other types use overwrite semantics.
By default, ``self.additive_event_types == {"message", "action"}``.
For example, repeated ``message`` entries are stored as a list, while
repeated ``observation`` entries overwrite the previous value.
"""
if not isinstance(content, dict):
raise TypeError(
"Expected 'content' to be dict, "
f"got {content.__class__.__name__}: {content!r}"
)

if type == "observation":
# Only store changed parts of observation
changed_parts = {
k: v for k, v in content.items() if v != self.last_observation.get(k)
}
if changed_parts:
self.step_content[type] = changed_parts
self.last_observation = content
if type in self.additive_event_types:
# Accumulate discrete events so concurrent entries are preserved
existing = self.step_content.get(type)
if existing is None:
self.step_content[type] = [content]
elif isinstance(existing, list):
existing.append(content)
else:
# Migrate a legacy single-dict entry into a list
self.step_content[type] = [existing, content]
else:
self.step_content[type] = content

Expand Down
31 changes: 22 additions & 9 deletions mesa_llm/memory/st_lt_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class STLTMemory(Memory):
- A short term memory who stores the n (int) most recent interactions (observations, planning, discussions)
- A long term memory that is a summary of the memories that are removed from short term memory (summary
completed/refactored as it goes)
- Event types in ``additive_event_types`` accumulate within a step.
Defaults to ``{"message", "action"}``.

Logic behind the implementation
- **Short-term capacity**: Configurable number of recent memory entries (default: short_term_capacity = 5)
Expand All @@ -34,6 +36,7 @@ def __init__(
display: bool = True,
llm_model: str | None = None,
api_base: str | None = None,
additive_event_types: list[str] | set[str] | tuple[str, ...] | None = None,
):
"""
Initialize the memory
Expand All @@ -43,6 +46,9 @@ def __init__(
llm_model : the model to use for the summarization
api_base : the API base URL to use for the LLM provider
agent : the agent that the memory belongs to
additive_event_types : event types that accumulate multiple values
within a step instead of overwriting. Defaults to
``{"message", "action"}``.
"""
if not llm_model:
raise ValueError(
Expand All @@ -54,6 +60,7 @@ def __init__(
llm_model=llm_model,
api_base=api_base,
display=display,
additive_event_types=additive_event_types,
)

self.capacity = short_term_capacity
Expand Down Expand Up @@ -139,10 +146,12 @@ def _process_step_core(self, pre_step: bool):
return None, []

pre_step_entry = self.short_term_memory.pop()
self.step_content.update(pre_step_entry.content)
merged_content = self._merge_step_contents(
self.step_content, pre_step_entry.content
)
new_entry = MemoryEntry(
agent=self.agent,
content=self.step_content,
content=merged_content,
step=self.agent.model.steps,
)
self.short_term_memory.append(new_entry)
Expand Down Expand Up @@ -224,10 +233,14 @@ def get_communication_history(self) -> str:
"""
Get the communication history
"""
return "\n".join(
[
f"step {entry.step}: {entry.content['message']}\n\n"
for entry in self.short_term_memory
if "message" in entry.content
]
)
lines = []
for entry in self.short_term_memory:
if "message" not in entry.content:
continue
msgs = entry.content["message"]
if isinstance(msgs, list):
for msg in msgs:
lines.append(f"step {entry.step}: {msg}\n\n")
else:
lines.append(f"step {entry.step}: {msgs}\n\n")
return "\n".join(lines)
38 changes: 29 additions & 9 deletions mesa_llm/memory/st_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,35 @@ class ShortTermMemory(Memory):
n : positive number of short-term memories to remember
display : whether to display the memory
llm_model : the model to use for the summarization
additive_event_types : event types accumulated as lists within a step.
Defaults to ``{"message", "action"}``.
"""

def __init__(
self,
agent: "LLMAgent",
n: int = 5,
display: bool = True,
additive_event_types: list[str] | set[str] | tuple[str, ...] | None = None,
):
"""
Initialize short-term memory.

Args:
agent : the agent that owns this memory
n : maximum number of finalized short-term entries to keep
display : whether memory entries should be displayed
additive_event_types : event types that accumulate multiple values
within a step instead of overwriting. Defaults to
``{"message", "action"}``.
"""
if n < 1:
raise ValueError("n must be >= 1 for ShortTermMemory")

super().__init__(
agent=agent,
display=display,
additive_event_types=additive_event_types,
)
self.n = n
self.short_term_memory = deque(maxlen=self.n)
Expand Down Expand Up @@ -61,8 +76,9 @@ def process_step(self, pre_step: bool = False):

new_entry = None
if self._current_step_entry is not None:
merged_content = dict(self.step_content)
merged_content.update(self._current_step_entry.content)
merged_content = self._merge_step_contents(
self.step_content, self._current_step_entry.content
)
new_entry = MemoryEntry(
agent=self.agent,
content=merged_content,
Expand Down Expand Up @@ -98,10 +114,14 @@ def get_communication_history(self) -> str:
"""
Get the communication history
"""
return "\n".join(
[
f"step {entry.step}: {entry.content['message']}\n\n"
for entry in self.short_term_memory
if "message" in entry.content
]
)
lines = []
for entry in self.short_term_memory:
if "message" not in entry.content:
continue
msgs = entry.content["message"]
if isinstance(msgs, list):
for msg in msgs:
lines.append(f"step {entry.step}: {msg}\n\n")
else:
lines.append(f"step {entry.step}: {msgs}\n\n")
return "\n".join(lines)
Loading
Loading