Skip to content

Commit 46bd005

Browse files
authored
Convert tensors to float in Helios's optimized_scale function (#13214)
Convert tensors to float in optimized_scale function
1 parent 8ec0a5c commit 46bd005

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/pipelines/helios/pipeline_helios_pyramid.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676

7777

7878
def optimized_scale(positive_flat, negative_flat):
79+
positive_flat = positive_flat.float()
80+
negative_flat = negative_flat.float()
7981
# Calculate dot production
8082
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
8183
# Squared norm of uncondition

0 commit comments

Comments
 (0)