File size: 1,516 Bytes
6880dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f72a8e
 
 
 
 
 
 
 
 
 
 
 
6880dce
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from huggingface_hub import from_pretrained_keras
from keras_cv import models
import gradio as gr

from tensorflow import keras

keras.mixed_precision.set_global_policy("mixed_float16")

# prepare model
resolution = 512
sd_dreambooth_model = models.StableDiffusion(
    img_width=resolution, img_height=resolution
)
db_diffusion_model = from_pretrained_keras("kfahn/dreambooth-mandelbulb")
sd_dreambooth_model._diffusion_model = db_diffusion_model

# generate images
def infer(prompt, negative_prompt, guidance_scale=10, num_inference_steps=50):
    neg = negative_prompt if negative_prompt else None
    imgs = []
    while len(imgs) != 2:
        next_prompt = pipeline(prompt, negative_prompt=neg, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, num_images_per_prompt=5)
        for img, is_neg in zip(next_prompt.images, next_prompt.nsfw_content_detected):
            if not is_neg:
                imgs.append(img)
            if len(imgs) == 2:
                break
            
    return imgs
    
output = gr.Gallery(label="Outputs").style(grid=(1,2))

# customize interface
title = "Dreambooth Mandelbulb flower"
description = "This is a dreambooth model fine-tuned on mandelbulb images. To try it, input the concept with {sks a hydrangea floweret shaped like a mandelbulb}."
examples=[["sks a hydrangea floweret shaped like a mandelbulb on a bush"]]
gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples).queue().launch()