-
Notifications
You must be signed in to change notification settings - Fork 192
[feat]: support offload for flux2 #1034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| { | ||
| "model_cls": "flux2_dev", | ||
| "task": "t2i", | ||
| "infer_steps": 50, | ||
| "sample_guide_scale": 4.0, | ||
| "vae_scale_factor": 16, | ||
| "feature_caching": "None", | ||
| "enable_cfg": false, | ||
| "patch_size": 2, | ||
| "tokenizer_max_length": 512, | ||
| "rope_type": "flashinfer", | ||
| "text_encoder_out_layers": [10, 20, 30], | ||
| "cpu_offload": true, | ||
| "offload_granularity": "block" | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,8 @@ | |||||||
| from lightx2v.utils.registry_factory import RUNNER_REGISTER | ||||||||
| from lightx2v_platform.base.global_var import AI_DEVICE | ||||||||
|
|
||||||||
| torch_device_module = getattr(torch, AI_DEVICE) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||||||||
|
|
||||||||
|
|
||||||||
| def calculate_dimensions(target_area, ratio): | ||||||||
| width = math.sqrt(target_area * ratio) | ||||||||
|
|
@@ -45,8 +47,11 @@ def load_vae(self): | |||||||
|
|
||||||||
| def init_modules(self): | ||||||||
| logger.info(f"Initializing {self.config['model_cls']} modules...") | ||||||||
| self.load_model() | ||||||||
| self.model.set_scheduler(self.scheduler) | ||||||||
| if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): | ||||||||
| self.load_model() | ||||||||
| self.model.set_scheduler(self.scheduler) | ||||||||
| elif self.config.get("lazy_load", False): | ||||||||
| assert self.config.get("cpu_offload", False) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||
|
|
||||||||
| task = self.config.get("task", "t2i") | ||||||||
| if task == "i2i": | ||||||||
|
|
@@ -59,8 +64,12 @@ def init_modules(self): | |||||||
| @ProfilingContext4DebugL2("Run Encoders") | ||||||||
| def _run_input_encoder_local_t2i(self): | ||||||||
| prompt = self.input_info.prompt | ||||||||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||||||||
| self.text_encoders = self.load_text_encoder() | ||||||||
| text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) | ||||||||
| torch.cuda.empty_cache() | ||||||||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||||||||
| del self.text_encoders[0] | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Deleting only the first element of the list leaves
Suggested change
|
||||||||
| torch_device_module.empty_cache() | ||||||||
| gc.collect() | ||||||||
| return { | ||||||||
| "text_encoder_output": text_encoder_output, | ||||||||
|
|
@@ -70,7 +79,11 @@ def _run_input_encoder_local_t2i(self): | |||||||
| @ProfilingContext4DebugL2("Run Encoders I2I") | ||||||||
| def _run_input_encoder_local_i2i(self): | ||||||||
| prompt = self.input_info.prompt | ||||||||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||||||||
| self.text_encoders = self.load_text_encoder() | ||||||||
| text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) | ||||||||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||||||||
| del self.text_encoders[0] | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
|
|
||||||||
| image_path = self.input_info.image_path | ||||||||
| from PIL import Image | ||||||||
|
|
@@ -108,7 +121,7 @@ def _run_input_encoder_local_i2i(self): | |||||||
| if index == 0: | ||||||||
| self.input_info.target_shape = (image_height, image_width) | ||||||||
|
|
||||||||
| torch.cuda.empty_cache() | ||||||||
| torch_device_module.empty_cache() | ||||||||
| gc.collect() | ||||||||
|
|
||||||||
| return { | ||||||||
|
|
@@ -244,6 +257,9 @@ def set_img_shapes(self): | |||||||
|
|
||||||||
| @ProfilingContext4DebugL1("Run VAE Decoder") | ||||||||
| def run_vae_decoder(self, latents): | ||||||||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||||||||
| self.vae = self.load_vae() | ||||||||
|
|
||||||||
| B, _, C = latents.shape | ||||||||
|
|
||||||||
| H = int((self.input_info.latent_image_ids[0, :, 1].max() + 1).item()) | ||||||||
|
|
@@ -252,14 +268,20 @@ def run_vae_decoder(self, latents): | |||||||
| latents = latents.view(B, H, W, C).permute(0, 3, 1, 2) | ||||||||
|
|
||||||||
| bn_mean = self.vae.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) | ||||||||
| bn_std = torch.sqrt(self.vae.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.vae.config.batch_norm_eps) | ||||||||
| bn_std = torch.sqrt(self.vae.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.vae.config.batch_norm_eps).to(latents.device, latents.dtype) | ||||||||
| latents = latents * bn_std + bn_mean | ||||||||
|
|
||||||||
| latents = latents.reshape(B, C // 4, 2, 2, H, W) | ||||||||
| latents = latents.permute(0, 1, 4, 2, 5, 3) | ||||||||
| latents = latents.reshape(B, C // 4, H * 2, W * 2) | ||||||||
|
|
||||||||
| images = self.vae.decode(latents, self.input_info) | ||||||||
|
|
||||||||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||||||||
| del self.vae | ||||||||
| torch_device_module.empty_cache() | ||||||||
| gc.collect() | ||||||||
|
|
||||||||
| return images | ||||||||
|
|
||||||||
| @ProfilingContext4DebugL1("RUN pipeline") | ||||||||
|
|
@@ -279,7 +301,7 @@ def run_pipeline(self, input_info): | |||||||
| image.save(input_info.save_result_path) | ||||||||
| logger.info(f"Image saved: {input_info.save_result_path}") | ||||||||
|
|
||||||||
| torch.cuda.empty_cache() | ||||||||
| torch_device_module.empty_cache() | ||||||||
| gc.collect() | ||||||||
|
|
||||||||
| if input_info.return_result_tensor: | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| #!/bin/bash | ||
| lightx2v_path= | ||
| model_path="/data/temp/FLUX.2-dev" | ||
| export CUDA_VISIBLE_DEVICES=3 | ||
|
|
||
| source ${lightx2v_path}/scripts/base/base.sh | ||
|
|
||
| # Create output directory | ||
| mkdir -p ${lightx2v_path}/save_results | ||
|
|
||
| python -m lightx2v.infer \ | ||
| --model_cls flux2_dev \ | ||
| --task t2i \ | ||
| --model_path $model_path \ | ||
| --config_json "${lightx2v_path}/configs/flux2/flux2_dev_offload.json" \ | ||
| --prompt "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text 'BFL Diffusers' on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." \ | ||
| --save_result_path "${lightx2v_path}/save_results/flux2_dev_offload.png" \ | ||
| --seed 42 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching the generic
Exceptionclass is too broad and can mask unrelated errors. For PyTorch memory allocation failures, it is better to catchRuntimeErrorspecifically, as that is whattorch.emptytypically raises when pinned memory allocation fails.