import math import torch import re import gradio as gr import numpy as np import modules.scripts as scripts import modules.images as saving from modules import devices, processing, shared, sd_samplers_kdiffusion, sd_samplers_compvis, script_callbacks from modules.processing import Processed from modules.shared import opts, state from ldm.models.diffusion import ddim from PIL import Image from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, noise_like re_prompt_cfgr = re.compile(r"]+)>") class Script(scripts.Script): def __init__(self): self.old_denoising = sd_samplers_kdiffusion.CFGDenoiser.combine_denoised self.old_schedule = ddim.DDIMSampler.make_schedule self.old_sample = ddim.DDIMSampler.p_sample_ddim globals()['enable_furry_cocks'] = True def find_module(module_names): if isinstance(module_names, str): module_names = [s.strip() for s in module_names.split(",")] for data in scripts.scripts_data: if data.script_class.__module__ in module_names and hasattr(data, "module"): return data.module return None def rescale_opt(p, x, xs): globals()['cfg_rescale_fi'] = x globals()['enable_furry_cocks'] = False xyz_grid = find_module("xyz_grid.py, xy_grid.py") if xyz_grid: extra_axis_options = [xyz_grid.AxisOption("Rescale CFG", float, rescale_opt)] xyz_grid.axis_options.extend(extra_axis_options) def title(self): return "CFG Rescale Extension" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with gr.Accordion("CFG Rescale", open=True, elem_id="cfg_rescale"): rescale = gr.Slider(label="CFG Rescale", show_label=False, minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Row(): recolor = gr.Checkbox(label="Auto Color Fix", default=False) rec_strength = gr.Slider(label="Fix Strength", interactive=True, visible=False, elem_id=self.elem_id("rec_strength"), minimum=0.1, maximum=10.0, step=0.1, value=1.0) show_original = gr.Checkbox(label="Keep Original Images", elem_id=self.elem_id("show_original"), visible=False, default=False) def show_recolor_strength(rec_checked): return [gr.update(visible=rec_checked), gr.update(visible=rec_checked)] recolor.change( fn=show_recolor_strength, inputs=recolor, outputs=[rec_strength, show_original] ) self.infotext_fields = [ (rescale, "CFG Rescale"), (recolor, "Auto Color Fix") ] self.paste_field_names = [] for _, field_name in self.infotext_fields: self.paste_field_names.append(field_name) return [rescale, recolor, rec_strength, show_original] def cfg_replace(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) fi = globals()['cfg_rescale_fi'] for i, conds in enumerate(conds_list): for cond_index, weight in conds: if fi == 0: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) else: xcfg = (denoised_uncond[i] + (x_out[cond_index] - denoised_uncond[i]) * (cond_scale * weight)) xrescaled = (torch.std(x_out[cond_index]) / torch.std(xcfg)) xfinal = fi * xrescaled + (1.0 - fi) denoised[i] = xfinal * xcfg return denoised def process(self, p, rescale, recolor, rec_strength, show_original): if globals()['enable_furry_cocks']: # parse from prompt for override rescale_override = None def found(m): nonlocal rescale_override try: rescale_override = float(m.group(1)) except ValueError: rescale_override = None return "" p.prompt = re.sub(re_prompt_cfgr, found, p.prompt) if rescale_override is not None: rescale = rescale_override globals()['cfg_rescale_fi'] = rescale else: # rescale value is being set from xyz_grid rescale = globals()['cfg_rescale_fi'] globals()['enable_furry_cocks'] = True sd_samplers_kdiffusion.CFGDenoiser.combine_denoised = self.cfg_replace if rescale > 0: p.extra_generation_params["CFG Rescale"] = rescale if recolor: p.extra_generation_params["Auto Color Fix Strength"] = rec_strength p.do_not_save_samples = True def postprocess_batch_list(self, p, pp, rescale, recolor, rec_strength, show_original, batch_number): if recolor and show_original: num = len(pp.images) for i in range(num): pp.images.append(pp.images[i]) p.prompts.append(p.prompts[i]) p.negative_prompts.append(p.negative_prompts[i]) p.seeds.append(p.seeds[i]) p.subseeds.append(p.subseeds[i]) def postprocess(self, p, processed, rescale, recolor, rec_strength, show_original): sd_samplers_kdiffusion.CFGDenoiser.combine_denoised = self.old_denoising def postfix(img, rec_strength): prec = 0.0005 * rec_strength r, g, b = img.split() # softer effect # r_min, r_max = np.percentile(r, p), np.percentile(r, 100.0 - p) # g_min, g_max = np.percentile(g, p), np.percentile(g, 100.0 - p) # b_min, b_max = np.percentile(b, p), np.percentile(b, 100.0 - p) rh, rbins = np.histogram(r, 256, (0, 256)) tmp = np.where(rh > rh.sum() * prec)[0] r_min = tmp.min() r_max = tmp.max() gh, gbins = np.histogram(g, 256, (0, 256)) tmp = np.where(gh > gh.sum() * prec)[0] g_min = tmp.min() g_max = tmp.max() bh, bbins = np.histogram(b, 256, (0, 256)) tmp = np.where(bh > bh.sum() * prec)[0] b_min = tmp.min() b_max = tmp.max() r = r.point(lambda i: int(255 * (min(max(i, r_min), r_max) - r_min) / (r_max - r_min))) g = g.point(lambda i: int(255 * (min(max(i, g_min), g_max) - g_min) / (g_max - g_min))) b = b.point(lambda i: int(255 * (min(max(i, b_min), b_max) - b_min) / (b_max - b_min))) new_img = Image.merge("RGB", (r, g, b)) return new_img if recolor: grab = 0 n_img = len(processed.images) for i in range(n_img): doit = False if show_original: check = i if opts.return_grid: if i == 0: continue else: check = check - 1 doit = check % (p.batch_size * 2) >= p.batch_size else: if n_img > 1 and i != 0: doit = True elif n_img == 1 or not opts.return_grid: doit = True if doit: res_img = postfix(processed.images[i], rec_strength) if opts.samples_save: ind = grab grab += 1 prompt_infotext = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, index=ind) # Save images to disk if opts.samples_save: saving.save_image(processed.images[i], p.outpath_samples, "", seed=p.all_seeds[ind], prompt=p.all_prompts[ind], info=prompt_infotext, p=p, suffix="colorfix") saving.save_image(res_img, p.outpath_samples, "", seed=p.all_seeds[ind], prompt=p.all_prompts[ind], info=prompt_infotext, p=p, suffix="colorfix") processed.images[i] = res_img def on_infotext_pasted(infotext, params): if "CFG Rescale" not in params: params["CFG Rescale"] = 0 if "CFG Rescale φ" in params: params["CFG Rescale"] = params["CFG Rescale φ"] del params["CFG Rescale φ"] if "CFG Rescale phi" in params and scripts.scripts_txt2img.script("Neutral Prompt") is None: params["CFG Rescale"] = params["CFG Rescale phi"] del params["CFG Rescale phi"] if "DDIM Trailing" not in params: params["DDIM Trailing"] = False script_callbacks.on_infotext_pasted(on_infotext_pasted)