Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

README.md

Training for FLUX

Table of Contents

Environment Setup

  1. Create and activate a new conda environment:

    conda create -n omini python=3.10
    conda activate omini
  2. Install required packages:

    pip install -r requirements.txt

Dataset Preparation

  1. Download Subject200K dataset for subject-driven generation:

    bash train/script/data_download/data_download1.sh
  2. Download text-to-image-2M dataset for spatial alignment control tasks:

    bash train/script/data_download/data_download2.sh

    Note: By default, only a few files will be downloaded. You can edit data_download2.sh to download more data, and update the config file accordingly.

Quick Start

Use these scripts to start training immediately:

  1. Subject-driven generation:

    bash train/script/train_subject.sh
  2. Spatial control tasks (Canny-to-image, colorization, depth map, etc.):

    bash train/script/train_spatial_alignment.sh
  3. Multi-condition training:

    bash train/script/train_multi_condition.sh
  4. Feature reuse (OminiControl2):

    bash train/script/train_feature_reuse.sh
  5. Compact token representation (OminiControl2):

    bash train/script/train_compact_token_representation.sh
  6. Token integration (OminiControl2):

    bash train/script/train_token_intergration.sh

Basic Training

Tasks from OminiControl

arXiv

  1. Subject-driven generation:

    bash train/script/train_subject.sh
  2. Spatial control tasks (using canny-to-image as example):

    bash train/script/train_spatial_alignment.sh
    Supported tasks
    • Canny edge to image (canny)
    • Image colorization (coloring)
    • Image deblurring (deblurring)
    • Depth map to image (depth)
    • Image to depth map (depth_pred)
    • Image inpainting (fill)
    • Super resolution (sr)

    🌟 Change the condition_type parameter in the config file to switch between tasks.

Note: Check the script files (train/script/) and config files (train/configs/) for WanDB and GPU settings.

Creating Your Own Task

You can create a custom task by building a new dataset and modifying the test code:

  1. Create a custom dataset: Your custom dataset should follow the format of Subject200KDataset in omini/train_flux/train_subject.py. Each sample should contain:

    • Image: the target image (image)
    • Text: description of the image (description)
    • Conditions: image conditions for generation
    • Position delta:
      • Use position_delta = (0, 0) to align the condition with the generated image
      • Use position_delta = (0, -a) to separate them (a = condition width / 16)

    Explanation:
    The model places both the condition and generated image in a shared coordinate system. position_delta shifts the condition image in this space.

    Each unit equals one patch (16 pixels). For a 512px-wide condition image (32 patches), position_delta = (0, -32) moves it fully to the left.

    This controls whether conditions and generated images share space or appear side-by-side.

  2. Modify the test code: Define test_function() in train_custom.py. Refer to the function in train_subject.py for examples. Make sure to keep the position_delta parameter consistent with your dataset.

Training Configuration

Batch Size

We recommend a batch size of 1 for stable training. And you can set accumulate_grad_batches to n to simulate a batch size of n.

Optimizer

The default optimizer is Prodigy. To use AdamW instead, modify the config file:

optimizer:
  type: AdamW
  lr: 1e-4
  weight_decay: 0.001

LoRA Configuration

Default LoRA rank is 4. Increase it for complex tasks (keep r and lora_alpha parameters the same):

lora_config:
  r: 128
  lora_alpha: 128

Trainable Modules

The target_modules parameter uses regex patterns to specify which modules to train. See PEFT Documentation for details.

Default configuration trains all modules affecting image tokens:

target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"

To train only attention components (to_q, to_k, to_v), use:

target_modules: "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v)"

Advanced Training

Multi-condition

A basic multi-condition implementation is available in train_multi_condition.py:

bash train/script/train_multi_condition.sh

Efficient Generation (OminiControl2)

arXiv

OminiControl2 introduces techniques to improve generation efficiency:

Feature Reuse (KV-Cache)

  1. Enable independent_condition in the config file during training:

    model:
      independent_condition: true
  2. During inference, set kv_cache = True in the generate function to speed up generation.

Example:

bash train/script/train_feature_reuse.sh

Note: Feature reuse speeds up generation but may slightly reduce performance and increase training time.

Compact Encoding Representation

Reduce the condition image resolution and use position_scale to align it with the output image:

train:
  dataset:
    condition_size: 
-     - 512
-     - 512
+     - 256
+     - 256
+   position_scale: 2
    target_size: 
      - 512
      - 512

Example:

bash train/script/train_compact_token_representation.sh

Token Integration (for Fill task)

Further reduce tokens by merging condition and generation tokens into a unified sequence. (Refer to the paper for details.)

Example:

bash train/script/train_token_intergration.sh

Citation

If you find this code useful, please cite our papers:

@article{tan2024ominicontrol,
  title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
  author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
  journal={arXiv preprint arXiv:2411.15098},
  year={2024}
}

@article{tan2025ominicontrol2,
  title={OminiControl2: Efficient Conditioning for Diffusion Transformers},
  author={Tan, Zhenxiong and Xue, Qiaochu and Yang, Xingyi and Liu, Songhua and Wang, Xinchao},
  journal={arXiv preprint arXiv:2503.08280},
  year={2025}
}