Spaces:
Running
on
Zero
Running
on
Zero
import fal_client | |
from PIL import Image | |
import requests | |
import io | |
import os | |
import base64 | |
FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl", | |
"LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"} | |
class FalModel(): | |
def __init__(self, model_name, model_type): | |
self.model_name = model_name | |
self.model_type = model_type | |
os.environ['FAL_KEY'] = os.environ['FalAPI'] | |
def __call__(self, *args, **kwargs): | |
def decode_data_url(data_url): | |
# Find the start of the Base64 encoded data | |
base64_start = data_url.find(",") + 1 | |
if base64_start == 0: | |
raise ValueError("Invalid data URL provided") | |
# Extract the Base64 encoded data | |
base64_string = data_url[base64_start:] | |
# Decode the Base64 string | |
decoded_bytes = base64.b64decode(base64_string) | |
return decoded_bytes | |
if self.model_type == "text2image": | |
assert "prompt" in kwargs, "prompt is required for text2image model" | |
handler = fal_client.submit( | |
f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}", | |
arguments={ | |
"prompt": kwargs["prompt"] | |
}, | |
) | |
for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
print('Request in progress') | |
print(event.logs) | |
result = handler.get() | |
print(result) | |
result_url = result['images'][0]['url'] | |
if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]: | |
result_url = io.BytesIO(decode_data_url(result_url)) | |
result = Image.open(result_url) | |
else: | |
response = requests.get(result_url) | |
result = Image.open(io.BytesIO(response.content)) | |
return result | |
elif self.model_type == "image2image": | |
raise NotImplementedError("image2image model is not implemented yet") | |
# assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" | |
# if "image" in kwargs: | |
# image_url = None | |
# pass | |
# handler = fal_client.submit( | |
# f"fal-ai/{self.model_name}", | |
# arguments={ | |
# "image_url": image_url | |
# }, | |
# ) | |
# | |
# for event in handler.iter_events(): | |
# if isinstance(event, fal_client.InProgress): | |
# print('Request in progress') | |
# print(event.logs) | |
# | |
# result = handler.get() | |
# return result | |
elif self.model_type == "text2video": | |
assert "prompt" in kwargs, "prompt is required for text2video model" | |
if self.model_name == 'AnimateDiff': | |
fal_model_name = 'fast-animatediff/text-to-video' | |
elif self.model_name == 'AnimateDiffTurbo': | |
fal_model_name = 'fast-animatediff/turbo/text-to-video' | |
elif self.model_name == 'StableVideoDiffusion': | |
fal_model_name = 'fast-svd/text-to-video' | |
else: | |
raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet") | |
handler = fal_client.submit( | |
f"fal-ai/{fal_model_name}", | |
arguments={ | |
"prompt": kwargs["prompt"] | |
}, | |
) | |
for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
print('Request in progress') | |
print(event.logs) | |
result = handler.get() | |
print("result video: ====") | |
print(result) | |
result_url = result['video']['url'] | |
return result_url | |
else: | |
raise ValueError("model_type must be text2image or image2image") | |
def load_fal_model(model_name, model_type): | |
return FalModel(model_name, model_type) |