diff --git a/examples/sagemaker/deploy_and_serve_endpoint.py b/examples/sagemaker/deploy_and_serve_endpoint.py index afc4cc1fc66b..e518183c39f3 100644 --- a/examples/sagemaker/deploy_and_serve_endpoint.py +++ b/examples/sagemaker/deploy_and_serve_endpoint.py @@ -1,7 +1,6 @@ import json -import boto3 -import sagemaker +import boto3 from sagemaker import serializers from sagemaker.model import Model from sagemaker.predictor import Predictor @@ -10,20 +9,22 @@ sm_client = boto_session.client("sagemaker") sm_role = boto_session.resource("iam").Role("SageMakerRole").arn -endpoint_name="" -image_uri="" -model_id="" # eg: Qwen/Qwen3-0.6B from https://huggingface.co/Qwen/Qwen3-0.6B -hf_token="" -prompt="" +endpoint_name = "" +image_uri = "" +model_id = ( + "" # eg: Qwen/Qwen3-0.6B from https://huggingface.co/Qwen/Qwen3-0.6B +) +hf_token = "" +prompt = "" model = Model( - name=endpoint_name, - image_uri=image_uri, - role=sm_role, - env={ - "SM_SGLANG_MODEL_PATH": model_id, - "HF_TOKEN": hf_token, - }, + name=endpoint_name, + image_uri=image_uri, + role=sm_role, + env={ + "SM_SGLANG_MODEL_PATH": model_id, + "HF_TOKEN": hf_token, + }, ) print("Model created successfully") print("Starting endpoint deployment (this may take 10-15 minutes)...") @@ -66,4 +67,3 @@ print("Warning: Response is not valid JSON. Returning as string.") print(f"Received model response: '{response}'") - diff --git a/scripts/ci/validate_and_download_models.py b/scripts/ci/validate_and_download_models.py index d1ef0f4e0a34..e357eccc941a 100755 --- a/scripts/ci/validate_and_download_models.py +++ b/scripts/ci/validate_and_download_models.py @@ -157,7 +157,8 @@ def check_incomplete_files(model_path: Path, cache_dir: str) -> List[str]: # Check if any files in the snapshot are symlinks to .incomplete blobs # This ensures we only flag incomplete files for THIS specific model, # not other models that might be downloading concurrently - for file_path in model_path.glob("*"): + # Use recursive glob to support Diffusers models with weights in subdirectories + for file_path in model_path.glob("**/*"): if file_path.is_symlink(): try: target = file_path.resolve() @@ -210,23 +211,24 @@ def validate_model_shards(model_path: Path) -> Tuple[bool, Optional[str], List[P Tuple of (is_valid, error_message, corrupted_files) - corrupted_files: List of paths to corrupted shard files that should be removed """ - # Pattern for sharded files: model-00001-of-00009.safetensors or pytorch_model-00001-of-00009.bin + # Pattern for sharded files: model-00001-of-00009.safetensors, pytorch_model-00001-of-00009.bin, + # or diffusion_pytorch_model-00001-of-00009.safetensors (for Diffusers models) + # Use word boundary to prevent matching files like tokenizer_model-* or optimizer_model-* shard_pattern = re.compile( - r"(?:model|pytorch_model)-(\d+)-of-(\d+)\.(safetensors|bin)" + r"(?:^|/)(?:model|pytorch_model|diffusion_pytorch_model)-(\d+)-of-(\d+)\.(safetensors|bin)" ) - # Find all shard files (both .safetensors and .bin) - shard_files = ( - list(model_path.glob("model-*-of-*.safetensors")) - + list(model_path.glob("model-*-of-*.bin")) - + list(model_path.glob("pytorch_model-*-of-*.bin")) + # Find all shard files recursively (both .safetensors and .bin) + # This supports both standard models (weights in root) and Diffusers models (weights in subdirs) + shard_files = list(model_path.glob("**/*-*-of-*.safetensors")) + list( + model_path.glob("**/*-*-of-*.bin") ) if not shard_files: - # No sharded files - check for any safetensors or bin files + # No sharded files - check for any safetensors or bin files recursively # Exclude non-model files like tokenizer, config, optimizer, etc. - all_safetensors = list(model_path.glob("*.safetensors")) - all_bins = list(model_path.glob("*.bin")) + all_safetensors = list(model_path.glob("**/*.safetensors")) + all_bins = list(model_path.glob("**/*.bin")) # Filter out non-model files excluded_prefixes = ["tokenizer", "optimizer", "training_", "config"] @@ -251,43 +253,61 @@ def validate_model_shards(model_path: Path) -> Tuple[bool, Optional[str], List[P return True, None, [] return False, "No model weight files found (safetensors or bin)", [] - # Extract total shard count from any shard filename - total_shards = None + # Group shards by subdirectory and total count + # This handles Diffusers models where different components (transformer/, vae/) + # have different numbers of shards + shard_groups = {} for shard_file in shard_files: - match = shard_pattern.search(shard_file.name) - if match: - total_shards = int(match.group(2)) - break - - if total_shards is None: - return False, "Could not determine total shard count from filenames", [] - - # Check that all shards exist - expected_shards = set(range(1, total_shards + 1)) - found_shards = set() - - for shard_file in shard_files: - match = shard_pattern.search(shard_file.name) + # Match against the full path string to get proper path separation + match = shard_pattern.search(str(shard_file)) if match: shard_num = int(match.group(1)) - found_shards.add(shard_num) + total = int(match.group(2)) + parent = shard_file.parent + key = (str(parent.relative_to(model_path)), total) + + if key not in shard_groups: + shard_groups[key] = set() + shard_groups[key].add(shard_num) + + if not shard_groups: + return False, "Could not determine shard groups from filenames", [] + + # Validate each group separately + for (parent_path, total_shards), found_shards in shard_groups.items(): + expected_shards = set(range(1, total_shards + 1)) + missing_shards = expected_shards - found_shards + + if missing_shards: + missing_list = sorted(missing_shards) + location = f" in {parent_path}" if parent_path != "." else "" + # Missing shards - nothing to remove, let download handle it + return ( + False, + f"Missing shards{location}: {missing_list} (expected {total_shards} total)", + [], + ) - missing_shards = expected_shards - found_shards + # Check for index file (look for specific patterns matching the shard prefixes) + # Standard models: model.safetensors.index.json or pytorch_model.safetensors.index.json + # Diffusers models: diffusion_pytorch_model.safetensors.index.json in subdirs + valid_index_patterns = [ + "model.safetensors.index.json", + "pytorch_model.safetensors.index.json", + "**/diffusion_pytorch_model.safetensors.index.json", + ] - if missing_shards: - missing_list = sorted(missing_shards) - # Missing shards - nothing to remove, let download handle it + index_files = [] + for pattern in valid_index_patterns: + index_files.extend(model_path.glob(pattern)) + + if not index_files: return ( False, - f"Missing shards: {missing_list} (expected {total_shards} total)", + "Missing required index file (model/pytorch_model/diffusion_pytorch_model.safetensors.index.json)", [], ) - # Check for index file - index_file = model_path / "model.safetensors.index.json" - if not index_file.exists(): - return False, "Missing model.safetensors.index.json", [] - # Validate each safetensors shard file for corruption print(f" Validating {len(shard_files)} shard file(s) for corruption...") corrupted_files = []