Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions configs/flux2/flux2_dev_offload.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_cls": "flux2_dev",
"task": "t2i",
"infer_steps": 50,
"sample_guide_scale": 4.0,
"vae_scale_factor": 16,
"feature_caching": "None",
"enable_cfg": false,
"patch_size": 2,
"tokenizer_max_length": 512,
"rope_type": "flashinfer",
"text_encoder_out_layers": [10, 20, 30],
"cpu_offload": true,
"offload_granularity": "block"
}
10 changes: 8 additions & 2 deletions lightx2v/common/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

import torch
from loguru import logger
from safetensors import safe_open

from lightx2v.utils.envs import *
Expand Down Expand Up @@ -81,10 +82,15 @@ def create_pin_tensor(tensor, transpose=False, dtype=None):
dtype: Target data type of the pinned tensor (optional, defaults to source tensor's dtype)

Returns:
Pinned memory tensor (on CPU) with optional transposition applied
Pinned memory tensor (on CPU) with optional transposition applied.
Falls back to regular CPU tensor if pinned memory allocation fails.
"""
dtype = dtype or tensor.dtype
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
try:
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
except Exception as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching the generic Exception class is too broad and can mask unrelated errors. For PyTorch memory allocation failures, it is better to catch RuntimeError specifically, as that is what torch.empty typically raises when pinned memory allocation fails.

Suggested change
except Exception as e:
except RuntimeError as e:

logger.warning(f"Failed to allocate pinned memory (shape={tensor.shape}, dtype={dtype}): {e}. Falling back to regular CPU memory.")
pin_tensor = torch.empty(tensor.shape, dtype=dtype)
pin_tensor = pin_tensor.copy_(tensor)
if transpose:
pin_tensor = pin_tensor.t()
Expand Down
6 changes: 6 additions & 0 deletions lightx2v/models/networks/flux2/weights/transformer_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,18 @@ def to_cuda(self, non_blocking=True):
block.to_cuda(non_blocking=non_blocking)
for block in self.single_blocks:
block.to_cuda(non_blocking=non_blocking)
self.double_stream_modulation_img_linear.to_cuda(non_blocking=non_blocking)
self.double_stream_modulation_txt_linear.to_cuda(non_blocking=non_blocking)
self.single_stream_modulation_linear.to_cuda(non_blocking=non_blocking)

def to_cpu(self, non_blocking=True):
for block in self.double_blocks:
block.to_cpu(non_blocking=non_blocking)
for block in self.single_blocks:
block.to_cpu(non_blocking=non_blocking)
self.double_stream_modulation_img_linear.to_cpu(non_blocking=non_blocking)
self.double_stream_modulation_txt_linear.to_cpu(non_blocking=non_blocking)
self.single_stream_modulation_linear.to_cpu(non_blocking=non_blocking)


# Backward-compatible aliases
Expand Down
34 changes: 28 additions & 6 deletions lightx2v/models/runners/flux2/flux2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using getattr(torch, AI_DEVICE) is risky. If AI_DEVICE is a device string like "cuda:0" or "cpu", this will raise an AttributeError. Typically, AI_DEVICE refers to the device identifier used with torch.device(), while getattr expects a module name like "cuda" or "mps". Additionally, torch does not have a cpu attribute that acts as a device module. Consider extracting the device type (e.g., AI_DEVICE.split(':')[0]) and handling the "cpu" case explicitly to avoid a crash.



def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
Expand Down Expand Up @@ -45,8 +47,11 @@ def load_vae(self):

def init_modules(self):
logger.info(f"Initializing {self.config['model_cls']} modules...")
self.load_model()
self.model.set_scheduler(self.scheduler)
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model()
self.model.set_scheduler(self.scheduler)
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for runtime configuration validation is discouraged because assertions can be disabled in optimized Python execution (using the -O flag). It is better to raise a ValueError to ensure the check is always performed.

Suggested change
assert self.config.get("cpu_offload", False)
if not self.config.get("cpu_offload", False):
raise ValueError("cpu_offload must be enabled when lazy_load is true")


task = self.config.get("task", "t2i")
if task == "i2i":
Expand All @@ -59,8 +64,12 @@ def init_modules(self):
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2i(self):
prompt = self.input_info.prompt
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
torch.cuda.empty_cache()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Deleting only the first element of the list leaves self.text_encoders as an empty list []. It is cleaner and more consistent with how self.vae is handled (line 281) to delete the entire attribute, which also avoids potential IndexError if the list is accessed elsewhere while empty.

Suggested change
del self.text_encoders[0]
del self.text_encoders

torch_device_module.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
Expand All @@ -70,7 +79,11 @@ def _run_input_encoder_local_t2i(self):
@ProfilingContext4DebugL2("Run Encoders I2I")
def _run_input_encoder_local_i2i(self):
prompt = self.input_info.prompt
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Deleting only the first element of the list leaves self.text_encoders as an empty list []. It is cleaner and more consistent with how self.vae is handled (line 281) to delete the entire attribute.

Suggested change
del self.text_encoders[0]
del self.text_encoders


image_path = self.input_info.image_path
from PIL import Image
Expand Down Expand Up @@ -108,7 +121,7 @@ def _run_input_encoder_local_i2i(self):
if index == 0:
self.input_info.target_shape = (image_height, image_width)

torch.cuda.empty_cache()
torch_device_module.empty_cache()
gc.collect()

return {
Expand Down Expand Up @@ -244,6 +257,9 @@ def set_img_shapes(self):

@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae = self.load_vae()

B, _, C = latents.shape

H = int((self.input_info.latent_image_ids[0, :, 1].max() + 1).item())
Expand All @@ -252,14 +268,20 @@ def run_vae_decoder(self, latents):
latents = latents.view(B, H, W, C).permute(0, 3, 1, 2)

bn_mean = self.vae.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
bn_std = torch.sqrt(self.vae.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.vae.config.batch_norm_eps)
bn_std = torch.sqrt(self.vae.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.vae.config.batch_norm_eps).to(latents.device, latents.dtype)
latents = latents * bn_std + bn_mean

latents = latents.reshape(B, C // 4, 2, 2, H, W)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(B, C // 4, H * 2, W * 2)

images = self.vae.decode(latents, self.input_info)

if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae
torch_device_module.empty_cache()
gc.collect()

return images

@ProfilingContext4DebugL1("RUN pipeline")
Expand All @@ -279,7 +301,7 @@ def run_pipeline(self, input_info):
image.save(input_info.save_result_path)
logger.info(f"Image saved: {input_info.save_result_path}")

torch.cuda.empty_cache()
torch_device_module.empty_cache()
gc.collect()

if input_info.return_result_tensor:
Expand Down
18 changes: 18 additions & 0 deletions scripts/flux2/infer_flux2_dev_offload.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash
lightx2v_path=
model_path="/data/temp/FLUX.2-dev"
export CUDA_VISIBLE_DEVICES=3

source ${lightx2v_path}/scripts/base/base.sh

# Create output directory
mkdir -p ${lightx2v_path}/save_results

python -m lightx2v.infer \
--model_cls flux2_dev \
--task t2i \
--model_path $model_path \
--config_json "${lightx2v_path}/configs/flux2/flux2_dev_offload.json" \
--prompt "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text 'BFL Diffusers' on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." \
--save_result_path "${lightx2v_path}/save_results/flux2_dev_offload.png" \
--seed 42
Loading