GenAI-Arena / model /models /fal_api_models.py
yuanshengni's picture
update fal_svd
7d60da4
raw
history blame
4.28 kB
import fal_client
from PIL import Image
import requests
import io
import os
import base64
FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl",
"LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"}
class FalModel():
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
os.environ['FAL_KEY'] = os.environ['FalAPI']
def __call__(self, *args, **kwargs):
def decode_data_url(data_url):
# Find the start of the Base64 encoded data
base64_start = data_url.find(",") + 1
if base64_start == 0:
raise ValueError("Invalid data URL provided")
# Extract the Base64 encoded data
base64_string = data_url[base64_start:]
# Decode the Base64 string
decoded_bytes = base64.b64decode(base64_string)
return decoded_bytes
if self.model_type == "text2image":
assert "prompt" in kwargs, "prompt is required for text2image model"
handler = fal_client.submit(
f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}",
arguments={
"prompt": kwargs["prompt"]
},
)
for event in handler.iter_events(with_logs=True):
if isinstance(event, fal_client.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
print(result)
result_url = result['images'][0]['url']
if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]:
result_url = io.BytesIO(decode_data_url(result_url))
result = Image.open(result_url)
else:
response = requests.get(result_url)
result = Image.open(io.BytesIO(response.content))
return result
elif self.model_type == "image2image":
raise NotImplementedError("image2image model is not implemented yet")
# assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model"
# if "image" in kwargs:
# image_url = None
# pass
# handler = fal_client.submit(
# f"fal-ai/{self.model_name}",
# arguments={
# "image_url": image_url
# },
# )
#
# for event in handler.iter_events():
# if isinstance(event, fal_client.InProgress):
# print('Request in progress')
# print(event.logs)
#
# result = handler.get()
# return result
elif self.model_type == "text2video":
assert "prompt" in kwargs, "prompt is required for text2video model"
if self.model_name == 'AnimateDiff':
fal_model_name = 'fast-animatediff/text-to-video'
elif self.model_name == 'AnimateDiffTurbo':
fal_model_name = 'fast-animatediff/turbo/text-to-video'
elif self.model_name == 'StableVideoDiffusion':
fal_model_name = 'fast-svd/text-to-video'
else:
raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet")
handler = fal_client.submit(
f"fal-ai/{fal_model_name}",
arguments={
"prompt": kwargs["prompt"]
},
)
for event in handler.iter_events(with_logs=True):
if isinstance(event, fal_client.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
print("result video: ====")
print(result)
result_url = result['video']['url']
return result_url
else:
raise ValueError("model_type must be text2image or image2image")
def load_fal_model(model_name, model_type):
return FalModel(model_name, model_type)