Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,11 @@ def _get_client(self) -> httpx.AsyncClient:
+ " (https://github.com/chroma-core/chroma)"
)

limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._clients[loop_hash] = httpx.AsyncClient(
timeout=None,
headers=headers,
verify=self._settings.chroma_server_ssl_verify or False,
limits=limits,
limits=self.http_limits,
)

return self._clients[loop_hash]
Expand Down Expand Up @@ -527,7 +526,7 @@ async def _get(
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
Expand Down Expand Up @@ -723,7 +722,7 @@ async def _query(
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
Expand Down
38 changes: 35 additions & 3 deletions chromadb/api/base_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,47 @@
import httpx

import chromadb.errors as errors
from chromadb.config import Settings
from chromadb.config import Component, Settings, System

logger = logging.getLogger(__name__)


class BaseHTTPClient:
# inherits from Component so that it can create an init function to use system
# this way it can build limits from the settings in System
class BaseHTTPClient(Component):
_settings: Settings
pre_flight_checks: Any = None
keepalive_secs: int = 40
DEFAULT_KEEPALIVE_SECS: float = 40.0

def __init__(self, system: System):
super().__init__(system)
self._settings = system.settings
keepalive_setting = self._settings.chroma_http_keepalive_secs
self.keepalive_secs: Optional[float] = (
keepalive_setting
if keepalive_setting is not None
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
)
self._http_limits = self._build_limits()

def _build_limits(self) -> httpx.Limits:
limit_kwargs: Dict[str, Any] = {}
if self.keepalive_secs is not None:
limit_kwargs["keepalive_expiry"] = self.keepalive_secs

max_connections = self._settings.chroma_http_max_connections
if max_connections is not None:
limit_kwargs["max_connections"] = max_connections

max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
if max_keepalive_connections is not None:
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections

return httpx.Limits(**limit_kwargs)

@property
def http_limits(self) -> httpx.Limits:
return self._http_limits

@staticmethod
def _validate_host(host: str) -> None:
Expand Down
16 changes: 10 additions & 6 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ def __init__(self, system: System):
default_api_path=system.settings.chroma_server_api_default_path,
)

limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
self._session = httpx.Client(timeout=None, limits=limits)
if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(
timeout=None,
limits=self.http_limits,
verify=self._settings.chroma_server_ssl_verify,
)
else:
self._session = httpx.Client(timeout=None, limits=self.http_limits)

self._header = system.settings.chroma_server_headers or {}
self._header["Content-Type"] = "application/json"
Expand All @@ -90,8 +96,6 @@ def __init__(self, system: System):
+ " (https://github.com/chroma-core/chroma)"
)

if self._settings.chroma_server_ssl_verify is not None:
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
if self._header is not None:
self._session.headers.update(self._header)

Expand Down Expand Up @@ -492,7 +496,7 @@ def _get(
return GetResult(
ids=resp_json["ids"],
embeddings=resp_json.get("embeddings", None),
metadatas=metadatas, # type: ignore
metadatas=metadatas,
documents=resp_json.get("documents", None),
data=None,
uris=resp_json.get("uris", None),
Expand Down Expand Up @@ -700,7 +704,7 @@ def _query(
ids=resp_json["ids"],
distances=resp_json.get("distances", None),
embeddings=resp_json.get("embeddings", None),
metadatas=metadata_batches, # type: ignore
metadatas=metadata_batches,
documents=resp_json.get("documents", None),
uris=resp_json.get("uris", None),
data=None,
Expand Down
4 changes: 4 additions & 0 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
# eg ["http://localhost:8000"]
chroma_server_cors_allow_origins: List[str] = []

chroma_http_keepalive_secs: Optional[float] = 40.0
chroma_http_max_connections: Optional[int] = None
chroma_http_max_keepalive_connections: Optional[int] = None

# ==================
# Server config
# ==================
Expand Down
47 changes: 44 additions & 3 deletions chromadb/test/test_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
from typing import Any, Callable, Generator, cast
from unittest.mock import patch
from typing import Any, Callable, Generator, cast, Dict, Tuple
from unittest.mock import MagicMock, patch
import chromadb
from chromadb.config import Settings
from chromadb.config import Settings, System
from chromadb.api import ClientAPI
import chromadb.server.fastapi
from chromadb.api.fastapi import FastAPI
import pytest
import tempfile
import os
Expand Down Expand Up @@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings(
str(e)
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
)


def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
captured: Dict[str, Any] = {}

# takes any positional args to match httpx.Client
def factory(*_: Any, **kwargs: Any) -> Any:
captured.update(kwargs)
session = MagicMock()
session.headers = {}
return session

return factory, captured


def test_fastapi_uses_http_limits_from_settings() -> None:
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=9000,
chroma_server_ssl_verify=True,
chroma_http_keepalive_secs=12.5,
chroma_http_max_connections=64,
chroma_http_max_keepalive_connections=16,
)
system = System(settings)

factory, captured = make_sync_client_factory()

with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
api = FastAPI(system)

api.stop()
limits = captured["limits"]
assert limits.keepalive_expiry == 12.5
assert limits.max_connections == 64
assert limits.max_keepalive_connections == 16
assert captured["timeout"] is None
assert captured["verify"] is True
18 changes: 18 additions & 0 deletions chromadb/test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,21 @@ def test_runtime_dependencies() -> None:
assert data.starts == ["D", "C"]
system.stop()
assert data.stops == ["C", "D"]


def test_http_client_setting_defaults() -> None:
settings = Settings()
assert settings.chroma_http_keepalive_secs == 40.0
assert settings.chroma_http_max_connections is None
assert settings.chroma_http_max_keepalive_connections is None


def test_http_client_setting_overrides() -> None:
settings = Settings(
chroma_http_keepalive_secs=5.5,
chroma_http_max_connections=123,
chroma_http_max_keepalive_connections=17,
)
assert settings.chroma_http_keepalive_secs == 5.5
assert settings.chroma_http_max_connections == 123
assert settings.chroma_http_max_keepalive_connections == 17
Loading