Fabrice-TIERCELIN commited on
Commit
ca7518d
1 Parent(s): 8a37844
Files changed (1) hide show
  1. models.py +35 -35
models.py CHANGED
@@ -69,41 +69,41 @@ class AudioDiffusion(nn.Module):
69
  ):
70
  super().__init__()
71
 
72
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
73
-
74
- self.text_encoder_name = text_encoder_name
75
- self.scheduler_name = scheduler_name
76
- self.unet_model_name = unet_model_name
77
- self.unet_model_config_path = unet_model_config_path
78
- self.snr_gamma = snr_gamma
79
- self.freeze_text_encoder = freeze_text_encoder
80
- self.uncondition = uncondition
81
-
82
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
83
- self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
84
- self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
85
-
86
- if unet_model_config_path:
87
- unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
88
- self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
89
- self.set_from = "random"
90
- print("UNet initialized randomly.")
91
- else:
92
- self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
93
- self.set_from = "pre-trained"
94
- self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
95
- self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
96
- print("UNet initialized from stable diffusion checkpoint.")
97
-
98
- if "stable-diffusion" in self.text_encoder_name:
99
- self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
100
- self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
101
- elif "t5" in self.text_encoder_name:
102
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
103
- self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
104
- else:
105
- self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
106
- self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
107
 
108
  def compute_snr(self, timesteps):
109
  """
 
69
  ):
70
  super().__init__()
71
 
72
+ # assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
73
+ #
74
+ # self.text_encoder_name = text_encoder_name
75
+ # self.scheduler_name = scheduler_name
76
+ # self.unet_model_name = unet_model_name
77
+ # self.unet_model_config_path = unet_model_config_path
78
+ # self.snr_gamma = snr_gamma
79
+ # self.freeze_text_encoder = freeze_text_encoder
80
+ # self.uncondition = uncondition
81
+ #
82
+ # # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
83
+ # self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
84
+ # self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
85
+ #
86
+ # if unet_model_config_path:
87
+ # unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
88
+ # self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
89
+ # self.set_from = "random"
90
+ # print("UNet initialized randomly.")
91
+ # else:
92
+ # self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
93
+ # self.set_from = "pre-trained"
94
+ # self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
95
+ # self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
96
+ # print("UNet initialized from stable diffusion checkpoint.")
97
+ #
98
+ # if "stable-diffusion" in self.text_encoder_name:
99
+ # self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
100
+ # self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
101
+ # elif "t5" in self.text_encoder_name:
102
+ # self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
103
+ # self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
104
+ # else:
105
+ # self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
106
+ # self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
107
 
108
  def compute_snr(self, timesteps):
109
  """