Skip to content

Implement torch compile and mxfp8 for flux#2579

Open
hlahkar wants to merge 2 commits intopytorch:mainfrom
hlahkar:flux_compile
Open

Implement torch compile and mxfp8 for flux#2579
hlahkar wants to merge 2 commits intopytorch:mainfrom
hlahkar:flux_compile

Conversation

@hlahkar
Copy link

@hlahkar hlahkar commented Mar 15, 2026

This PR implements torch compile and mxfp8 dtype computation for Flux model. Note: frozen encoder are not included in for mxfp8 quantization.

Test Report fo the test cases implementd for torch compile for Flux:

Running 15 items in this shard: tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks, tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers, tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component, tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled, tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component

tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_multiple_calls PASSED [ 6%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_uses_specified_backend PASSED [ 13%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_apply_compile_wraps_all_blocks PASSED [ 20%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileFlux::test_forward_after_compile PASSED [ 26%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_encoder_forward_after_compile PASSED [ 33%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_no_fullgraph_for_encoders PASSED [ 40%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_clip_layers PASSED [ 46%]
tests/unit_tests/test_compile_flux.py::TestApplyCompileToEncoders::test_wraps_t5_blocks PASSED [ 53%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_disabled PASSED [ 60%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_with_model_component PASSED [ 66%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_compile_enabled_without_model_component PASSED [ 73%]
tests/unit_tests/test_compile_flux.py::TestCompileConfigGating::test_encoder_compile_disabled PASSED [ 80%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_compiled PASSED [ 86%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_when_disabled PASSED [ 93%]
tests/unit_tests/test_compile_flux.py::TestMSELossCompile::test_mse_loss_not_compiled_without_loss_component PASSED [100%]

========================================================================================================================== warnings summary ==========================================================================================================================
../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: 14 warnings
/usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: torch.jit.script_method is deprecated. Please switch to torch.compile or torch.export.
warnings.warn(

:488
:488: DeprecationWarning: builtin type SwigPyPacked has no module attribute

:488
:488: DeprecationWarning: builtin type SwigPyObject has no module attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== 15 passed, 16 warnings in 9.74s ===================================================================================================================

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant