diff --git a/docker/Dockerfile b/docker/Dockerfile index 5dd7ab36d6c9..920675d29d38 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -396,6 +396,12 @@ RUN CMAKE_VERSION=3.31.1 \ && cp -r "${CMAKE_INSTALLER}/share/"* /usr/local/share/ \ && rm -rf "${CMAKE_INSTALLER}" "${CMAKE_INSTALLER}.tar.gz" +# Install S3 dependencies +RUN python3 -m pip install \ + boto3 \ + runai-model-streamer \ + runai-model-streamer-s3 + # Install just RUN curl --proto '=https' --tlsv1.2 --retry 3 --retry-delay 2 -sSf https://just.systems/install.sh | \ sed "s|https://github.com|https://${GITHUB_ARTIFACTORY}|g" | \ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 44406239d751..9eff5aa4dfad 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -110,6 +110,7 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, sampling_defaults: str = "openai", quantize_and_serve: bool = False, + device: Optional[str] = None, ) -> None: # Parse args self.model_path = model_path @@ -124,7 +125,7 @@ def __init__( self._validate_quantize_and_serve_config() # Get hf config - self._maybe_pull_model_tokenizer_from_remote() + self._maybe_pull_model_tokenizer_from_remote(device) self.model_override_args = json.loads(model_override_args) kwargs = {} if override_config_file and override_config_file.strip(): @@ -134,6 +135,7 @@ def __init__( trust_remote_code=trust_remote_code, revision=revision, model_override_args=self.model_override_args, + device=device, **kwargs, ) self.hf_text_config = get_hf_text_config(self.hf_config) @@ -257,6 +259,7 @@ def from_server_args( model_impl=server_args.model_impl, sampling_defaults=server_args.sampling_defaults, quantize_and_serve=server_args.quantize_and_serve, + device=server_args.device, override_config_file=server_args.decrypted_config_file, **kwargs, ) @@ -837,7 +840,7 @@ def get_default_sampling_params(self) -> dict[str, Any]: return default_sampling_params - def _maybe_pull_model_tokenizer_from_remote(self) -> None: + def _maybe_pull_model_tokenizer_from_remote(self, device: str) -> None: """ Pull the model config files to a temporary directory in case of remote. @@ -854,7 +857,7 @@ def _maybe_pull_model_tokenizer_from_remote(self) -> None: # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use # with statement to avoid closing the client. - client = create_remote_connector(self.model_path) + client = create_remote_connector(self.model_path, device) if is_remote_url(self.model_path): client.pull_files(allow_pattern=["*config.json"]) self.model_weights = self.model_path diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index c9663a836d14..36110b7e5b3b 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -2,6 +2,7 @@ import enum import logging +import os from sglang.srt.connector.base_connector import ( BaseConnector, @@ -23,6 +24,7 @@ class ConnectorType(str, enum.Enum): def create_remote_connector(url, device, **kwargs) -> BaseConnector: + url = verify_if_url_is_gcs_bucket(url) connector_type = parse_connector_type(url) if connector_type == "redis": return RedisConnector(url) @@ -34,6 +36,16 @@ def create_remote_connector(url, device, **kwargs) -> BaseConnector: raise ValueError(f"Invalid connector type: {url}") +def verify_if_url_is_gcs_bucket(url): + if url.startswith("gs://"): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = "https://storage.googleapis.com" + os.environ["AWS_ENDPOINT_URL"] = "https://storage.googleapis.com" + os.environ["RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING"] = "0" + os.environ["AWS_EC2_METADATA_DISABLED"] = "true" + url = url.replace("gs://", "s3://", 1) + return url + + def get_connector_type(client: BaseConnector) -> ConnectorType: if isinstance(client, BaseKVConnector): return ConnectorType.KV diff --git a/python/sglang/srt/connector/s3.py b/python/sglang/srt/connector/s3.py index 7bef8f5d5225..aab7f7914151 100644 --- a/python/sglang/srt/connector/s3.py +++ b/python/sglang/srt/connector/s3.py @@ -101,7 +101,9 @@ def pull_files( return for file in files: - destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir)) + destination_file = os.path.join( + self.local_dir, file.removeprefix(f"{base_dir}/") + ) local_dir = Path(destination_file).parent os.makedirs(local_dir, exist_ok=True) self.client.download_file(bucket_name, file, destination_file) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 787de125728c..c6874fd543b4 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -274,6 +274,7 @@ def __init__( tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, + device=server_args.device, ) self._initialize_multi_item_delimiter_text() diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index ef04b85c6c7c..82b78e7de5ef 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -238,6 +238,7 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, model_override_args: Optional[dict] = None, + device: Optional[str] = None, **kwargs, ): is_gguf = check_gguf_file(model) @@ -249,7 +250,7 @@ def get_config( # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use # with statement to avoid closing the client. - client = create_remote_connector(model) + client = create_remote_connector(model, device) client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() @@ -411,6 +412,7 @@ def get_tokenizer( tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, + device: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" @@ -437,7 +439,7 @@ def get_tokenizer( # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use # with statement to avoid closing the client. - client = create_remote_connector(tokenizer_name) + client = create_remote_connector(tokenizer_name, device) client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = client.get_local_dir()