Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,283 Bytes
944dd2b 94bd22c 944dd2b 94bd22c e049190 e368cec 944dd2b e368cec 94bd22c e368cec 944dd2b 94bd22c e368cec 944dd2b e368cec 94bd22c 944dd2b 07e4294 94bd22c e368cec 944dd2b 26dad4e 765fb5e 7d60da4 26dad4e 944dd2b 26dad4e e368cec 944dd2b e368cec 944dd2b e368cec 944dd2b e368cec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import fal_client
from PIL import Image
import requests
import io
import os
import base64
FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl",
"LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"}
class FalModel():
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
os.environ['FAL_KEY'] = os.environ['FalAPI']
def __call__(self, *args, **kwargs):
def decode_data_url(data_url):
# Find the start of the Base64 encoded data
base64_start = data_url.find(",") + 1
if base64_start == 0:
raise ValueError("Invalid data URL provided")
# Extract the Base64 encoded data
base64_string = data_url[base64_start:]
# Decode the Base64 string
decoded_bytes = base64.b64decode(base64_string)
return decoded_bytes
if self.model_type == "text2image":
assert "prompt" in kwargs, "prompt is required for text2image model"
handler = fal_client.submit(
f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}",
arguments={
"prompt": kwargs["prompt"]
},
)
for event in handler.iter_events(with_logs=True):
if isinstance(event, fal_client.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
print(result)
result_url = result['images'][0]['url']
if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]:
result_url = io.BytesIO(decode_data_url(result_url))
result = Image.open(result_url)
else:
response = requests.get(result_url)
result = Image.open(io.BytesIO(response.content))
return result
elif self.model_type == "image2image":
raise NotImplementedError("image2image model is not implemented yet")
# assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model"
# if "image" in kwargs:
# image_url = None
# pass
# handler = fal_client.submit(
# f"fal-ai/{self.model_name}",
# arguments={
# "image_url": image_url
# },
# )
#
# for event in handler.iter_events():
# if isinstance(event, fal_client.InProgress):
# print('Request in progress')
# print(event.logs)
#
# result = handler.get()
# return result
elif self.model_type == "text2video":
assert "prompt" in kwargs, "prompt is required for text2video model"
if self.model_name == 'AnimateDiff':
fal_model_name = 'fast-animatediff/text-to-video'
elif self.model_name == 'AnimateDiffTurbo':
fal_model_name = 'fast-animatediff/turbo/text-to-video'
elif self.model_name == 'StableVideoDiffusion':
fal_model_name = 'fast-svd/text-to-video'
else:
raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet")
handler = fal_client.submit(
f"fal-ai/{fal_model_name}",
arguments={
"prompt": kwargs["prompt"]
},
)
for event in handler.iter_events(with_logs=True):
if isinstance(event, fal_client.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
print("result video: ====")
print(result)
result_url = result['video']['url']
return result_url
else:
raise ValueError("model_type must be text2image or image2image")
def load_fal_model(model_name, model_type):
return FalModel(model_name, model_type) |