GenAI-Arena / model /models /fal_api_models.py
tianleliphoebe's picture
update video generation
26dad4e
raw
history blame
3.18 kB
import fal_client
from PIL import Image
import requests
import io
import os
class FalModel():
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
os.environ['FAL_KEY'] = os.environ['FalAPI']
def __call__(self, *args, **kwargs):
if self.model_type == "text2image":
assert "prompt" in kwargs, "prompt is required for text2image model"
handler = fal_client.submit(
f"fal-ai/{self.model_name}",
arguments={
"prompt": kwargs["prompt"]
},
)
for event in handler.iter_events(with_logs=True):
if isinstance(event, fal_client.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
result_url = result['images'][0]['url']
response = requests.get(result_url)
result = Image.open(io.BytesIO(response.content))
return result
elif self.model_type == "image2image":
raise NotImplementedError("image2image model is not implemented yet")
# assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model"
# if "image" in kwargs:
# image_url = None
# pass
# handler = fal_client.submit(
# f"fal-ai/{self.model_name}",
# arguments={
# "image_url": image_url
# },
# )
#
# for event in handler.iter_events():
# if isinstance(event, fal_client.InProgress):
# print('Request in progress')
# print(event.logs)
#
# result = handler.get()
# return result
elif self.model_type == "text2video":
assert "prompt" in kwargs, "prompt is required for text2video model"
if self.model_name == 'AnimateDiff':
fal_model_name = 'fast-animatediff/text-to-video'
elif self.model_name == 'AnimateDiffTurbo':
fal_model_name = 'fast-animatediff/turbo/text-to-video'
else:
raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet")
handler = fal_client.submit(
f"fal-ai/{fal_model_name}",
arguments={
"prompt": kwargs["prompt"]
},
)
for event in handler.iter_events(with_logs=True):
if isinstance(event, fal_client.InProgress):
print('Request in progress')
print(event.logs)
result = handler.get()
print("result video: ====")
print(result)
result_url = result['video']['url']
return result_url
else:
raise ValueError("model_type must be text2image or image2image")
def load_fal_model(model_name, model_type):
return FalModel(model_name, model_type)