Skip to content

Commit d2829fe

Browse files
AUTOMATIC1111ruchej
authored andcommitted
undo some changes from AUTOMATIC1111#15823 and fix whitespace
1 parent 203e5e7 commit d2829fe

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

modules/sd_samplers_kdiffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import inspect
33
import k_diffusion.sampling
4-
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers
4+
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
55
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
66
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
77

@@ -115,7 +115,7 @@ def get_sigmas(self, p, steps):
115115
if scheduler.need_inner_model:
116116
sigmas_kwargs['inner_model'] = self.model_wrap
117117

118-
sigmas = scheduler.function(n=steps, **sigmas_kwargs)
118+
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)
119119

120120
if discard_next_to_last_sigma:
121121
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])

modules/sd_schedulers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import dataclasses
2-
32
import torch
4-
53
import k_diffusion
6-
74
import numpy as np
85

96
from modules import shared
107

8+
119
def to_d(x, sigma, denoised):
1210
"""Converts a denoiser output to a Karras ODE derivative."""
1311
return (x - denoised) / sigma
1412

13+
1514
k_diffusion.sampling.to_d = to_d
1615

16+
1717
@dataclasses.dataclass
1818
class Scheduler:
1919
name: str
@@ -25,21 +25,22 @@ class Scheduler:
2525
aliases: list = None
2626

2727

28-
def uniform(n, sigma_min, sigma_max, inner_model):
29-
return inner_model.get_sigmas(n)
28+
def uniform(n, sigma_min, sigma_max, inner_model, device):
29+
return inner_model.get_sigmas(n).to(device)
3030

3131

32-
def sgm_uniform(n, sigma_min, sigma_max, inner_model):
32+
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
3333
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
3434
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
3535
sigs = [
3636
inner_model.t_to_sigma(ts)
3737
for ts in torch.linspace(start, end, n + 1)[:-1]
3838
]
3939
sigs += [0.0]
40-
return torch.FloatTensor(sigs)
40+
return torch.FloatTensor(sigs).to(device)
4141

42-
def get_align_your_steps_sigmas(n, sigma_min, sigma_max):
42+
43+
def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device):
4344
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
4445
def loglinear_interp(t_steps, num_steps):
4546
"""
@@ -65,12 +66,13 @@ def loglinear_interp(t_steps, num_steps):
6566
else:
6667
sigmas.append(0.0)
6768

68-
return torch.FloatTensor(sigmas)
69+
return torch.FloatTensor(sigmas).to(device)
70+
6971

70-
def kl_optimal(n, sigma_min, sigma_max):
71-
alpha_min = torch.arctan(torch.tensor(sigma_min))
72-
alpha_max = torch.arctan(torch.tensor(sigma_max))
73-
step_indices = torch.arange(n + 1)
72+
def kl_optimal(n, sigma_min, sigma_max, device):
73+
alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
74+
alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
75+
step_indices = torch.arange(n + 1, device=device)
7476
sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
7577
return sigmas
7678

0 commit comments

Comments
 (0)