import imagen_hub class ImagenHubModel(): def __init__(self, model_name): self.model = imagen_hub.load(model_name) def __call__(self, *args, **kwargs): return self.model.infer_one_image(*args, **kwargs) class PNP(ImagenHubModel): def __init__(self): super().__init__('PNP') def __call__(self, *args, **kwargs): if "num_inversion_steps" not in kwargs: kwargs["num_inversion_steps"] = 200 return super().__call__(*args, **kwargs) class Prompt2prompt(ImagenHubModel): def __init__(self): super().__init__('Prompt2prompt') def __call__(self, *args, **kwargs): if "num_inner_steps" not in kwargs: kwargs["num_inner_steps"] = 3 return super().__call__(*args, **kwargs) def load_imagenhub_model(model_name, model_type=None): if model_name == 'PNP': return PNP() if model_name == 'Prompt2prompt': return Prompt2prompt() return ImagenHubModel(model_name)