Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions examples/sagemaker/deploy_and_serve_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import boto3
import sagemaker

import boto3
from sagemaker import serializers
from sagemaker.model import Model
from sagemaker.predictor import Predictor
Expand All @@ -10,20 +9,22 @@
sm_client = boto_session.client("sagemaker")
sm_role = boto_session.resource("iam").Role("SageMakerRole").arn

endpoint_name="<YOUR_ENDPOINT_NAME>"
image_uri="<YOUR_DOCKER_IMAGE_URI>"
model_id="<YOUR_MODEL_ID>" # eg: Qwen/Qwen3-0.6B from https://huggingface.co/Qwen/Qwen3-0.6B
hf_token="<YOUR_HUGGINGFACE_TOKEN>"
prompt="<YOUR_ENDPOINT_PROMPT>"
endpoint_name = "<YOUR_ENDPOINT_NAME>"
image_uri = "<YOUR_DOCKER_IMAGE_URI>"
model_id = (
"<YOUR_MODEL_ID>" # eg: Qwen/Qwen3-0.6B from https://huggingface.co/Qwen/Qwen3-0.6B
)
hf_token = "<YOUR_HUGGINGFACE_TOKEN>"
prompt = "<YOUR_ENDPOINT_PROMPT>"

model = Model(
name=endpoint_name,
image_uri=image_uri,
role=sm_role,
env={
"SM_SGLANG_MODEL_PATH": model_id,
"HF_TOKEN": hf_token,
},
name=endpoint_name,
image_uri=image_uri,
role=sm_role,
env={
"SM_SGLANG_MODEL_PATH": model_id,
"HF_TOKEN": hf_token,
},
)
print("Model created successfully")
print("Starting endpoint deployment (this may take 10-15 minutes)...")
Expand Down Expand Up @@ -66,4 +67,3 @@
print("Warning: Response is not valid JSON. Returning as string.")

print(f"Received model response: '{response}'")

96 changes: 58 additions & 38 deletions scripts/ci/validate_and_download_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env python3
"""
Validate model integrity for CI runners and download if needed.
Expand Down Expand Up @@ -157,7 +157,8 @@
# Check if any files in the snapshot are symlinks to .incomplete blobs
# This ensures we only flag incomplete files for THIS specific model,
# not other models that might be downloading concurrently
for file_path in model_path.glob("*"):
# Use recursive glob to support Diffusers models with weights in subdirectories
for file_path in model_path.glob("**/*"):
if file_path.is_symlink():
try:
target = file_path.resolve()
Expand Down Expand Up @@ -210,23 +211,24 @@
Tuple of (is_valid, error_message, corrupted_files)
- corrupted_files: List of paths to corrupted shard files that should be removed
"""
# Pattern for sharded files: model-00001-of-00009.safetensors or pytorch_model-00001-of-00009.bin
# Pattern for sharded files: model-00001-of-00009.safetensors, pytorch_model-00001-of-00009.bin,
# or diffusion_pytorch_model-00001-of-00009.safetensors (for Diffusers models)
# Use word boundary to prevent matching files like tokenizer_model-* or optimizer_model-*
shard_pattern = re.compile(
r"(?:model|pytorch_model)-(\d+)-of-(\d+)\.(safetensors|bin)"
r"(?:^|/)(?:model|pytorch_model|diffusion_pytorch_model)-(\d+)-of-(\d+)\.(safetensors|bin)"
)

# Find all shard files (both .safetensors and .bin)
shard_files = (
list(model_path.glob("model-*-of-*.safetensors"))
+ list(model_path.glob("model-*-of-*.bin"))
+ list(model_path.glob("pytorch_model-*-of-*.bin"))
# Find all shard files recursively (both .safetensors and .bin)
# This supports both standard models (weights in root) and Diffusers models (weights in subdirs)
shard_files = list(model_path.glob("**/*-*-of-*.safetensors")) + list(
model_path.glob("**/*-*-of-*.bin")
)

if not shard_files:
# No sharded files - check for any safetensors or bin files
# No sharded files - check for any safetensors or bin files recursively
# Exclude non-model files like tokenizer, config, optimizer, etc.
all_safetensors = list(model_path.glob("*.safetensors"))
all_bins = list(model_path.glob("*.bin"))
all_safetensors = list(model_path.glob("**/*.safetensors"))
all_bins = list(model_path.glob("**/*.bin"))

# Filter out non-model files
excluded_prefixes = ["tokenizer", "optimizer", "training_", "config"]
Expand All @@ -251,43 +253,61 @@
return True, None, []
return False, "No model weight files found (safetensors or bin)", []

# Extract total shard count from any shard filename
total_shards = None
# Group shards by subdirectory and total count
# This handles Diffusers models where different components (transformer/, vae/)
# have different numbers of shards
shard_groups = {}
for shard_file in shard_files:
match = shard_pattern.search(shard_file.name)
if match:
total_shards = int(match.group(2))
break

if total_shards is None:
return False, "Could not determine total shard count from filenames", []

# Check that all shards exist
expected_shards = set(range(1, total_shards + 1))
found_shards = set()

for shard_file in shard_files:
match = shard_pattern.search(shard_file.name)
# Match against the full path string to get proper path separation
match = shard_pattern.search(str(shard_file))
if match:
shard_num = int(match.group(1))
found_shards.add(shard_num)
total = int(match.group(2))
parent = shard_file.parent
key = (str(parent.relative_to(model_path)), total)

if key not in shard_groups:
shard_groups[key] = set()
shard_groups[key].add(shard_num)

if not shard_groups:
return False, "Could not determine shard groups from filenames", []

# Validate each group separately
for (parent_path, total_shards), found_shards in shard_groups.items():
expected_shards = set(range(1, total_shards + 1))
missing_shards = expected_shards - found_shards

if missing_shards:
missing_list = sorted(missing_shards)
location = f" in {parent_path}" if parent_path != "." else ""
# Missing shards - nothing to remove, let download handle it
return (
False,
f"Missing shards{location}: {missing_list} (expected {total_shards} total)",
[],
)

missing_shards = expected_shards - found_shards
# Check for index file (look for specific patterns matching the shard prefixes)
# Standard models: model.safetensors.index.json or pytorch_model.safetensors.index.json
# Diffusers models: diffusion_pytorch_model.safetensors.index.json in subdirs
valid_index_patterns = [
"model.safetensors.index.json",
"pytorch_model.safetensors.index.json",
"**/diffusion_pytorch_model.safetensors.index.json",
]

if missing_shards:
missing_list = sorted(missing_shards)
# Missing shards - nothing to remove, let download handle it
index_files = []
for pattern in valid_index_patterns:
index_files.extend(model_path.glob(pattern))

if not index_files:
return (
False,
f"Missing shards: {missing_list} (expected {total_shards} total)",
"Missing required index file (model/pytorch_model/diffusion_pytorch_model.safetensors.index.json)",
[],
)

# Check for index file
index_file = model_path / "model.safetensors.index.json"
if not index_file.exists():
return False, "Missing model.safetensors.index.json", []

# Validate each safetensors shard file for corruption
print(f" Validating {len(shard_files)} shard file(s) for corruption...")
corrupted_files = []
Expand Down
Loading