Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/nightly-test-nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ concurrency:
group: nightly-test-nvidia-${{ github.ref }}
cancel-in-progress: true

env:
SGLANG_IS_IN_CI: true

jobs:
# General tests - 1 GPU
nightly-test-general-1-gpu-runner:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ concurrency:
group: pr-test-${{ github.ref }}
cancel-in-progress: true

env:
SGLANG_IS_IN_CI: true

jobs:
# =============================================== check changes ====================================================
check-changes:
Expand Down
74 changes: 64 additions & 10 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
_validate_sharded_model,
)
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
from sglang.utils import is_in_ci

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -421,6 +420,55 @@ def find_local_hf_snapshot_dir(
return None


def _validate_weights_after_download(
hf_folder: str,
allow_patterns: List[str],
model_name_or_path: str,
) -> None:
"""Validate downloaded weight files to catch corruption early.

This function validates safetensors files after download to catch
corruption issues (truncated downloads, network errors, etc.) before
model loading fails with cryptic errors.

Args:
hf_folder: Path to the downloaded model folder
allow_patterns: Patterns used to match weight files
model_name_or_path: Model identifier for error messages

Raises:
RuntimeError: If any weight files are corrupted
"""
import glob as glob_module

# Find all weight files that were downloaded
weight_files: List[str] = []
for pattern in allow_patterns:
weight_files.extend(glob_module.glob(os.path.join(hf_folder, pattern)))

if not weight_files:
return # No weight files to validate

# Validate safetensors files
corrupted_files = []
for f in weight_files:
if f.endswith(".safetensors") and os.path.exists(f):
if not _validate_safetensors_file(f):
corrupted_files.append(os.path.basename(f))

if corrupted_files:
# Clean up corrupted files so next attempt re-downloads them
_cleanup_corrupted_files_selective(
model_name_or_path,
[os.path.join(hf_folder, f) for f in corrupted_files],
)
raise RuntimeError(
f"Downloaded model files are corrupted for {model_name_or_path}: "
f"{corrupted_files}. The corrupted files have been removed. "
"Please retry to re-download the model."
)


def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
Expand All @@ -446,17 +494,19 @@ def download_weights_from_hf(
str: The path to the downloaded model weights.
"""

if is_in_ci():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should only do all these cache behaviors in CI

# If the weights are already local, skip downloading and returns the path.
# This is used to skip too-many Huggingface API calls in CI.
path = find_local_hf_snapshot_dir(
model_name_or_path, cache_dir, allow_patterns, revision
)
if path is not None:
return path
# Always check for valid local cache first.
# This validates cached files and cleans up corrupted ones.
path = find_local_hf_snapshot_dir(
model_name_or_path, cache_dir, allow_patterns, revision
)
if path is not None:
# Valid local cache found, skip download
return path

# In CI, skip HF API calls if we're in offline mode or want to avoid rate limits
# But we already checked for local cache above, so if we're here we need to download
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
# Before we download we look at what is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

Expand All @@ -480,6 +530,10 @@ def download_weights_from_hf(
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)

# Validate downloaded files to catch corruption early
_validate_weights_after_download(hf_folder, allow_patterns, model_name_or_path)

return hf_folder


Expand Down
Loading