Spaces:
Runtime error
Runtime error
import soundfile as sf | |
import torch | |
import numpy as np | |
from diffusers import StableUnCLIPImg2ImgPipeline | |
from PIL import Image | |
from . import imagebind | |
class Anything2Image: | |
def __init__( | |
self, | |
device = "cuda:0" if torch.cuda.is_available() else "cpu", | |
imagebind_download_dir="checkpoints" | |
): | |
self.pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=None if device == 'cpu' else torch.float16, | |
).to(device) | |
self.model = imagebind.imagebind_huge(pretrained=True, download_dir=imagebind_download_dir).eval().to(device) | |
self.device = device | |
def __call__(self, prompt=None, audio=None, image=None, text=None): | |
device, model, pipe = self.device, self.model, self.pipe | |
if audio is not None: | |
sr, waveform = audio | |
sf.write('tmp.wav', waveform, sr) | |
embeddings = model.forward({ | |
imagebind.ModalityType.AUDIO: imagebind.load_and_transform_audio_data(['tmp.wav'], device), | |
}) | |
audio_embeddings = embeddings[imagebind.ModalityType.AUDIO] | |
if image is not None: | |
Image.fromarray(image).save('tmp.png') | |
embeddings = model.forward({ | |
imagebind.ModalityType.VISION: imagebind.load_and_transform_vision_data(['tmp.png'], device), | |
}, normalize=False) | |
image_embeddings = embeddings[imagebind.ModalityType.VISION] | |
if audio is not None and image is not None: | |
embeddings = (audio_embeddings + image_embeddings) / 2 | |
elif image is not None: | |
embeddings = image_embeddings | |
elif audio is not None: | |
embeddings = audio_embeddings | |
else: | |
embeddings = None | |
if text is not None and text != "": | |
embeddings = self.model.forward({ | |
imagebind.ModalityType.TEXT: imagebind.load_and_transform_text([text], device), | |
}, normalize=False) | |
embeddings = embeddings[imagebind.ModalityType.TEXT] | |
if embeddings is not None and self.device != 'cpu': | |
embeddings = embeddings.half() | |
images = pipe(prompt=prompt, image_embeds=embeddings).images | |
return images[0] |