import fal_client from PIL import Image import requests import io import os import base64 FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl", "LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"} class FalModel(): def __init__(self, model_name, model_type): self.model_name = model_name self.model_type = model_type os.environ['FAL_KEY'] = os.environ['FalAPI'] def __call__(self, *args, **kwargs): def decode_data_url(data_url): # Find the start of the Base64 encoded data base64_start = data_url.find(",") + 1 if base64_start == 0: raise ValueError("Invalid data URL provided") # Extract the Base64 encoded data base64_string = data_url[base64_start:] # Decode the Base64 string decoded_bytes = base64.b64decode(base64_string) return decoded_bytes if self.model_type == "text2image": assert "prompt" in kwargs, "prompt is required for text2image model" handler = fal_client.submit( f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}", arguments={ "prompt": kwargs["prompt"] }, ) for event in handler.iter_events(with_logs=True): if isinstance(event, fal_client.InProgress): print('Request in progress') print(event.logs) result = handler.get() print(result) result_url = result['images'][0]['url'] if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]: result_url = io.BytesIO(decode_data_url(result_url)) result = Image.open(result_url) else: response = requests.get(result_url) result = Image.open(io.BytesIO(response.content)) return result elif self.model_type == "image2image": raise NotImplementedError("image2image model is not implemented yet") # assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" # if "image" in kwargs: # image_url = None # pass # handler = fal_client.submit( # f"fal-ai/{self.model_name}", # arguments={ # "image_url": image_url # }, # ) # # for event in handler.iter_events(): # if isinstance(event, fal_client.InProgress): # print('Request in progress') # print(event.logs) # # result = handler.get() # return result elif self.model_type == "text2video": assert "prompt" in kwargs, "prompt is required for text2video model" if self.model_name == 'AnimateDiff': fal_model_name = 'fast-animatediff/text-to-video' elif self.model_name == 'AnimateDiffTurbo': fal_model_name = 'fast-animatediff/turbo/text-to-video' elif self.model_name == 'StableVideoDiffusion': fal_model_name = 'fast-svd/text-to-video' else: raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet") handler = fal_client.submit( f"fal-ai/{fal_model_name}", arguments={ "prompt": kwargs["prompt"] }, ) for event in handler.iter_events(with_logs=True): if isinstance(event, fal_client.InProgress): print('Request in progress') print(event.logs) result = handler.get() print("result video: ====") print(result) result_url = result['video']['url'] return result_url else: raise ValueError("model_type must be text2image or image2image") def load_fal_model(model_name, model_type): return FalModel(model_name, model_type)