Implement torch compile and mxfp8 for flux#2579
Open
hlahkar wants to merge 2 commits intopytorch:mainfrom
Open
Implement torch compile and mxfp8 for flux#2579hlahkar wants to merge 2 commits intopytorch:mainfrom
hlahkar wants to merge 2 commits intopytorch:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR implements torch compile and mxfp8 dtype computation for Flux model. Note: frozen encoder are not included in for mxfp8 quantization.
Test Report fo the test cases implementd for torch compile for Flux:
Running 15 items in this shard: tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls PASSED [ 6%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend PASSED [ 13%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks PASSED [ 20%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile PASSED [ 26%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile PASSED [ 33%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders PASSED [ 40%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers PASSED [ 46%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks PASSED [ 53%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled PASSED [ 60%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component PASSED [ 66%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component PASSED [ 73%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled PASSED [ 80%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled PASSED [ 86%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled PASSED [ 93%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component PASSED [100%]
========================================================================================================================== warnings summary ==========================================================================================================================
../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: 14 warnings
/usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning:
torch.jit.script_methodis deprecated. Please switch totorch.compileortorch.export.warnings.warn(
:488
:488: DeprecationWarning: builtin type SwigPyPacked has no module attribute
:488
:488: DeprecationWarning: builtin type SwigPyObject has no module attribute
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== 15 passed, 16 warnings in 9.74s ===================================================================================================================