Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 129 additions & 111 deletions mesa_llm/llm_agent.py

Large diffs are not rendered by default.

95 changes: 54 additions & 41 deletions mesa_llm/module_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Any

from dotenv import load_dotenv
from litellm import acompletion, completion, litellm
Expand All @@ -22,33 +23,34 @@

class ModuleLLM:
"""
A module that provides a simple interface for using LLMs
A module that provides a simple interface for using LLMs.

Note : Currently supports OpenAI, Anthropic, xAI, Huggingface, Ollama, OpenRouter, NovitaAI, Gemini
Note: Currently supports OpenAI, Anthropic, xAI, Huggingface,
Ollama, OpenRouter, NovitaAI, Gemini.
"""

def __init__(
self,
llm_model: str,
api_base: str | None = None,
system_prompt: str | None = None,
):
) -> None:
"""
Initialize the LLM module
Initialize the LLM module.

Args:
llm_model: The model to use for the LLM in the format
"{provider}/{model}" (for example, "openai/gpt-4o").
api_base: The API base to use if the LLM provider is Ollama
system_prompt: The system prompt to use for the LLM
llm_model (str): The model to use in the format "{provider}/{model}"
(for example, "openai/gpt-4o").
api_base (str | None): The API base URL. Required for Ollama providers.
system_prompt (str | None): The system prompt to use for the LLM.

Raises:
ValueError: If llm_model is not in the expected "{provider}/{model}"
format, or if the provider API key is missing.
"""
self.api_base = api_base
self.llm_model = llm_model
self.system_prompt = system_prompt
self.api_base: str | None = api_base
self.llm_model: str = llm_model
self.system_prompt: str | None = system_prompt

