Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d631586
split text in MM-DiT
xibosun Nov 22, 2024
d48b78b
check the use_parallel_vae flag for CogVideo
xibosun Nov 22, 2024
d35e495
fix dimensions in all_gather
xibosun Nov 27, 2024
faafcd1
optimizations on H100
xibosun Nov 27, 2024
a41b7c1
support optimized USP in Flux
xibosun Nov 28, 2024
55b8711
do not split text if undivisible by sp_degree
xibosun Nov 28, 2024
726f402
polish optimized USP
xibosun Nov 28, 2024
61b4b90
update diffusers versio in setup.py
xibosun Nov 28, 2024
b8b0b10
merge to main
xibosun Nov 28, 2024
d43176d
fix bugs
xibosun Nov 29, 2024
5d7c886
unify USP interface
xibosun Nov 29, 2024
3c17dea
optimized USP in CogVideo
xibosun Nov 29, 2024
cc1f2da
use optimized USP in cogvideo
xibosun Dec 3, 2024
73a071b
add CogVideoX1.5-5B performance on H20 and L20
xibosun Dec 3, 2024
7a687df
merge upstream main
xibosun Dec 3, 2024
8d6de01
rename files and update docs
xibosun Dec 3, 2024
c2fdea2
Merge remote-tracking branch 'upstream/main' into text_slice
xibosun Dec 5, 2024
0172a69
decouple retime state from USP
xibosun Dec 5, 2024
d45155f
add doc for adding new models
xibosun Dec 17, 2024
e5e21c7
Merge remote-tracking branch 'upstream/main' into text_slice
xibosun Dec 17, 2024
3fd6809
fix typos
xibosun Dec 17, 2024
a5318fb
add docs for adding models
xibosun Dec 18, 2024
a88cf5b
fix docs for adding models
xibosun Dec 18, 2024
57db9de
fix docs for adding new models
xibosun Dec 18, 2024
564a483
fix docs for adding models
xibosun Dec 19, 2024
01a68a2
add figure to illustrate USP
xibosun Dec 19, 2024
b495702
Merge remote-tracking branch 'upstream/main' into text_slice
xibosun Feb 10, 2025
4bbbd1b
feat: add usp implementations
xibosun Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions xfuser/model_executor/layers/usp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This file implements USP with torch version >= '2.5.0'
import torch
from torch.nn import functional as F
from torch.distributed.tensor.experimental._attention import _templated_ring_attention
aten = torch.ops.aten

import torch.distributed._functional_collectives as ft_c

from torch.distributed.tensor.experimental._attention import _templated_ring_attention

from yunchang.globals import PROCESS_GROUP

from xfuser.core.distributed import (
Expand All @@ -14,17 +14,74 @@
get_ring_parallel_world_size,
)

from xfuser.envs import PACKAGES_CHECKER
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_FLASH_ATTN = env_info["has_flash_attn"]

aten = torch.ops.aten


def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
dropout_p=dropout_p,
is_causal=is_causal
)
if torch.__version__ >= "2.6.0":
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
kwargs = {
"dropout_p": dropout_p,
"is_causal": is_causal,
}
if HAS_FLASH_ATTN:
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
1,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
**kwargs,
)
else:
kwargs = {
**kwargs,
"attn_bias": None,
"compute_log_sumexp": True,
}
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
1,
aten._scaled_dot_product_efficient_attention,
query,
key,
value,
**kwargs,
)
else:
kwargs = {
"dropout_p": dropout_p,
"is_causal": is_causal,
}
if HAS_FLASH_ATTN:
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
**kwargs
)
else:
kwargs = {
**kwargs,
"attn_bias": None,
"compute_log_sumexp": True,
}
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_efficient_attention,
query,
key,
value,
**kwargs,
)
return out


Expand Down
34 changes: 25 additions & 9 deletions xfuser/model_executor/layers/usp_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,43 @@

from yunchang.globals import PROCESS_GROUP
from yunchang.ring.ring_flash_attn import ring_flash_attn_forward
from yunchang.ring.ring_pytorch_attn import ring_pytorch_attn_func

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_ulysses_parallel_world_size,
get_ring_parallel_world_size,
)

from xfuser.envs import PACKAGES_CHECKER
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_FLASH_ATTN = env_info["has_flash_attn"]


def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
query = query.transpose(1,2).contiguous()
key = key.transpose(1,2).contiguous()
value = value.transpose(1,2).contiguous()
out, *_ = ring_flash_attn_forward(
PROCESS_GROUP.RING_PG,
query,
key,
value,
softmax_scale=query.shape[-1] ** (-0.5),
dropout_p=dropout_p,
causal=is_causal,
)
if HAS_FLASH_ATTN:
out, *_ = ring_flash_attn_forward(
PROCESS_GROUP.RING_PG,
query,
key,
value,
softmax_scale=query.shape[-1] ** (-0.5),
dropout_p=dropout_p,
causal=is_causal,
)
else:
out = ring_pytorch_attn_func(
query,
key,
value,
dropout_p=dropout_p,
softmax_scale=query.shape[-1] ** (-0.5),
causal=is_causal,
group=PROCESS_GROUP.RING_PG,
)
out = out.transpose(1,2).contiguous()
return out

Expand Down
Loading