File size: 6,075 Bytes
ebb9992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
@torch.no_grad()
def sd_pipeline_call(
pipeline: StableDiffusionPipeline,
prompt_embeds: torch.FloatTensor,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
""" Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
# 0. Default height and width to unet
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
# 2. Define call parameters
batch_size = 1
device = pipeline._execution_device
neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
negative_prompt_embeds, _ = pipeline.text_encoder(
input_ids=neg_prompt.input_ids.to(device),
attention_mask=None,
)
negative_prompt_embeds = negative_prompt_embeds[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 4. Prepare timesteps
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipeline.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = pipeline.unet.in_channels
latents = pipeline.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
pipeline.text_encoder.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if do_classifier_free_guidance:
latent_model_input = latents
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred_uncond = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
cross_attention_kwargs=cross_attention_kwargs,
).sample
###############################################################
# NeTI logic: use the prompt embedding for the current timestep
###############################################################
embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
noise_pred_text = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=embed,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = pipeline.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
# 10. Convert to PIL
image = pipeline.numpy_to_pil(image)
else:
# 8. Post-processing
image = pipeline.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
# Offload last model to CPU
if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
pipeline.final_offload_hook.offload()
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
negative_prompt: Optional[Union[str, List[str]]] = None):
if negative_prompt is None:
negative_prompt = ""
uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
uncond_input = pipeline.tokenizer(
uncond_tokens,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
return uncond_input
|