Spaces:
Running
Running
import torch | |
from diffusers.loaders import AttnProcsLayers | |
from modules.BEATs.BEATs import BEATs, BEATsConfig | |
from modules.AudioToken.embedder import FGAEmbedder | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from diffusers.models.attention_processor import LoRAAttnProcessor | |
class AudioTokenWrapper(torch.nn.Module): | |
"""Simple wrapper module for Stable Diffusion that holds all the models together""" | |
def __init__( | |
self, | |
args, | |
accelerator, | |
): | |
super().__init__() | |
# Load scheduler and models | |
from modules.clip_text_model.modeling_clip import CLIPTextModel | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
) | |
self.unet = UNet2DConditionModel.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision | |
) | |
self.vae = AutoencoderKL.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision | |
) | |
checkpoint = torch.load( | |
'models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt') | |
cfg = BEATsConfig(checkpoint['cfg']) | |
self.aud_encoder = BEATs(cfg) | |
self.aud_encoder.load_state_dict(checkpoint['model']) | |
self.aud_encoder.predictor = None | |
input_size = 768 * 3 | |
if args.pretrained_model_name_or_path == "CompVis/stable-diffusion-v1-4": | |
self.embedder = FGAEmbedder(input_size=input_size, output_size=768) | |
else: | |
self.embedder = FGAEmbedder(input_size=input_size, output_size=1024) | |
self.vae.eval() | |
self.unet.eval() | |
self.text_encoder.eval() | |
self.aud_encoder.eval() | |
if 'lora' in args and args.lora: | |
# Set correct lora layers | |
lora_attn_procs = {} | |
for name in self.unet.attn_processors.keys(): | |
cross_attention_dim = None if name.endswith( | |
"attn1.processor") else self.unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = self.unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = self.unet.config.block_out_channels[block_id] | |
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, | |
cross_attention_dim=cross_attention_dim) | |
self.unet.set_attn_processor(lora_attn_procs) | |
self.lora_layers = AttnProcsLayers(self.unet.attn_processors) | |
if args.data_set == 'train': | |
# Freeze vae, unet, text_enc and aud_encoder | |
self.vae.requires_grad_(False) | |
self.unet.requires_grad_(False) | |
self.text_encoder.requires_grad_(False) | |
self.aud_encoder.requires_grad_(False) | |
self.embedder.requires_grad_(True) | |
self.embedder.train() | |
if 'lora' in args and args.lora: | |
self.unet.train() | |
if args.data_set == 'test': | |
from transformers import CLIPTextModel | |
self.text_encoder = CLIPTextModel.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
) | |
self.embedder.eval() | |
embedder_learned_embeds = args.learned_embeds | |
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=accelerator.device)) | |
if 'lora' in args and args.lora: | |
self.lora_layers.eval() | |
lora_layers_learned_embeds = args.lora_learned_embeds | |
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=accelerator.device)) | |
self.unet.load_attn_procs(lora_layers_learned_embeds) | |