import fal_client from PIL import Image import requests import io import os class FalModel(): def __init__(self, model_name, model_type): self.model_name = model_name self.model_type = model_type os.environ['FAL_KEY'] = os.environ['FalAPI'] def __call__(self, *args, **kwargs): if self.model_type == "text2image": assert "prompt" in kwargs, "prompt is required for text2image model" handler = fal_client.submit( f"fal-ai/{self.model_name}", arguments={ "prompt": kwargs["prompt"] }, ) for event in handler.iter_events(with_logs=True): if isinstance(event, fal_client.InProgress): print('Request in progress') print(event.logs) result = handler.get() result_url = result['images'][0]['url'] response = requests.get(result_url) result = Image.open(io.BytesIO(response.content)) return result elif self.model_type == "image2image": raise NotImplementedError("image2image model is not implemented yet") # assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" # if "image" in kwargs: # image_url = None # pass # handler = fal_client.submit( # f"fal-ai/{self.model_name}", # arguments={ # "image_url": image_url # }, # ) # # for event in handler.iter_events(): # if isinstance(event, fal_client.InProgress): # print('Request in progress') # print(event.logs) # # result = handler.get() # return result elif self.model_type == "text2video": assert "prompt" in kwargs, "prompt is required for text2video model" if self.model_name == 'AnimateDiff': fal_model_name = 'fast-animatediff/text-to-video' elif self.model_name == 'AnimateDiffTurbo': fal_model_name = 'fast-animatediff/turbo/text-to-video' else: raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet") handler = fal_client.submit( f"fal-ai/{fal_model_name}", arguments={ "prompt": kwargs["prompt"] }, ) for event in handler.iter_events(with_logs=True): if isinstance(event, fal_client.InProgress): print('Request in progress') print(event.logs) result = handler.get() print("result video: ====") print(result) result_url = result['video']['url'] return result_url else: raise ValueError("model_type must be text2image or image2image") def load_fal_model(model_name, model_type): return FalModel(model_name, model_type)