File size: 1,780 Bytes
c3a1897
 
 
 
 
 
 
 
 
 
 
 
eb902b3
 
c3a1897
 
 
 
 
b25eb4e
 
c3a1897
9b4b3ea
c3a1897
 
b25eb4e
c3a1897
 
 
 
b25eb4e
eb902b3
c3a1897
 
 
 
 
 
 
 
 
 
 
 
 
eb902b3
 
c3a1897
 
eb902b3
 
c3a1897
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import cv2
import torch
import numpy as np
from PIL import Image
from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    UniPCMultistepScheduler,
)


class TextToImage:
    def __init__(self, device):
        self.device = device
        self.model = self.initialize_model()

    def initialize_model(self):
        controlnet = ControlNetModel.from_pretrained(
            "fusing/stable-diffusion-v1-5-controlnet-canny",
            torch_dtype=torch.float16,
        )
        pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            safety_checker=None,
            torch_dtype=torch.float16,
        )
        pipeline.scheduler = UniPCMultistepScheduler.from_config(
            pipeline.scheduler.config
        )
        pipeline.enable_model_cpu_offload()
        pipeline.to(self.device)
        return pipeline

    @staticmethod
    def preprocess_image(image):
        image = np.array(image)
        low_threshold = 100
        high_threshold = 200
        image = cv2.Canny(image, low_threshold, high_threshold)
        image = np.stack([image, image, image], axis=2)
        image = Image.fromarray(image)
        return image

    def text_to_image(self, text, image):
        print('\033[1;35m' + '*' * 100 + '\033[0m')
        print('\nStep5, Text to Image:')
        image = self.preprocess_image(image)
        generated_image = self.model(text, image, num_inference_steps=20).images[0]
        print("Generated image has been svaed.")
        print('\033[1;35m' + '*' * 100 + '\033[0m')
        return generated_image
    
    def text_to_image_debug(self, text, image):
        print("text_to_image_debug")
        return image