from typing import Dict, List, Any import base64 from PIL import Image from io import BytesIO from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker import torch from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline # # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") # set mixed precision dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler(): def __init__(self, path=""): # self.stable_diffusion_id = "Lykon/dreamshaper-8" # self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device) # self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device) self.generator = torch.Generator(device=device.type).manual_seed(3) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: # import torch device = "cuda" num_images_per_prompt = 1 prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) prompt = "Anthropomorphic cat dressed as a pilot" negative_prompt = "" prior_output = prior( prompt=prompt, height=512, width=512, negative_prompt=negative_prompt, guidance_scale=7.0, num_images_per_prompt=num_images_per_prompt, num_inference_steps=20 ) decoder_output = decoder( image_embeddings=prior_output.image_embeddings.half(), prompt=prompt, negative_prompt=negative_prompt, guidance_scale=7.0, output_type="pil", num_inference_steps=10 ).images return decoder_output[0]