Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions lib/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pathlib import Path
from typing import Any, Callable, Optional, TypedDict
from redis import StrictRedis
from tenacity import retry, stop_after_attempt, wait_random
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_random
from tritonclient.http import InferenceServerClient


Expand All @@ -34,9 +34,11 @@ def check_triton_server_health(url: str, timeout: int = 10, scheme: str = "http"
url = f"{scheme}://{url}"
try:
urllib.request.urlopen(f"{url}/v2/health/live", timeout=timeout)
except urllib.error.URLError as e:
except (urllib.error.URLError, ConnectionError) as e:
raise AssertionError(CONNECTION_ERR_MSG.format(url=url)) from e

wait_for_triton_server = retry(stop=stop_after_delay(60), wait=wait_fixed(2), reraise=True)(check_triton_server_health)

@retry(stop=stop_after_attempt(3), wait=wait_random(1, 2), reraise=True)
def get_triton_inference_stats(client: InferenceServerClient):
return client.get_inference_statistics()['model_stats']
Expand Down
5 changes: 4 additions & 1 deletion worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from miniray.lib.sig_term_handler import SigTermHandler
from miniray.lib.resource_manager import ResourceManager, ResourceLimitError
from miniray.lib.worker_helpers import ExponentialBackoff
from miniray.lib.triton_helpers import TRITON_SERVER_ADDRESS, check_triton_server_health
from miniray.lib.triton_helpers import TRITON_SERVER_ADDRESS, check_triton_server_health, wait_for_triton_server
from miniray.lib.system_helpers import get_cgroup_cpu_usage, get_cgroup_mem_usage, get_gpu_stats, get_gpu_mem_usage, get_gpu_utilization
from miniray.lib.statsd_helpers import statsd
from miniray.lib.helpers import Limits, desc, GB_TO_BYTES, TASK_TIMEOUT_GRACE_SECONDS, JOB_CACHE_SIZE
Expand Down Expand Up @@ -508,6 +508,9 @@ def main():

procs: dict[int, Optional[Task]] = dict.fromkeys(range(sum(rm.cpu_totals.values())))

if triton_client is not None:
wait_for_triton_server(url=TRITON_SERVER_ADDRESS)

while not sigterm_handler.raised:
r_master.set(ACTIVE_KEY, 1, ex=SLEEP_TIME_MAX+1)
backoff.sleep()
Expand Down