Skip to content

Commit b49a275

Browse files
Allow extending the message interface (#305)
allow extending the message interface
1 parent 6f292f4 commit b49a275

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

lagent/memory/base_memory.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
class Memory:
77

8+
_item_cls = AgentMessage
9+
810
def __init__(self, recent_n=None) -> None:
911
self.memory: List[AgentMessage] = []
1012
self.recent_n = recent_n
@@ -24,11 +26,12 @@ def get_memory(
2426
return memory
2527

2628
def add(self, memories: Union[List[Dict], Dict, None]) -> None:
27-
for memory in memories if isinstance(memories,
28-
(list, tuple)) else [memories]:
29+
for memory in memories if isinstance(memories, (list, tuple)) else [memories]:
2930
if isinstance(memory, str):
30-
memory = AgentMessage(sender='user', content=memory)
31+
memory = self._item_cls(sender='user', content=memory)
3132
if isinstance(memory, AgentMessage):
33+
if not isinstance(memory, self._item_cls):
34+
memory = self._item_cls.model_validate(memory, from_attributes=True)
3235
self.memory.append(memory)
3336

3437
def delete(self, index: Union[List, int]) -> None:
@@ -46,10 +49,10 @@ def load(
4649
if overwrite:
4750
self.memory = []
4851
if isinstance(memories, dict):
49-
self.memory.append(AgentMessage(**memories))
52+
self.memory.append(self._item_cls.model_validate(memories))
5053
elif isinstance(memories, list):
5154
for m in memories:
52-
self.memory.append(AgentMessage(**m))
55+
self.memory.append(self._item_cls.model_validate(m))
5356
else:
5457
raise TypeError(f'{type(memories)} is not supported')
5558

0 commit comments

Comments
 (0)