Skip to content

Commit bb24c13

Browse files
committed
infotext support for #14978
1 parent aabedcb commit bb24c13

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

modules/infotext_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
359359
if "Emphasis" not in res:
360360
res["Emphasis"] = "Original"
361361

362+
if "Refiner switch by sampling steps" not in res:
363+
res["Refiner switch by sampling steps"] = False
364+
362365
infotext_versions.backcompat(res)
363366

364367
for key in skip_fields:

modules/infotext_versions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
v160 = version.parse("1.6.0")
77
v170_tsnr = version.parse("v1.7.0-225")
8+
v180 = version.parse("1.8.0")
89

910

1011
def parse_version(text):
@@ -40,3 +41,5 @@ def backcompat(d):
4041
if ver < v170_tsnr:
4142
d["Downcast alphas_cumprod"] = True
4243

44+
if ver < v180 and d.get('Refiner'):
45+
d["Refiner switch by sampling steps"] = True

modules/sd_samplers_common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,16 @@ def torchsde_randn(size, dtype, device, seed):
155155
replace_torchsde_browinan()
156156

157157

158-
def apply_refiner(cfg_denoiser, sigma):
159-
if opts.refiner_switch_by_sample_steps:
158+
def apply_refiner(cfg_denoiser, sigma=None):
159+
if opts.refiner_switch_by_sample_steps or not sigma:
160160
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
161+
cfg_denoiser.p.extra_generation_params["Refiner switch by sampling steps"] = True
162+
161163
else:
162164
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
163165
try:
164166
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
165-
except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep
167+
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
166168
timestep = torch.max(sigma).to(dtype=int)
167169
completed_ratio = (999 - timestep) / 1000
168170

0 commit comments

Comments
 (0)