Spaces:
Runtime error
Runtime error
Fabrice-TIERCELIN
commited on
Commit
•
ca7518d
1
Parent(s):
8a37844
Comment
Browse files
models.py
CHANGED
@@ -69,41 +69,41 @@ class AudioDiffusion(nn.Module):
|
|
69 |
):
|
70 |
super().__init__()
|
71 |
|
72 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
73 |
-
|
74 |
-
self.text_encoder_name = text_encoder_name
|
75 |
-
self.scheduler_name = scheduler_name
|
76 |
-
self.unet_model_name = unet_model_name
|
77 |
-
self.unet_model_config_path = unet_model_config_path
|
78 |
-
self.snr_gamma = snr_gamma
|
79 |
-
self.freeze_text_encoder = freeze_text_encoder
|
80 |
-
self.uncondition = uncondition
|
81 |
-
|
82 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
83 |
-
self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
|
84 |
-
self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
|
85 |
-
|
86 |
-
if unet_model_config_path:
|
87 |
-
unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
|
88 |
-
self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
|
89 |
-
self.set_from = "random"
|
90 |
-
print("UNet initialized randomly.")
|
91 |
-
else:
|
92 |
-
self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
|
93 |
-
self.set_from = "pre-trained"
|
94 |
-
self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
|
95 |
-
self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
|
96 |
-
print("UNet initialized from stable diffusion checkpoint.")
|
97 |
-
|
98 |
-
if "stable-diffusion" in self.text_encoder_name:
|
99 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
|
100 |
-
self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
|
101 |
-
elif "t5" in self.text_encoder_name:
|
102 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
|
103 |
-
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
|
104 |
-
else:
|
105 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
|
106 |
-
self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
|
107 |
|
108 |
def compute_snr(self, timesteps):
|
109 |
"""
|
|
|
69 |
):
|
70 |
super().__init__()
|
71 |
|
72 |
+
# assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
73 |
+
#
|
74 |
+
# self.text_encoder_name = text_encoder_name
|
75 |
+
# self.scheduler_name = scheduler_name
|
76 |
+
# self.unet_model_name = unet_model_name
|
77 |
+
# self.unet_model_config_path = unet_model_config_path
|
78 |
+
# self.snr_gamma = snr_gamma
|
79 |
+
# self.freeze_text_encoder = freeze_text_encoder
|
80 |
+
# self.uncondition = uncondition
|
81 |
+
#
|
82 |
+
# # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
83 |
+
# self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
|
84 |
+
# self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler")
|
85 |
+
#
|
86 |
+
# if unet_model_config_path:
|
87 |
+
# unet_config = UNet2DConditionModel.load_config(unet_model_config_path)
|
88 |
+
# self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet")
|
89 |
+
# self.set_from = "random"
|
90 |
+
# print("UNet initialized randomly.")
|
91 |
+
# else:
|
92 |
+
# self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet")
|
93 |
+
# self.set_from = "pre-trained"
|
94 |
+
# self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4))
|
95 |
+
# self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8))
|
96 |
+
# print("UNet initialized from stable diffusion checkpoint.")
|
97 |
+
#
|
98 |
+
# if "stable-diffusion" in self.text_encoder_name:
|
99 |
+
# self.tokenizer = CLIPTokenizer.from_pretrained(self.text_encoder_name, subfolder="tokenizer")
|
100 |
+
# self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder_name, subfolder="text_encoder")
|
101 |
+
# elif "t5" in self.text_encoder_name:
|
102 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
|
103 |
+
# self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
|
104 |
+
# else:
|
105 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(self.text_encoder_name)
|
106 |
+
# self.text_encoder = AutoModel.from_pretrained(self.text_encoder_name)
|
107 |
|
108 |
def compute_snr(self, timesteps):
|
109 |
"""
|