colt12 commited on
Commit
e8c954c
·
verified ·
1 Parent(s): 38fdab0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -4
handler.py CHANGED
@@ -2,28 +2,31 @@ import torch
2
  from diffusers import StableDiffusionXLPipeline
3
  import base64
4
  from io import BytesIO
 
5
 
6
  class InferenceHandler:
7
  def __init__(self):
 
8
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model_name = "colt12/maxcushion"
10
 
11
- # If your model is private, include the use_auth_token parameter
12
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
13
  model_name,
14
  torch_dtype=torch.float16,
15
  use_safetensors=True,
16
- # Uncomment the line below and replace with your token if needed
17
- # use_auth_token="your_huggingface_token"
18
  ).to(self.device)
19
 
20
  def __call__(self, inputs):
 
21
  prompt = inputs.get("prompt", "")
22
  if not prompt:
23
  raise ValueError("A prompt must be provided")
24
 
25
  negative_prompt = inputs.get("negative_prompt", "")
26
-
 
27
  image = self.pipe(
28
  prompt=prompt,
29
  negative_prompt=negative_prompt,
@@ -31,10 +34,13 @@ class InferenceHandler:
31
  guidance_scale=7.5
32
  ).images[0]
33
 
 
34
  buffered = BytesIO()
35
  image.save(buffered, format="PNG")
36
  image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
37
 
 
38
  return {"image_base64": image_base64}
39
 
 
40
  handler = InferenceHandler()
 
2
  from diffusers import StableDiffusionXLPipeline
3
  import base64
4
  from io import BytesIO
5
+ import os
6
 
7
  class InferenceHandler:
8
  def __init__(self):
9
+ # Determine the device to run on
10
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_name = "colt12/maxcushion"
12
 
13
+ # Load the pipeline with authentication from environment variable
14
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
15
  model_name,
16
  torch_dtype=torch.float16,
17
  use_safetensors=True,
18
+ use_auth_token=os.getenv("HUGGINGFACE_TOKEN") # Securely get the token
 
19
  ).to(self.device)
20
 
21
  def __call__(self, inputs):
22
+ # Extract the prompt from inputs
23
  prompt = inputs.get("prompt", "")
24
  if not prompt:
25
  raise ValueError("A prompt must be provided")
26
 
27
  negative_prompt = inputs.get("negative_prompt", "")
28
+
29
+ # Generate the image using the pipeline
30
  image = self.pipe(
31
  prompt=prompt,
32
  negative_prompt=negative_prompt,
 
34
  guidance_scale=7.5
35
  ).images[0]
36
 
37
+ # Convert the image to base64 encoding
38
  buffered = BytesIO()
39
  image.save(buffered, format="PNG")
40
  image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
41
 
42
+ # Return the base64 image
43
  return {"image_base64": image_base64}
44
 
45
+ # Instantiate the handler
46
  handler = InferenceHandler()