55
66class Memory :
77
8+ _item_cls = AgentMessage
9+
810 def __init__ (self , recent_n = None ) -> None :
911 self .memory : List [AgentMessage ] = []
1012 self .recent_n = recent_n
@@ -24,11 +26,12 @@ def get_memory(
2426 return memory
2527
2628 def add (self , memories : Union [List [Dict ], Dict , None ]) -> None :
27- for memory in memories if isinstance (memories ,
28- (list , tuple )) else [memories ]:
29+ for memory in memories if isinstance (memories , (list , tuple )) else [memories ]:
2930 if isinstance (memory , str ):
30- memory = AgentMessage (sender = 'user' , content = memory )
31+ memory = self . _item_cls (sender = 'user' , content = memory )
3132 if isinstance (memory , AgentMessage ):
33+ if not isinstance (memory , self ._item_cls ):
34+ memory = self ._item_cls .model_validate (memory , from_attributes = True )
3235 self .memory .append (memory )
3336
3437 def delete (self , index : Union [List , int ]) -> None :
@@ -46,10 +49,10 @@ def load(
4649 if overwrite :
4750 self .memory = []
4851 if isinstance (memories , dict ):
49- self .memory .append (AgentMessage ( ** memories ))
52+ self .memory .append (self . _item_cls . model_validate ( memories ))
5053 elif isinstance (memories , list ):
5154 for m in memories :
52- self .memory .append (AgentMessage ( ** m ))
55+ self .memory .append (self . _item_cls . model_validate ( m ))
5356 else :
5457 raise TypeError (f'{ type (memories )} is not supported' )
5558
0 commit comments