diff --git a/lib/triton_helpers.py b/lib/triton_helpers.py index e9de548..f607923 100644 --- a/lib/triton_helpers.py +++ b/lib/triton_helpers.py @@ -12,7 +12,7 @@ from pathlib import Path from typing import Any, Callable, Optional, TypedDict from redis import StrictRedis -from tenacity import retry, stop_after_attempt, wait_random +from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_random from tritonclient.http import InferenceServerClient @@ -34,9 +34,11 @@ def check_triton_server_health(url: str, timeout: int = 10, scheme: str = "http" url = f"{scheme}://{url}" try: urllib.request.urlopen(f"{url}/v2/health/live", timeout=timeout) - except urllib.error.URLError as e: + except (urllib.error.URLError, ConnectionError) as e: raise AssertionError(CONNECTION_ERR_MSG.format(url=url)) from e +wait_for_triton_server = retry(stop=stop_after_delay(60), wait=wait_fixed(2), reraise=True)(check_triton_server_health) + @retry(stop=stop_after_attempt(3), wait=wait_random(1, 2), reraise=True) def get_triton_inference_stats(client: InferenceServerClient): return client.get_inference_statistics()['model_stats'] diff --git a/worker.py b/worker.py index 5f8bad5..1b3422f 100755 --- a/worker.py +++ b/worker.py @@ -31,7 +31,7 @@ from miniray.lib.sig_term_handler import SigTermHandler from miniray.lib.resource_manager import ResourceManager, ResourceLimitError from miniray.lib.worker_helpers import ExponentialBackoff -from miniray.lib.triton_helpers import TRITON_SERVER_ADDRESS, check_triton_server_health +from miniray.lib.triton_helpers import TRITON_SERVER_ADDRESS, check_triton_server_health, wait_for_triton_server from miniray.lib.system_helpers import get_cgroup_cpu_usage, get_cgroup_mem_usage, get_gpu_stats, get_gpu_mem_usage, get_gpu_utilization from miniray.lib.statsd_helpers import statsd from miniray.lib.helpers import Limits, desc, GB_TO_BYTES, TASK_TIMEOUT_GRACE_SECONDS, JOB_CACHE_SIZE @@ -508,6 +508,9 @@ def main(): procs: dict[int, Optional[Task]] = dict.fromkeys(range(sum(rm.cpu_totals.values()))) + if triton_client is not None: + wait_for_triton_server(url=TRITON_SERVER_ADDRESS) + while not sigterm_handler.raised: r_master.set(ACTIVE_KEY, 1, ex=SLEEP_TIME_MAX+1) backoff.sleep()