Spaces:
Running
Running
File size: 7,350 Bytes
07c7745 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import os
import re
import json
import torch
import shutil
import requests
import gradio as gr
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate
from modelscope import snapshot_download
from tempfile import NamedTemporaryFile
from pydub.utils import mediainfo
from urllib.parse import urlparse
from convert import midi2xml, xml2abc, xml2mxl, xml2jpg
CACHE_DIR = "./flagged"
WEIGHTS_PATH = (
snapshot_download("MuGeminorum/piano_transcription", cache_dir="./__pycache__")
+ "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
)
def clean_cache(cache_dir=CACHE_DIR):
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.mkdir(cache_dir)
def get_audio_file_type(file_path: str):
try:
# 获取媒体信息
info = mediainfo(file_path)
# 返回文件格式
return "." + info["format_name"]
except Exception as e:
print(f"Error occurred: {e}")
return None
def download_audio(url: str, save_path: str):
with NamedTemporaryFile(delete=False, suffix="_temp") as tmp_file:
temp_file_path = tmp_file.name
# 发送HTTP GET请求并下载内容
response = requests.get(url, stream=True)
# 检查请求是否成功
if response.status_code == 200:
# 将音频内容写入临时文件
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
else:
print(f"Failed to download file: HTTP {response.status_code}")
return ""
ext = get_audio_file_type(temp_file_path)
full_path = f"{save_path}{ext}"
# 重命名临时文件以包含正确的扩展名
shutil.move(temp_file_path, full_path)
return full_path
def is_url(s: str):
try:
# 解析字符串
result = urlparse(s)
# 检查scheme(如http, https)和netloc(域名)
return all([result.scheme, result.netloc])
except:
# 如果解析过程中发生异常,则返回False
return False
def audio2midi(audio_path: str):
# Load audio
audio, _ = load_audio(audio_path, sr=sample_rate, mono=True)
# Transcriptor
transcriptor = PianoTranscription(
device="cuda" if torch.cuda.is_available() else "cpu",
checkpoint_path=WEIGHTS_PATH,
)
# device: 'cuda' | 'cpu' Transcribe and write out to MIDI file
midi_path = f"{CACHE_DIR}/output.mid"
# midi_path = audio_path.replace(audio_path.split(".")[-1], "mid")
transcriptor.transcribe(audio, midi_path)
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize()
def upl_infer(audio_path: str):
clean_cache()
try:
midi, title = audio2midi(audio_path)
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
return midi, pdf, xml, mxl, abc, jpg
except Exception as e:
return None, None, None, None, f"{e}", None
def get_first_integer(input_string: str):
match = re.search(r"\d+", input_string)
if match:
return str(int(match.group()))
else:
return ""
def music163_song_info(id: str):
detail_api = "https://music.163.com/api/v3/song/detail"
parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""}
free = False
song_name = "获取歌曲失败 Failed to get the song"
response = requests.get(detail_api, params=parm_dict)
# 检查请求是否成功
if response.status_code == 200:
# 处理成功响应
data = json.loads(response.text)
if data and "songs" in data and data["songs"]:
fee = int(data["songs"][0]["fee"])
free = fee == 0 or fee == 8
song_name = str(data["songs"][0]["name"])
else:
song_name = "歌曲不存在 Song not exist"
else:
raise ConnectionError(f"Error: {response.status_code}, {response.text}")
return song_name, free
def url_infer(audio_url: str):
clean_cache()
song_name = ""
download_path = f"{CACHE_DIR}/output"
try:
if is_url(audio_url):
if "163" in audio_url and not audio_url.endswith(".mp3"):
song_id = get_first_integer(audio_url.split("?id=")[1])
audio_url = (
f"https://music.163.com/song/media/outer/url?id={song_id}.mp3"
)
song_name, free = music163_song_info(song_id)
if not free:
raise AttributeError("付费歌曲无法解析 Unable to parse VIP songs")
download_path = download_audio(audio_url, download_path)
midi, title = audio2midi(download_path)
if song_name:
title = song_name
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
return download_path, midi, pdf, xml, mxl, abc, jpg
except Exception as e:
return None, None, None, None, None, f"{e}", None
if __name__ == "__main__":
with gr.Blocks() as iface:
with gr.Tab("上传模式 (Upload Mode)"):
gr.Interface(
fn=upl_infer,
inputs=gr.Audio(
label="上传音频 (Upload an audio)",
type="filepath",
),
outputs=[
gr.File(label="下载 MIDI (Download MIDI)"),
gr.File(label="下载 PDF 乐谱 (Download PDF score)"),
gr.File(label="下载 MusicXML (Download MusicXML)"),
gr.File(label="下载 MXL (Download MXL)"),
gr.Textbox(label="abc 乐谱 (abc notation)", show_copy_button=True),
gr.Image(label="五线谱 (Staff)", type="filepath"),
],
title="请上传音频 100% 后再点提交<br>Please make sure the audio is completely uploaded before clicking Submit",
allow_flagging="never",
)
with gr.Tab("直链模式 (Direct Link Mode)"):
gr.Interface(
fn=url_infer,
inputs=gr.Textbox(label="输入音频直链 URL (Input audio direct link)"),
outputs=[
gr.Audio(label="下载音频 (Download audio)", type="filepath"),
gr.File(label="下载 MIDI (Download MIDI)"),
gr.File(label="下载 PDF 乐谱 (Download PDF score)"),
gr.File(label="下载 MusicXML (Download MusicXML)"),
gr.File(label="下载 MXL (Download MXL)"),
gr.Textbox(label="abc 乐谱 (abc notation)", show_copy_button=True),
gr.Image(label="五线谱 (Staff)", type="filepath"),
],
title="网易云音乐可直接输入非 VIP 歌曲页面链接自动解析<br>For Netease Cloud music, you can directly input the non-VIP song page link",
examples=[
"https://music.163.com/#/song?id=1945798894",
"https://music.163.com/#/song?id=1945798973",
"https://music.163.com/#/song?id=1946098771",
],
allow_flagging="never",
cache_examples=False,
)
iface.launch()
|