aiqcamp commited on
Commit
59b0bed
·
verified ·
1 Parent(s): df0d8fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -43
app.py CHANGED
@@ -9,8 +9,15 @@ import os
9
  import requests
10
  from transformers import pipeline
11
  import tempfile
 
 
 
 
 
 
 
 
12
 
13
- # 기본 설정
14
  try:
15
  import mmaudio
16
  except ImportError:
@@ -24,15 +31,22 @@ from mmaudio.model.networks import MMAudio, get_my_mmaudio
24
  from mmaudio.model.sequence_config import SequenceConfig
25
  from mmaudio.model.utils.features_utils import FeaturesUtils
26
 
27
- # CUDA 설정
28
- torch.backends.cuda.matmul.allow_tf32 = True
29
- torch.backends.cudnn.allow_tf32 = True
30
-
31
  # 로깅 설정
 
 
 
 
32
  log = logging.getLogger()
33
 
34
- # 장치 및 데이터 타입 설정
35
- device = 'cuda'
 
 
 
 
 
 
 
36
  dtype = torch.bfloat16
37
 
38
  # 모델 설정
@@ -43,23 +57,9 @@ output_dir = Path('./output/gradio')
43
  setup_eval_logging()
44
 
45
  # 번역기 및 Pixabay API 설정
46
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
47
  PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17"
48
 
49
- def search_pixabay_videos(query, api_key):
50
- base_url = "https://pixabay.com/api/videos/"
51
- params = {
52
- "key": api_key,
53
- "q": query,
54
- "per_page": 80
55
- }
56
-
57
- response = requests.get(base_url, params=params)
58
- if response.status_code == 200:
59
- data = response.json()
60
- return [video['videos']['large']['url'] for video in data.get('hits', [])]
61
- return []
62
-
63
  # CSS 스타일 정의
64
  custom_css = """
65
  .gradio-container {
@@ -111,34 +111,71 @@ button:hover {
111
  }
112
  """
113
 
114
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
115
- seq_cfg = model.seq_cfg
116
-
117
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
118
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
119
- log.info(f'Loaded weights from {model.model_path}')
 
 
120
 
121
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
122
- synchformer_ckpt=model.synchformer_ckpt,
123
- enable_conditions=True,
124
- mode=model.mode,
125
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
126
- need_vae_encoder=False)
127
- feature_utils = feature_utils.to(device, dtype).eval()
128
 
129
- return net, feature_utils, seq_cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  net, feature_utils, seq_cfg = get_model()
132
 
133
  def translate_prompt(text):
134
- if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text):
135
- translation = translator(text)[0]['translation_text']
136
- return translation
137
- return text
 
 
 
 
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def search_videos(query):
140
- query = translate_prompt(query)
141
- return search_pixabay_videos(query, PIXABAY_API_KEY)
 
142
 
143
  @spaces.GPU
144
  @torch.inference_mode()
@@ -209,7 +246,8 @@ video_search_tab = gr.Interface(
209
  fn=search_videos,
210
  inputs=gr.Textbox(label="검색어 입력"),
211
  outputs=gr.Gallery(label="검색 결과", columns=4, rows=20),
212
- css=custom_css
 
213
  )
214
 
215
  video_to_audio_tab = gr.Interface(
 
9
  import requests
10
  from transformers import pipeline
11
  import tempfile
12
+ import numpy as np
13
+ from einops import rearrange
14
+ import cv2
15
+ from scipy.io import wavfile
16
+ import librosa
17
+ import json
18
+ from typing import Optional, Tuple, List
19
+ import atexit
20
 
 
21
  try:
22
  import mmaudio
23
  except ImportError:
 
31
  from mmaudio.model.sequence_config import SequenceConfig
32
  from mmaudio.model.utils.features_utils import FeaturesUtils
33
 
 
 
 
 
34
  # 로깅 설정
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
38
+ )
39
  log = logging.getLogger()
40
 
41
+ # CUDA 설정
42
+ if torch.cuda.is_available():
43
+ device = torch.device("cuda")
44
+ torch.backends.cuda.matmul.allow_tf32 = True
45
+ torch.backends.cudnn.allow_tf32 = True
46
+ torch.backends.cudnn.benchmark = True
47
+ else:
48
+ device = torch.device("cpu")
49
+
50
  dtype = torch.bfloat16
51
 
52
  # 모델 설정
 
57
  setup_eval_logging()
58
 
59
  # 번역기 및 Pixabay API 설정
60
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
61
  PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # CSS 스타일 정의
64
  custom_css = """
65
  .gradio-container {
 
111
  }
112
  """
113
 
114
+ def cleanup_temp_files():
115
+ temp_dir = tempfile.gettempdir()
116
+ for file in os.listdir(temp_dir):
117
+ if file.endswith(('.mp4', '.flac')):
118
+ try:
119
+ os.remove(os.path.join(temp_dir, file))
120
+ except:
121
+ pass
122
 
123
+ atexit.register(cleanup_temp_files)
 
 
 
 
 
 
124
 
125
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
126
+ with torch.cuda.device(device):
127
+ seq_cfg = model.seq_cfg
128
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
129
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
130
+ log.info(f'Loaded weights from {model.model_path}')
131
+
132
+ feature_utils = FeaturesUtils(
133
+ tod_vae_ckpt=model.vae_path,
134
+ synchformer_ckpt=model.synchformer_ckpt,
135
+ enable_conditions=True,
136
+ mode=model.mode,
137
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
138
+ need_vae_encoder=False
139
+ ).to(device, dtype).eval()
140
+
141
+ return net, feature_utils, seq_cfg
142
 
143
  net, feature_utils, seq_cfg = get_model()
144
 
145
  def translate_prompt(text):
146
+ try:
147
+ if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text):
148
+ with torch.no_grad():
149
+ translation = translator(text)[0]['translation_text']
150
+ return translation
151
+ return text
152
+ except Exception as e:
153
+ logging.error(f"Translation error: {e}")
154
+ return text
155
 
156
+ def search_pixabay_videos(query, api_key):
157
+ try:
158
+ base_url = "https://pixabay.com/api/videos/"
159
+ params = {
160
+ "key": api_key,
161
+ "q": query,
162
+ "per_page": 80
163
+ }
164
+
165
+ response = requests.get(base_url, params=params)
166
+ if response.status_code == 200:
167
+ data = response.json()
168
+ return [video['videos']['large']['url'] for video in data.get('hits', [])]
169
+ return []
170
+ except Exception as e:
171
+ logging.error(f"Pixabay API error: {e}")
172
+ return []
173
+
174
+ @torch.no_grad()
175
  def search_videos(query):
176
+ with torch.cuda.device("cpu"):
177
+ query = translate_prompt(query)
178
+ return search_pixabay_videos(query, PIXABAY_API_KEY)
179
 
180
  @spaces.GPU
181
  @torch.inference_mode()
 
246
  fn=search_videos,
247
  inputs=gr.Textbox(label="검색어 입력"),
248
  outputs=gr.Gallery(label="검색 결과", columns=4, rows=20),
249
+ css=custom_css,
250
+ api_name=False
251
  )
252
 
253
  video_to_audio_tab = gr.Interface(