11import dataclasses
2-
32import torch
4-
53import k_diffusion
6-
74import numpy as np
85
96from modules import shared
107
8+
119def to_d (x , sigma , denoised ):
1210 """Converts a denoiser output to a Karras ODE derivative."""
1311 return (x - denoised ) / sigma
1412
13+
1514k_diffusion .sampling .to_d = to_d
1615
16+
1717@dataclasses .dataclass
1818class Scheduler :
1919 name : str
@@ -25,21 +25,22 @@ class Scheduler:
2525 aliases : list = None
2626
2727
28- def uniform (n , sigma_min , sigma_max , inner_model ):
29- return inner_model .get_sigmas (n )
28+ def uniform (n , sigma_min , sigma_max , inner_model , device ):
29+ return inner_model .get_sigmas (n ). to ( device )
3030
3131
32- def sgm_uniform (n , sigma_min , sigma_max , inner_model ):
32+ def sgm_uniform (n , sigma_min , sigma_max , inner_model , device ):
3333 start = inner_model .sigma_to_t (torch .tensor (sigma_max ))
3434 end = inner_model .sigma_to_t (torch .tensor (sigma_min ))
3535 sigs = [
3636 inner_model .t_to_sigma (ts )
3737 for ts in torch .linspace (start , end , n + 1 )[:- 1 ]
3838 ]
3939 sigs += [0.0 ]
40- return torch .FloatTensor (sigs )
40+ return torch .FloatTensor (sigs ). to ( device )
4141
42- def get_align_your_steps_sigmas (n , sigma_min , sigma_max ):
42+
43+ def get_align_your_steps_sigmas (n , sigma_min , sigma_max , device ):
4344 # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
4445 def loglinear_interp (t_steps , num_steps ):
4546 """
@@ -65,12 +66,13 @@ def loglinear_interp(t_steps, num_steps):
6566 else :
6667 sigmas .append (0.0 )
6768
68- return torch .FloatTensor (sigmas )
69+ return torch .FloatTensor (sigmas ).to (device )
70+
6971
70- def kl_optimal (n , sigma_min , sigma_max ):
71- alpha_min = torch .arctan (torch .tensor (sigma_min ))
72- alpha_max = torch .arctan (torch .tensor (sigma_max ))
73- step_indices = torch .arange (n + 1 )
72+ def kl_optimal (n , sigma_min , sigma_max , device ):
73+ alpha_min = torch .arctan (torch .tensor (sigma_min , device = device ))
74+ alpha_max = torch .arctan (torch .tensor (sigma_max , device = device ))
75+ step_indices = torch .arange (n + 1 , device = device )
7476 sigmas = torch .tan (step_indices / n * alpha_min + (1.0 - step_indices / n ) * alpha_max )
7577 return sigmas
7678
0 commit comments