from typing import Dict, List, Any import base64 from PIL import Image from io import BytesIO from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers import StableDiffusionPipeline import torch # import numpy as np # import cv2 # # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") # set mixed precision dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler(): def __init__(self, path=""): self.stable_diffusion_id = "Lykon/dreamshaper-8" self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type) #self.pipe.enable_xformers_memory_efficient_attention() #self.pipe.enable_vae_tiling() self.generator = torch.Generator(device=device.type).manual_seed(3) targets = [ self.pipe.vae, self.pipe.text_encoder, self.pipe.unet, ] self.conv_layers = [] self.conv_layers_original_paddings = [] for target in targets: for module in target.modules(): if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.ConvTranspose2d): self.conv_layers.append(module) self.conv_layers_original_paddings.append(module.padding_mode) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: # """ # :param data: A dictionary contains `inputs` and optional `image` field. # :return: A dictionary with `image` field contains image in base64. # """ prompt = data.pop("inputs", None) num_inference_steps = data.pop("num_inference_steps", 30) guidance_scale = data.pop("guidance_scale", 7.4) negative_prompt = data.pop("negative_prompt", None) height = data.pop("height", None) width = data.pop("width", None) # run inference pipeline out = self.pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, height=height, width=width, generator=self.generator ) # return first generate PIL image return out.images[0]