colt12 commited on
Commit
9fb755b
·
verified ·
1 Parent(s): faac009

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +15 -3
handler.py CHANGED
@@ -1,16 +1,23 @@
1
  import torch
2
- from diffusers import StableDiffusionXLPipeline, DDIMScheduler # Import your desired scheduler
3
  import base64
4
  from io import BytesIO
5
  import os
6
 
7
  class InferenceHandler:
8
  def __init__(self):
 
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_name = "./" # Use the current directory
11
 
 
 
 
 
 
 
 
12
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
13
- model_name,
14
  torch_dtype=torch.float16,
15
  use_safetensors=True,
16
  use_auth_token=os.getenv("HUGGINGFACE_TOKEN")
@@ -20,12 +27,14 @@ class InferenceHandler:
20
  self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
21
 
22
  def __call__(self, inputs):
 
23
  prompt = inputs.get("prompt", "")
24
  if not prompt:
25
  raise ValueError("A prompt must be provided")
26
 
27
  negative_prompt = inputs.get("negative_prompt", "")
28
 
 
29
  image = self.pipe(
30
  prompt=prompt,
31
  negative_prompt=negative_prompt,
@@ -33,10 +42,13 @@ class InferenceHandler:
33
  guidance_scale=7.5
34
  ).images[0]
35
 
 
36
  buffered = BytesIO()
37
  image.save(buffered, format="PNG")
38
  image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
39
 
 
40
  return {"image_base64": image_base64}
41
 
 
42
  handler = InferenceHandler()
 
1
  import torch
2
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
  import base64
4
  from io import BytesIO
5
  import os
6
 
7
  class InferenceHandler:
8
  def __init__(self):
9
+ # Determine the device to run on
10
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
+ # Get the directory where this script is located
13
+ model_dir = os.path.dirname(os.path.abspath(__file__))
14
+
15
+ # Print the model directory for debugging purposes
16
+ print("Loading model from directory:", model_dir)
17
+
18
+ # Load the pipeline with authentication
19
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
20
+ model_dir,
21
  torch_dtype=torch.float16,
22
  use_safetensors=True,
23
  use_auth_token=os.getenv("HUGGINGFACE_TOKEN")
 
27
  self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
28
 
29
  def __call__(self, inputs):
30
+ # Extract the prompt from inputs
31
  prompt = inputs.get("prompt", "")
32
  if not prompt:
33
  raise ValueError("A prompt must be provided")
34
 
35
  negative_prompt = inputs.get("negative_prompt", "")
36
 
37
+ # Generate the image using the pipeline
38
  image = self.pipe(
39
  prompt=prompt,
40
  negative_prompt=negative_prompt,
 
42
  guidance_scale=7.5
43
  ).images[0]
44
 
45
+ # Convert the image to base64 encoding
46
  buffered = BytesIO()
47
  image.save(buffered, format="PNG")
48
  image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
49
 
50
+ # Return the base64 image
51
  return {"image_base64": image_base64}
52
 
53
+ # Instantiate the handler
54
  handler = InferenceHandler()