File size: 1,769 Bytes
aafe4a6
9fb755b
fa41512
 
e8c954c
eb40866
fa41512
 
9fb755b
fa41512
f0c76db
9fb755b
 
 
 
 
 
 
da36afd
9fb755b
da36afd
f0c76db
50456b3
f0c76db
eb40866
50456b3
 
 
fa41512
9fb755b
fa41512
 
 
50456b3
fa41512
e8c954c
9fb755b
fa41512
 
 
 
 
 
50456b3
9fb755b
fa41512
 
 
50456b3
9fb755b
fa41512
eb40866
9fb755b
fa41512
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import base64
from io import BytesIO
import os

class InferenceHandler:
    def __init__(self):
        # Determine the device to run on
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Get the directory where this script is located
        model_dir = os.path.dirname(os.path.abspath(__file__))

        # Print the model directory for debugging purposes
        print("Loading model from directory:", model_dir)

        # Load the pipeline with authentication
        self.pipe = StableDiffusionXLPipeline.from_pretrained(
            model_dir,
            torch_dtype=torch.float16,
            use_safetensors=True,
            use_auth_token=os.getenv("HUGGINGFACE_TOKEN")
        ).to(self.device)

        # Set the scheduler programmatically
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)

    def __call__(self, inputs):
        # Extract the prompt from inputs
        prompt = inputs.get("prompt", "")
        if not prompt:
            raise ValueError("A prompt must be provided")

        negative_prompt = inputs.get("negative_prompt", "")

        # Generate the image using the pipeline
        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=30,
            guidance_scale=7.5
        ).images[0]

        # Convert the image to base64 encoding
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        # Return the base64 image
        return {"image_base64": image_base64}

# Instantiate the handler
handler = InferenceHandler()