if "/" not in llm_model:
raise ValueError(
Expand All @@ -62,44 +64,48 @@ def __init__(
if self.api_base is None:
self.api_base = "http://localhost:11434"
logger.warning(
"Using default Ollama API base: %s. If inference is not working, you may need to set the API base to the correct URL.",
"Using default Ollama API base: %s. If inference is not working, "
"you may need to set the API base to the correct URL.",
self.api_base,
)
else:
try:
self.api_key = os.environ[f"{provider}_API_KEY"]
self.api_key: str = os.environ[f"{provider}_API_KEY"]
except KeyError as err:
raise ValueError(
f"No API key found for {provider}. Please set the {provider}_API_KEY environment variable (e.g., in your .env file)."
f"No API key found for {provider}. Please set the "
f"{provider}_API_KEY environment variable (e.g., in your .env file)."
) from err

if not litellm.supports_function_calling(model=self.llm_model):
logger.warning(
"%s does not support function calling. This model may not be able to use tools. Please check the model documentation at https://docs.litellm.ai/docs/providers for more information.",
"%s does not support function calling. This model may not be able "
"to use tools. Please check the model documentation at "
"https://docs.litellm.ai/docs/providers for more information.",
self.llm_model,
)

def _build_messages(self, prompt: str | list[str] | None = None) -> list[dict]:
def _build_messages(
self, prompt: str | list[str] | None = None
) -> list[dict[str, str]]:
"""
Format the prompt messages for the LLM of the form : {"role": ..., "content": ...}
Format the prompt messages for the LLM.

Args:
prompt: The prompt to generate a response for (str, list of strings, or None)
prompt (str | list[str] | None): The prompt to generate a response for.

Returns:
The messages for the LLM
list[dict[str, str]]: Messages in {"role": ..., "content": ...} format.
"""
messages = []
messages: list[dict[str, str]] = []

# Always include a system message. Default to empty string if no system prompt to support Ollama
system_content = self.system_prompt if self.system_prompt else ""
messages.append({"role": "system", "content": system_content})

if prompt:
if isinstance(prompt, str):
messages.append({"role": "user", "content": prompt})
elif isinstance(prompt, list):
# Use extend to add all prompts from the list
messages.extend([{"role": "user", "content": p} for p in prompt])

return messages
Expand All @@ -112,26 +118,25 @@ def _build_messages(self, prompt: str | list[str] | None = None) -> list[dict]:
def generate(
self,
prompt: str | list[str] | None = None,
tool_schema: list[dict] | None = None,
tool_schema: list[dict[str, Any]] | None = None,
tool_choice: str = "auto",
response_format: dict | object | None = None,
) -> str:
response_format: dict[str, Any] | object | None = None,
) -> Any:
"""
Generate a response from the LLM using litellm based on the prompt
Generate a response from the LLM using litellm.

Args:
prompt: The prompt to generate a response for (str, list of strings, or None)
tool_schema: The schema of the tools to use
tool_choice: The choice of tool to use
response_format: The format of the response
prompt (str | list[str] | None): The prompt to generate a response for.
tool_schema (list[dict[str, Any]] | None): Schema of tools available.
tool_choice (str): Tool selection strategy. Defaults to "auto".
response_format (dict[str, Any] | object | None): Desired response format.

Returns:
The response from the LLM
Any: The raw litellm response object.
"""

messages = self._build_messages(prompt)

completion_kwargs = {
completion_kwargs: dict[str, Any] = {
"model": self.llm_model,
"messages": messages,
"tools": tool_schema,
Expand All @@ -141,28 +146,36 @@ def generate(
if self.api_base:
completion_kwargs["api_base"] = self.api_base

response = completion(**completion_kwargs)

return response
return completion(**completion_kwargs)

async def agenerate(
self,
prompt: str | list[str] | None = None,
tool_schema: list[dict] | None = None,
tool_schema: list[dict[str, Any]] | None = None,
tool_choice: str = "auto",
response_format: dict | object | None = None,
) -> str:
response_format: dict[str, Any] | object | None = None,
) -> Any:
"""
Asynchronous version of generate() method for parallel LLM calls.
Asynchronous version of generate() for parallel LLM calls.

Args:
prompt (str | list[str] | None): The prompt to generate a response for.
tool_schema (list[dict[str, Any]] | None): Schema of tools available.
tool_choice (str): Tool selection strategy. Defaults to "auto".
response_format (dict[str, Any] | object | None): Desired response format.

Returns:
Any: The raw litellm response object.
"""
messages = self._build_messages(prompt)
response: Any = None
async for attempt in AsyncRetrying(
wait=wait_exponential(multiplier=1, min=1, max=60),
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
reraise=True,
):
with attempt:
completion_kwargs = {
completion_kwargs: dict[str, Any] = {
"model": self.llm_model,
"messages": messages,
"tools": tool_schema,
Expand Down
6 changes: 4 additions & 2 deletions mesa_llm/recording/record_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _auto_save():
def step_wrapper(self: "Model", *args, **kwargs): # type: ignore[override]
# Record beginning of step
if hasattr(self, "recorder"):
self.recorder.record_model_event("step_start", {"step": self.steps}) # type: ignore[attr-defined]
# self.recorder.record_model_event("step_start", {"step": self.step}) # type: ignore[attr-defined]
self.recorder.record_model_event("step_start", {"step": self._time})

# Execute the original step logic
result = original_step(self, *args, **kwargs) # type: ignore[misc]
Expand All @@ -110,7 +111,8 @@ def step_wrapper(self: "Model", *args, **kwargs): # type: ignore[override]
if hasattr(self, "recorder"):
_attach_recorder_to_agents(self, self.recorder) # type: ignore[attr-defined]
# Record end of step after agents have acted
self.recorder.record_model_event("step_end", {"step": self.steps}) # type: ignore[attr-defined]
# self.recorder.record_model_event("step_end", {"step": self.step}) # type: ignore[attr-defined]
self.recorder.record_model_event("step_end", {"step": self._time})

return result

Expand Down
45 changes: 11 additions & 34 deletions mesa_llm/tools/inbuilt_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
OrthogonalMooreGrid,
OrthogonalVonNeumannGrid,
)
from mesa.space import (
ContinuousSpace,
MultiGrid,
SingleGrid,
)
from mesa.experimental.continuous_space import ContinuousSpace

from mesa_llm.tools.tool_decorator import tool

Expand Down Expand Up @@ -63,7 +59,7 @@ def _get_agent_position(agent: "LLMAgent") -> Any:
@tool
def move_one_step(agent: "LLMAgent", direction: str) -> str:
"""
Moves agents one step in specified cardinal/diagonal directions (North, South, East, West, NorthEast, NorthWest, SouthEast, SouthWest). Automatically handles different Mesa grid types including SingleGrid, MultiGrid, OrthogonalGrids, and ContinuousSpace.
Moves agents one step in specified cardinal/diagonal directions (North, South, East, West, NorthEast, NorthWest, SouthEast, SouthWest). Automatically handles different Mesa grid types including OrthogonalGrids and ContinuousSpace.

Args:
direction: The direction to move in. Must be one of:
Expand Down Expand Up @@ -113,40 +109,25 @@ def move_one_step(agent: "LLMAgent", direction: str) -> str:
return teleport_to_location(agent, target_coordinates)

space = getattr(agent.model, "space", None)
grid_or_space = None
if isinstance(grid, SingleGrid | MultiGrid):
grid_or_space = grid
elif isinstance(space, ContinuousSpace):
grid_or_space = space

if grid_or_space is not None:
if isinstance(space, ContinuousSpace):
dx, dy = direction_map_xy[direction]
x, y = _get_agent_position(agent)
new_pos = (x + dx, y + dy)

if grid_or_space.torus:
new_pos = grid_or_space.torus_adj(new_pos)
elif grid_or_space.out_of_bounds(new_pos):
if space.torus:
new_pos = space.torus_adj(new_pos)
elif space.out_of_bounds(new_pos):
return (
f"Agent {agent.unique_id} is at the boundary and cannot move "
f"{direction}. Try a different direction."
)

if isinstance(grid_or_space, SingleGrid) and not grid_or_space.is_cell_empty(
new_pos
):
return (
f"Agent {agent.unique_id} cannot move {direction} because "
"the target cell is occupied."
)

target_coordinates = tuple(new_pos)
return teleport_to_location(agent, target_coordinates)

raise ValueError(
"Unsupported environment for move_one_step. Expected SingleGrid, "
"MultiGrid, OrthogonalMooreGrid, OrthogonalVonNeumannGrid, or "
"ContinuousSpace."
"Unsupported environment for move_one_step. Expected "
"OrthogonalMooreGrid, OrthogonalVonNeumannGrid, or ContinuousSpace."
)


Expand All @@ -168,21 +149,17 @@ def teleport_to_location(
"""
target_coordinates = tuple(target_coordinates)

if isinstance(agent.model.grid, SingleGrid | MultiGrid):
agent.model.grid.move_agent(agent, target_coordinates)

elif isinstance(agent.model.grid, OrthogonalMooreGrid | OrthogonalVonNeumannGrid):
if isinstance(agent.model.grid, OrthogonalMooreGrid | OrthogonalVonNeumannGrid):
cell = agent.model.grid._cells[target_coordinates]
agent.cell = cell

elif isinstance(agent.model.space, ContinuousSpace):
agent.model.space.move_agent(agent, target_coordinates)
agent.pos = target_coordinates

else:
raise ValueError(
"Unsupported environment for teleport_to_location. Expected "
"SingleGrid, MultiGrid, OrthogonalMooreGrid, "
"OrthogonalVonNeumannGrid, or ContinuousSpace."
"OrthogonalMooreGrid, OrthogonalVonNeumannGrid, or ContinuousSpace."
)

return f"agent {agent.unique_id} moved to {target_coordinates}."
Expand Down
15 changes: 2 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import pytest
from litellm import Choices, Message, ModelResponse

# from mesa.space import MultiGrid
from mesa.model import Model
from mesa.space import MultiGrid

from mesa_llm.llm_agent import LLMAgent
from mesa_llm.memory.st_memory import ShortTermMemory
Expand Down Expand Up @@ -96,18 +97,6 @@ def basic_model():
return Model(rng=42)


@pytest.fixture
def grid_model():
"""Create model with MultiGrid"""

class GridModel(Model):
def __init__(self):
super().__init__(rng=42)
self.grid = MultiGrid(10, 10, torus=False)

return GridModel()


@pytest.fixture
def basic_agent(basic_model):
"""Create single agent with memory"""
Expand Down
Loading