Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autogen/beta/config/gemini/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def build_tools(schemas: list[ToolSchema]) -> list[types.Tool] | None:
types.FunctionDeclaration(
name=t.function.name,
description=t.function.description,
parameters=_ensure_object_schema(t.function.parameters),
parameters_json_schema=_ensure_object_schema(t.function.parameters),
)
)

Expand Down
5 changes: 4 additions & 1 deletion autogen/beta/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .final import Toolkit, tool
from .shell import LocalShellTool
from .skills import SkillSearchToolkit, SkillsToolkit
from .toolkits import FilesystemToolkit
from .toolkits import FilesystemToolkit, MCPServer, MCPServerConfig, MCPStdioServerConfig

__all__ = (
"CodeExecutionTool",
Expand All @@ -31,7 +31,10 @@
"FilesystemToolkit",
"ImageGenerationTool",
"LocalShellTool",
"MCPServer",
"MCPServerConfig",
"MCPServerTool",
"MCPStdioServerConfig",
"MemoryTool",
"NetworkPolicy",
"ShellTool",
Expand Down
8 changes: 7 additions & 1 deletion autogen/beta/tools/toolkits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,11 @@
# SPDX-License-Identifier: Apache-2.0

from .filesystem import FilesystemToolkit
from .mcp_server import MCPServer, MCPServerConfig, MCPStdioServerConfig

__all__ = ("FilesystemToolkit",)
__all__ = (
"FilesystemToolkit",
"MCPServer",
"MCPServerConfig",
"MCPStdioServerConfig",
)
12 changes: 12 additions & 0 deletions autogen/beta/tools/toolkits/mcp_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2026, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0

from .toolkit import MCPServer
from .types import MCPServerConfig, MCPStdioServerConfig

__all__ = (
"MCPServer",
"MCPServerConfig",
"MCPStdioServerConfig",
)
256 changes: 256 additions & 0 deletions autogen/beta/tools/toolkits/mcp_server/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright (c) 2026, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import AsyncIterator, Iterable
from contextlib import ExitStack, asynccontextmanager
from dataclasses import replace
from typing import Any

import httpx
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamable_http_client
from mcp.types import CallToolResult, TextContent
from mcp.types import Tool as MCPTool

from autogen.beta.annotations import Context, Variable
from autogen.beta.events.tool_events import (
ToolCallEvent,
ToolErrorEvent,
ToolResultEvent,
)
from autogen.beta.middleware import BaseMiddleware, ToolExecution, ToolMiddleware, ToolResultType
from autogen.beta.tools.final import Toolkit
from autogen.beta.tools.final.function_tool import FunctionDefinition, FunctionToolSchema
from autogen.beta.tools.tool import Tool

from .types import MCPServerConfig, MCPStdioServerConfig

AnyMCPConfig = MCPServerConfig | MCPStdioServerConfig


@asynccontextmanager
async def _mcp_session(config: AnyMCPConfig) -> AsyncIterator[ClientSession]:
"""Open a short-lived MCP ``ClientSession`` for one operation.

Dispatches on the config type — HTTP/streamable-http for
:class:`MCPServerConfig`, stdio subprocess for :class:`MCPStdioServerConfig`.
"""
if isinstance(config, MCPStdioServerConfig):
params = StdioServerParameters(
command=config.command, # type: ignore[arg-type]
args=list(config.args or []), # type: ignore[arg-type]
env=config.env, # type: ignore[arg-type]
cwd=config.cwd, # type: ignore[arg-type]
encoding=config.encoding,
)
async with (
stdio_client(params) as (read_stream, write_stream),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()
yield session
else:
async with (
streamable_http_client(
config.server_url,
http_client=httpx.AsyncClient(
headers=config.headers,
timeout=config.connection_timeout,
),
) as (read_stream, write_stream, _),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()
yield session


class _MCPProxyTool(Tool):
"""A function-tool-shaped proxy that forwards calls to a remote MCP server."""

__slots__ = ("name", "schema", "_config", "_middleware")

def __init__(
self,
config: AnyMCPConfig,
raw_tool: MCPTool,
middleware: tuple[ToolMiddleware, ...] = (),
) -> None:
self._config = config
self._middleware = middleware
self.name = raw_tool.name
self.schema = FunctionToolSchema(
function=FunctionDefinition(
name=self.name,
description=raw_tool.description or "",
parameters=dict(raw_tool.inputSchema or {}),
)
)

async def schemas(self, context: "Context") -> list[FunctionToolSchema]:
return [self.schema]

def register(
self,
stack: "ExitStack",
context: "Context",
*,
middleware: Iterable["BaseMiddleware"] = (),
) -> None:
execution: ToolExecution = self
for hook in reversed(self._middleware):
execution = _wrap_middleware(hook, execution)
for mw in middleware:
execution = _wrap_middleware(mw.on_tool_execution, execution)

async def execute(event: "ToolCallEvent", context: "Context") -> None:
result = await execution(event, context)
await context.send(result)

stack.enter_context(context.stream.where(ToolCallEvent.name == self.name).sub_scope(execute))

async def __call__(self, event: "ToolCallEvent", context: "Context") -> "ToolResultEvent | ToolErrorEvent":
try:
async with _mcp_session(self._config) as session:
result = await session.call_tool(self.name, event.serialized_arguments)
except Exception as e:
return ToolErrorEvent.from_call(event, error=e)

if result.isError:
return ToolErrorEvent.from_call(event, error=RuntimeError(_extract_content(result)))

return ToolResultEvent.from_call(event, result=_extract_content(result))


class MCPServer(Toolkit):
"""Expose the tools of an MCP server as ordinary local tools.

Accepts either:

* a URL string or :class:`MCPServerConfig` for a remote (streamable-http)
server, or
* an :class:`MCPStdioServerConfig` for a locally-launched server
communicating over stdin/stdout.

Tool discovery is lazy: the first call to :meth:`schemas` performs the
MCP handshake, lists the server's tools, and registers a proxy for each
one. The agent never sees that these are MCP tools — they look and behave
like ordinary :class:`FunctionTool` instances.
"""

__slots__ = ("config", "_discovered")

def __init__(
self,
server: str | MCPServerConfig | MCPStdioServerConfig,
*,
middleware: Iterable[ToolMiddleware] = (),
) -> None:
if isinstance(server, str):
server = MCPServerConfig(server_url=server)
self.config: AnyMCPConfig = server
self._discovered = False

label = server.server_label if isinstance(server.server_label, str) else ""
super().__init__(
name=label or "mcp_toolkit",
middleware=middleware,
)

async def schemas(self, context: "Context") -> Iterable[FunctionToolSchema]:
await self._discover_tools(context)
return await super().schemas(context)

async def _discover_tools(self, context: "Context") -> None:
if self._discovered:
return

resolved = _resolve_config(self.config, context)

async with _mcp_session(resolved) as session:
raw_tools = (await session.list_tools()).tools

allowed = resolved.allowed_tools
blocked = set(resolved.blocked_tools or [])

for raw in raw_tools:
if allowed is not None and raw.name not in allowed:
continue
if raw.name in blocked:
continue
self.tools.append(
_MCPProxyTool(
config=resolved,
raw_tool=raw,
middleware=self._middleware,
)
)

self._discovered = True


def _wrap_middleware(hook: "ToolMiddleware", inner: "ToolExecution") -> "ToolExecution":
async def call(event: "ToolCallEvent", context: "Context") -> "ToolResultType":
return await hook(inner, event, context)

return call


def _extract_content(result: CallToolResult) -> str:
"""Flatten an MCP ``tools/call`` result into a string for the model."""
parts = result.content
if not parts:
return result.model_dump_json(exclude_none=True)

chunks: list[str] = []
for p in parts:
if isinstance(p, TextContent):
chunks.append(p.text)
else:
chunks.append(p.model_dump_json(exclude_none=True))
return "\n".join(chunks)


def _resolve_value(value: Any, context: "Context") -> Any:
if not isinstance(value, Variable):
return value
name = value.name
if name in context.variables:
return context.variables[name]
if value.default is not Ellipsis:
return value.default
if value.default_factory is not Ellipsis:
return value.default_factory()
raise KeyError(f"Context variable {name!r} not found and no default provided")


def _resolve_config(config: AnyMCPConfig, context: "Context") -> AnyMCPConfig:
if isinstance(config, MCPStdioServerConfig):
return replace(
config,
command=_resolve_value(config.command, context),
args=list(_resolve_value(config.args, context) or []),
env=_resolve_value(config.env, context),
cwd=_resolve_value(config.cwd, context),
server_label=_resolve_value(config.server_label, context) or "",
description=_resolve_value(config.description, context),
allowed_tools=_resolve_value(config.allowed_tools, context),
blocked_tools=_resolve_value(config.blocked_tools, context),
)

headers = dict(_resolve_value(config.headers, context) or {})
auth = _resolve_value(config.authorization_token, context)
if auth and "Authorization" not in headers:
headers["Authorization"] = f"Bearer {auth}"

return replace(
config,
server_url=_resolve_value(config.server_url, context),
server_label=_resolve_value(config.server_label, context) or "",
authorization_token=auth,
description=_resolve_value(config.description, context),
allowed_tools=_resolve_value(config.allowed_tools, context),
blocked_tools=_resolve_value(config.blocked_tools, context),
headers=headers or None,
)
48 changes: 48 additions & 0 deletions autogen/beta/tools/toolkits/mcp_server/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2026, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass, field
from pathlib import Path

from autogen.beta.annotations import Variable


@dataclass
class MCPServerConfig:
"""
Configuration for a remote (HTTP / streamable-http) MCP server.
It's important to specify AUTH headers as most MCP servers force auth nowadays.
"""

server_url: str | Variable
server_label: str | Variable = ""
authorization_token: str | Variable | None = None
description: str | Variable | None = None
allowed_tools: list[str] | Variable | None = None
blocked_tools: list[str] | Variable | None = None
headers: dict[str, str] | Variable | None = None
connection_timeout: float = 30.0


@dataclass
class MCPStdioServerConfig:
"""
Configuration for a local MCP server that communicates over stdin/stdout.

The server is launched as a subprocess (``command`` + ``args``) and the
MCP protocol is spoken across its stdio pipes. Use this for locally
installed MCP servers shipped as CLIs (e.g. ``npx -y @some/mcp-server``,
``uvx some-mcp-server``, or a script in your project).
"""

command: str | Variable
args: list[str] | Variable = field(default_factory=list)
env: dict[str, str] | Variable | None = None
cwd: str | Path | Variable | None = None
server_label: str | Variable = ""
description: str | Variable | None = None
allowed_tools: list[str] | Variable | None = None
blocked_tools: list[str] | Variable | None = None
encoding: str = "utf-8"
6 changes: 3 additions & 3 deletions test/beta/config/gemini/tools/test_tool_to_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_tool_to_api() -> None:
types.FunctionDeclaration(
name=schema.function.name,
description=schema.function.description,
parameters=schema.function.parameters,
parameters_json_schema=schema.function.parameters,
)
]
)
Expand All @@ -43,7 +43,7 @@ def test_parameterless_tool_empty_dict_gets_object_schema() -> None:
types.FunctionDeclaration(
name="list_skills",
description="List installed skills.",
parameters={"type": "object", "properties": {}},
parameters_json_schema={"type": "object", "properties": {}},
)
]
)
Expand All @@ -67,7 +67,7 @@ def test_parameterless_tool_null_type_gets_object_schema() -> None:
types.FunctionDeclaration(
name="list_skills",
description="List installed skills.",
parameters={"type": "object", "properties": {}},
parameters_json_schema={"type": "object", "properties": {}},
)
]
)
Expand Down
2 changes: 1 addition & 1 deletion test/beta/config/gemini/tools/test_web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def test_mixed_with_function_tool(context: Context) -> None:
types.FunctionDeclaration(
name=func_schema.function.name,
description=func_schema.function.description,
parameters=func_schema.function.parameters,
parameters_json_schema=func_schema.function.parameters,
)
]
),
Expand Down
Loading