Skip to content
Merged
4 changes: 4 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3805,6 +3805,10 @@ class OrganizationMemberUpdateResponse(MemberUpdateResponse):

class TeamInfoResponseObjectTeamTable(LiteLLM_TeamTable):
team_member_budget_table: Optional[LiteLLM_BudgetTable] = None
# Resources inherited from access groups (separate from direct assignments)
access_group_models: Optional[List[str]] = None
access_group_mcp_server_ids: Optional[List[str]] = None
access_group_agent_ids: Optional[List[str]] = None


class TeamInfoResponseObject(TypedDict):
Expand Down
67 changes: 67 additions & 0 deletions litellm/proxy/management_endpoints/team_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3042,6 +3042,19 @@ async def team_info(
team_info_response_object=_team_info,
)

# Resolve resources inherited from access groups
if _team_info.access_group_ids:
ag_lookup = await _batch_resolve_access_group_resources(_team_info.access_group_ids)
models, mcp_ids, agent_ids = set(), set(), set()
for ag_id in _team_info.access_group_ids:
if ag_id in ag_lookup:
models.update(ag_lookup[ag_id]["models"])
mcp_ids.update(ag_lookup[ag_id]["mcp_server_ids"])
agent_ids.update(ag_lookup[ag_id]["agent_ids"])
_team_info.access_group_models = list(models)
_team_info.access_group_mcp_server_ids = list(mcp_ids)
_team_info.access_group_agent_ids = list(agent_ids)

