-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[WIP] Refactor Model Tests #12822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Refactor Model Tests #12822
Changes from all commits
1f026ad
1c55871
bffa3a9
aa29af8
0f1a4e0
fe451c3
489480b
0fdd9d3
c366b5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,21 @@ | |
|
|
||
| def pytest_configure(config): | ||
| config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources") | ||
| config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality") | ||
| config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality") | ||
| config.addinivalue_line("markers", "training: marks tests for training functionality") | ||
| config.addinivalue_line("markers", "attention: marks tests for attention processor functionality") | ||
| config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality") | ||
| config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality") | ||
| config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality") | ||
| config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality") | ||
| config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading") | ||
| config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality") | ||
| config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality") | ||
| config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality") | ||
| config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality") | ||
| config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality") | ||
|
Comment on lines
+44
to
+48
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you think it's better off in the context of a pipeline? We could try to refactor everything under I am not too strongly opinionated about it. |
||
| config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality") | ||
|
|
||
|
|
||
| def pytest_addoption(parser): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| from .attention import AttentionTesterMixin, ContextParallelTesterMixin | ||
| from .common import ModelTesterMixin | ||
| from .compile import TorchCompileTesterMixin | ||
| from .ip_adapter import IPAdapterTesterMixin | ||
| from .lora import LoraTesterMixin | ||
| from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin | ||
| from .quantization import ( | ||
| BitsAndBytesTesterMixin, | ||
| GGUFTesterMixin, | ||
| ModelOptTesterMixin, | ||
| QuantizationTesterMixin, | ||
| QuantoTesterMixin, | ||
| TorchAoTesterMixin, | ||
| ) | ||
| from .single_file import SingleFileTesterMixin | ||
| from .training import TrainingTesterMixin | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "ContextParallelTesterMixin", | ||
| "AttentionTesterMixin", | ||
| "BitsAndBytesTesterMixin", | ||
| "CPUOffloadTesterMixin", | ||
| "GGUFTesterMixin", | ||
| "GroupOffloadTesterMixin", | ||
| "IPAdapterTesterMixin", | ||
| "LayerwiseCastingTesterMixin", | ||
| "LoraTesterMixin", | ||
| "MemoryTesterMixin", | ||
| "ModelOptTesterMixin", | ||
| "ModelTesterMixin", | ||
| "QuantizationTesterMixin", | ||
| "QuantoTesterMixin", | ||
| "SingleFileTesterMixin", | ||
| "TorchAoTesterMixin", | ||
| "TorchCompileTesterMixin", | ||
| "TrainingTesterMixin", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,263 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2025 HuggingFace Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.multiprocessing as mp | ||
|
|
||
| from diffusers.models._modeling_parallel import ContextParallelConfig | ||
| from diffusers.models.attention import AttentionModuleMixin | ||
| from diffusers.models.attention_processor import ( | ||
| AttnProcessor, | ||
| ) | ||
|
|
||
| from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device | ||
|
|
||
|
|
||
| @is_attention | ||
| class AttentionTesterMixin: | ||
| """ | ||
| Mixin class for testing attention processor and module functionality on models. | ||
| Tests functionality from AttentionModuleMixin including: | ||
| - Attention processor management (set/get) | ||
| - QKV projection fusion/unfusion | ||
| - Attention backends (XFormers, NPU, etc.) | ||
| Expected class attributes to be set by subclasses: | ||
| - model_class: The model class to test | ||
| - base_precision: Tolerance for floating point comparisons (default: 1e-3) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels unnecessarily restrictive. Would rather rely on method-specific |
||
| - uses_custom_attn_processor: Whether model uses custom attention processors (default: False) | ||
| Expected methods to be implemented by subclasses: | ||
| - get_init_dict(): Returns dict of arguments to initialize the model | ||
| - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass | ||
| Pytest mark: attention | ||
| Use `pytest -m "not attention"` to skip these tests | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from |
||
| """ | ||
|
|
||
| base_precision = 1e-3 | ||
|
|
||
| def test_fuse_unfuse_qkv_projections(self): | ||
| init_dict = self.get_init_dict() | ||
| inputs_dict = self.get_dummy_inputs() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
| model.eval() | ||
|
|
||
| if not hasattr(model, "fuse_qkv_projections"): | ||
| pytest.skip("Model does not support QKV projection fusion.") | ||
|
|
||
| # Get output before fusion | ||
| with torch.no_grad(): | ||
| output_before_fusion = model(**inputs_dict) | ||
| if isinstance(output_before_fusion, dict): | ||
| output_before_fusion = output_before_fusion.to_tuple()[0] | ||
|
Comment on lines
+68
to
+70
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit): Let's have a function to normalize the outputs so that it's ready to go to |
||
|
|
||
| # Fuse projections | ||
| model.fuse_qkv_projections() | ||
|
|
||
| # Verify fusion occurred by checking for fused attributes | ||
| has_fused_projections = False | ||
| for module in model.modules(): | ||
| if isinstance(module, AttentionModuleMixin): | ||
| if hasattr(module, "to_qkv") or hasattr(module, "to_kv"): | ||
| has_fused_projections = True | ||
| assert module.fused_projections, "fused_projections flag should be True" | ||
| break | ||
|
|
||
| if has_fused_projections: | ||
| # Get output after fusion | ||
| with torch.no_grad(): | ||
| output_after_fusion = model(**inputs_dict) | ||
| if isinstance(output_after_fusion, dict): | ||
| output_after_fusion = output_after_fusion.to_tuple()[0] | ||
|
|
||
| # Verify outputs match | ||
| assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), ( | ||
| "Output should not change after fusing projections" | ||
| ) | ||
|
|
||
| # Unfuse projections | ||
| model.unfuse_qkv_projections() | ||
|
|
||
| # Verify unfusion occurred | ||
| for module in model.modules(): | ||
| if isinstance(module, AttentionModuleMixin): | ||
| assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing" | ||
| assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing" | ||
| assert not module.fused_projections, "fused_projections flag should be False" | ||
|
|
||
| # Get output after unfusion | ||
| with torch.no_grad(): | ||
| output_after_unfusion = model(**inputs_dict) | ||
| if isinstance(output_after_unfusion, dict): | ||
| output_after_unfusion = output_after_unfusion.to_tuple()[0] | ||
|
|
||
| # Verify outputs still match | ||
| assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), ( | ||
| "Output should match original after unfusing projections" | ||
| ) | ||
|
|
||
| def test_get_set_processor(self): | ||
| init_dict = self.get_init_dict() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
|
|
||
| # Check if model has attention processors | ||
| if not hasattr(model, "attn_processors"): | ||
| pytest.skip("Model does not have attention processors.") | ||
|
|
||
| # Test getting processors | ||
| processors = model.attn_processors | ||
| assert isinstance(processors, dict), "attn_processors should return a dict" | ||
| assert len(processors) > 0, "Model should have at least one attention processor" | ||
|
|
||
| # Test that all processors can be retrieved via get_processor | ||
| for module in model.modules(): | ||
| if isinstance(module, AttentionModuleMixin): | ||
| processor = module.get_processor() | ||
| assert processor is not None, "get_processor should return a processor" | ||
|
|
||
| # Test setting a new processor | ||
| new_processor = AttnProcessor() | ||
| module.set_processor(new_processor) | ||
| retrieved_processor = module.get_processor() | ||
| assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set" | ||
|
|
||
| def test_attention_processor_dict(self): | ||
| init_dict = self.get_init_dict() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
|
|
||
| if not hasattr(model, "set_attn_processor"): | ||
| pytest.skip("Model does not support setting attention processors.") | ||
|
|
||
| # Get current processors | ||
| current_processors = model.attn_processors | ||
|
|
||
| # Create a dict of new processors | ||
| new_processors = {key: AttnProcessor() for key in current_processors.keys()} | ||
|
|
||
| # Set processors using dict | ||
| model.set_attn_processor(new_processors) | ||
|
|
||
| # Verify all processors were set | ||
| updated_processors = model.attn_processors | ||
| for key in current_processors.keys(): | ||
| assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor" | ||
|
|
||
| def test_attention_processor_count_mismatch_raises_error(self): | ||
| init_dict = self.get_init_dict() | ||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
|
|
||
| if not hasattr(model, "set_attn_processor"): | ||
| pytest.skip("Model does not support setting attention processors.") | ||
|
|
||
| # Get current processors | ||
| current_processors = model.attn_processors | ||
|
|
||
| # Create a dict with wrong number of processors | ||
| wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()} | ||
|
|
||
| # Verify error is raised | ||
| with pytest.raises(ValueError) as exc_info: | ||
| model.set_attn_processor(wrong_processors) | ||
|
|
||
| assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" | ||
|
|
||
|
|
||
| def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): | ||
| try: | ||
| # Setup distributed environment | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "12355" | ||
|
|
||
| torch.distributed.init_process_group( | ||
| backend="nccl", | ||
| init_method="env://", | ||
| world_size=world_size, | ||
| rank=rank, | ||
| ) | ||
| torch.cuda.set_device(rank) | ||
| device = torch.device(f"cuda:{rank}") | ||
|
|
||
| model = model_class(**init_dict) | ||
| model.to(device) | ||
| model.eval() | ||
|
|
||
| inputs_on_device = {} | ||
| for key, value in inputs_dict.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| inputs_on_device[key] = value.to(device) | ||
| else: | ||
| inputs_on_device[key] = value | ||
|
|
||
| cp_config = ContextParallelConfig(**cp_dict) | ||
| model.enable_parallelism(config=cp_config) | ||
|
|
||
| with torch.no_grad(): | ||
| output = model(**inputs_on_device) | ||
| if isinstance(output, dict): | ||
| output = output.to_tuple()[0] | ||
|
|
||
| if rank == 0: | ||
| result_queue.put(("success", output.shape)) | ||
|
|
||
| except Exception as e: | ||
| if rank == 0: | ||
| result_queue.put(("error", str(e))) | ||
| finally: | ||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.destroy_process_group() | ||
|
|
||
|
|
||
| @is_context_parallel | ||
| @require_torch_multi_accelerator | ||
| class ContextParallelTesterMixin: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would prefer this under a separate module |
||
| base_precision = 1e-3 | ||
|
|
||
| @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) | ||
| def test_context_parallel_inference(self, cp_type): | ||
| if not torch.distributed.is_available(): | ||
| pytest.skip("torch.distributed is not available.") | ||
|
|
||
| if not torch.cuda.is_available() or torch.cuda.device_count() < 2: | ||
| pytest.skip("Context parallel requires at least 2 CUDA devices.") | ||
|
Comment on lines
+241
to
+242
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: | ||
| pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") | ||
|
|
||
| world_size = 2 | ||
| init_dict = self.get_init_dict() | ||
| inputs_dict = self.get_dummy_inputs() | ||
| cp_dict = {cp_type: world_size} | ||
|
|
||
| ctx = mp.get_context("spawn") | ||
| result_queue = ctx.Queue() | ||
|
|
||
| mp.spawn( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we not check if the model class implements |
||
| _context_parallel_worker, | ||
| args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), | ||
| nprocs=world_size, | ||
| join=True, | ||
| ) | ||
|
|
||
| status, result = result_queue.get(timeout=60) | ||
| assert status == "success", f"Context parallel inference failed: {result}" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to club this with attention backends?