response_object = TeamInfoResponseObject(
team_id=team_id,
team_info=_team_info,
Expand Down Expand Up @@ -3332,6 +3345,36 @@ async def _build_team_list_where_conditions(
return where_conditions


async def _batch_resolve_access_group_resources(
all_access_group_ids: List[str],
) -> Dict[str, Dict[str, List[str]]]:
"""
Batch-fetch access groups in a single DB query and return a per-group
resource map.

Returns {ag_id: {"models": [...], "mcp_server_ids": [...], "agent_ids": [...]}}.
Missing/invalid groups are silently omitted.
"""
from litellm.proxy.proxy_server import prisma_client as _prisma_client

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
litellm.proxy.proxy_server
begins an import cycle.
Import of module
proxy_server
begins an import cycle.

Copilot Autofix

AI 14 days ago

General approach: Break the cycle by removing the import from team_endpoints to proxy_server while preserving behavior. Instead of pulling prisma_client, proxy_logging_obj, and user_api_key_cache from proxy_server, use dependencies or utilities that are already available in this module, or re-acquire what is needed via other, non-cyclic means.

Best concrete fix here: resolve_access_group_inherited_resources only uses those three objects to call get_access_object, which is already imported from litellm.proxy.auth.auth_checks. If get_access_object can operate with None or omitted prisma_client / cache / logging arguments, we can simply stop importing them from proxy_server and call get_access_object without those keyword arguments. That removes the cycle without changing higher-level functionality, and keeps the function focused on resolving resources via the existing auth helper.

So, in litellm/proxy/management_endpoints/team_endpoints.py, within resolve_access_group_inherited_resources:

  • Delete the three local imports:
from litellm.proxy.proxy_server import prisma_client as _prisma_client
from litellm.proxy.proxy_server import proxy_logging_obj as _proxy_logging_obj
from litellm.proxy.proxy_server import user_api_key_cache as _user_api_key_cache
  • Remove the _user_api_key_cache None guard, since that variable will no longer exist.
  • Update the get_access_object call to only pass access_group_id=ag_id (no prisma_client, user_api_key_cache, or proxy_logging_obj keyword arguments).

No new methods or imports are needed; we only rely on the already-imported get_access_object and verbose_proxy_logger.

Suggested changeset 1
litellm/proxy/management_endpoints/team_endpoints.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py
--- a/litellm/proxy/management_endpoints/team_endpoints.py
+++ b/litellm/proxy/management_endpoints/team_endpoints.py
@@ -3361,25 +3361,13 @@
     if not access_group_ids:
         return empty
 
-    from litellm.proxy.proxy_server import prisma_client as _prisma_client
-    from litellm.proxy.proxy_server import proxy_logging_obj as _proxy_logging_obj
-    from litellm.proxy.proxy_server import user_api_key_cache as _user_api_key_cache
-
-    if _user_api_key_cache is None:
-        return empty
-
     models: List[str] = []
     mcp_ids: List[str] = []
     agent_ids: List[str] = []
 
     for ag_id in access_group_ids:
         try:
-            ag = await get_access_object(
-                access_group_id=ag_id,
-                prisma_client=_prisma_client,
-                user_api_key_cache=_user_api_key_cache,
-                proxy_logging_obj=_proxy_logging_obj,
-            )
+            ag = await get_access_object(access_group_id=ag_id)
             models.extend(getattr(ag, "access_model_names", []))
             mcp_ids.extend(getattr(ag, "access_mcp_server_ids", []))
             agent_ids.extend(getattr(ag, "access_agent_ids", []))
EOF
@@ -3361,25 +3361,13 @@
if not access_group_ids:
return empty

from litellm.proxy.proxy_server import prisma_client as _prisma_client
from litellm.proxy.proxy_server import proxy_logging_obj as _proxy_logging_obj
from litellm.proxy.proxy_server import user_api_key_cache as _user_api_key_cache

if _user_api_key_cache is None:
return empty

models: List[str] = []
mcp_ids: List[str] = []
agent_ids: List[str] = []

for ag_id in access_group_ids:
try:
ag = await get_access_object(
access_group_id=ag_id,
prisma_client=_prisma_client,
user_api_key_cache=_user_api_key_cache,
proxy_logging_obj=_proxy_logging_obj,
)
ag = await get_access_object(access_group_id=ag_id)
models.extend(getattr(ag, "access_model_names", []))
mcp_ids.extend(getattr(ag, "access_mcp_server_ids", []))
agent_ids.extend(getattr(ag, "access_agent_ids", []))
Copilot is powered by AI and may make mistakes. Always verify output.

if not all_access_group_ids or _prisma_client is None:
return {}

unique_ids = list(set(all_access_group_ids))
rows = await _prisma_client.db.litellm_accessgrouptable.find_many(
where={"access_group_id": {"in": unique_ids}},
)

result: Dict[str, Dict[str, List[str]]] = {}
for row in rows:
result[row.access_group_id] = {
"models": list(row.access_model_names or []),
"mcp_server_ids": list(row.access_mcp_server_ids or []),
"agent_ids": list(row.access_agent_ids or []),
}
return result


def _convert_teams_to_response_models(
teams: list,
use_deleted_table: bool,
Expand Down Expand Up @@ -3558,6 +3601,30 @@ async def list_team_v2(
# Convert Prisma models to response models with members_count
team_list = _convert_teams_to_response_models(teams, use_deleted_table)

# Resolve resources inherited from access groups (single batch query)
if not use_deleted_table:
team_items_with_ag = [
t for t in team_list
if isinstance(t, TeamListItem) and t.access_group_ids
]
if team_items_with_ag:
all_ag_ids = [
ag_id
for t in team_items_with_ag
for ag_id in (t.access_group_ids or [])
]
ag_lookup = await _batch_resolve_access_group_resources(all_ag_ids)
for team_item in team_items_with_ag:
models, mcp_ids, agent_ids = set(), set(), set()
for ag_id in (team_item.access_group_ids or []):
if ag_id in ag_lookup:
models.update(ag_lookup[ag_id]["models"])
mcp_ids.update(ag_lookup[ag_id]["mcp_server_ids"])
agent_ids.update(ag_lookup[ag_id]["agent_ids"])
team_item.access_group_models = list(models)
team_item.access_group_mcp_server_ids = list(mcp_ids)
team_item.access_group_agent_ids = list(agent_ids)

return {
"teams": team_list,
"total": total_count,
Expand Down
4 changes: 4 additions & 0 deletions litellm/types/proxy/management_endpoints/team_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class TeamListItem(LiteLLM_TeamTable):
"""A team item in the paginated list response, enriched with computed fields."""

members_count: int = 0
# Resources inherited from access groups (separate from direct assignments)
access_group_models: Optional[List[str]] = None
access_group_mcp_server_ids: Optional[List[str]] = None
access_group_agent_ids: Optional[List[str]] = None


class TeamListResponse(BaseModel):
Expand Down
128 changes: 128 additions & 0 deletions tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6491,3 +6491,131 @@ async def test_create_team_member_budget_table_with_duration():
assert budget_request.budget_duration == "30d"
assert budget_request.max_budget == 20.0
assert result["metadata"]["team_member_budget_id"] == "budget-abc"


# ---------------------------------------------------------------------------
# Tests for _batch_resolve_access_group_resources
# ---------------------------------------------------------------------------


class TestBatchResolveAccessGroupResources:
"""Tests for the batch access group resource resolution helper."""

@pytest.mark.asyncio
async def test_returns_empty_when_no_ids(self):
"""Empty list should return empty dict."""
from litellm.proxy.management_endpoints.team_endpoints import (
_batch_resolve_access_group_resources,
)

assert await _batch_resolve_access_group_resources([]) == {}

@pytest.mark.asyncio
async def test_single_access_group(self):
"""Single access group should return its resources."""
from litellm.proxy.management_endpoints.team_endpoints import (
_batch_resolve_access_group_resources,
)

fake_row = MagicMock()
fake_row.access_group_id = "ag-1"
fake_row.access_model_names = ["gpt-4", "claude-3"]
fake_row.access_mcp_server_ids = ["mcp-1"]
fake_row.access_agent_ids = ["agent-1", "agent-2"]

fake_prisma = MagicMock()
fake_prisma.db.litellm_accessgrouptable.find_many = AsyncMock(return_value=[fake_row])

with patch("litellm.proxy.proxy_server.prisma_client", fake_prisma):
result = await _batch_resolve_access_group_resources(["ag-1"])

assert sorted(result["ag-1"]["models"]) == ["claude-3", "gpt-4"]
assert result["ag-1"]["mcp_server_ids"] == ["mcp-1"]
assert sorted(result["ag-1"]["agent_ids"]) == ["agent-1", "agent-2"]

@pytest.mark.asyncio
async def test_multiple_access_groups(self):
"""Multiple access groups returned in a single query."""
from litellm.proxy.management_endpoints.team_endpoints import (
_batch_resolve_access_group_resources,
)

row1 = MagicMock()
row1.access_group_id = "ag-1"
row1.access_model_names = ["gpt-4"]
row1.access_mcp_server_ids = ["mcp-1"]
row1.access_agent_ids = ["agent-1"]

row2 = MagicMock()
row2.access_group_id = "ag-2"
row2.access_model_names = ["gemini"]
row2.access_mcp_server_ids = ["mcp-2"]
row2.access_agent_ids = ["agent-2"]

fake_prisma = MagicMock()
fake_prisma.db.litellm_accessgrouptable.find_many = AsyncMock(return_value=[row1, row2])

with patch("litellm.proxy.proxy_server.prisma_client", fake_prisma):
result = await _batch_resolve_access_group_resources(["ag-1", "ag-2"])

assert result["ag-1"]["models"] == ["gpt-4"]
assert result["ag-2"]["models"] == ["gemini"]

@pytest.mark.asyncio
async def test_missing_access_group_omitted(self):
"""If an access group doesn't exist in DB, it's simply not in the result."""
from litellm.proxy.management_endpoints.team_endpoints import (
_batch_resolve_access_group_resources,
)

row1 = MagicMock()
row1.access_group_id = "ag-1"
row1.access_model_names = ["gpt-4"]
row1.access_mcp_server_ids = []
row1.access_agent_ids = []

fake_prisma = MagicMock()
fake_prisma.db.litellm_accessgrouptable.find_many = AsyncMock(return_value=[row1])

with patch("litellm.proxy.proxy_server.prisma_client", fake_prisma):
result = await _batch_resolve_access_group_resources(["ag-1", "ag-missing"])

assert "ag-1" in result
assert "ag-missing" not in result

@pytest.mark.asyncio
async def test_returns_empty_when_prisma_unavailable(self):
"""If prisma_client is None, should return empty dict."""
from litellm.proxy.management_endpoints.team_endpoints import (
_batch_resolve_access_group_resources,
)

with patch("litellm.proxy.proxy_server.prisma_client", None):
result = await _batch_resolve_access_group_resources(["ag-1"])

assert result == {}

@pytest.mark.asyncio
async def test_deduplicates_input_ids(self):
"""Duplicate IDs in input should result in a single DB lookup."""
from litellm.proxy.management_endpoints.team_endpoints import (
_batch_resolve_access_group_resources,
)

row1 = MagicMock()
row1.access_group_id = "ag-1"
row1.access_model_names = ["gpt-4"]
row1.access_mcp_server_ids = []
row1.access_agent_ids = []

fake_find_many = AsyncMock(return_value=[row1])
fake_prisma = MagicMock()
fake_prisma.db.litellm_accessgrouptable.find_many = fake_find_many

with patch("litellm.proxy.proxy_server.prisma_client", fake_prisma):
result = await _batch_resolve_access_group_resources(["ag-1", "ag-1", "ag-1"])

# Should have been called with deduplicated list
call_args = fake_find_many.call_args
assert len(call_args.kwargs["where"]["access_group_id"]["in"]) == 1
assert "ag-1" in result
Loading
Loading