admin
sync
07c7745
raw
history blame
7.35 kB
import os
import re
import json
import torch
import shutil
import requests
import gradio as gr
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate
from modelscope import snapshot_download
from tempfile import NamedTemporaryFile
from pydub.utils import mediainfo
from urllib.parse import urlparse
from convert import midi2xml, xml2abc, xml2mxl, xml2jpg
CACHE_DIR = "./flagged"
WEIGHTS_PATH = (
snapshot_download("MuGeminorum/piano_transcription", cache_dir="./__pycache__")
+ "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
)
def clean_cache(cache_dir=CACHE_DIR):
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.mkdir(cache_dir)
def get_audio_file_type(file_path: str):
try:
# 获取媒体信息
info = mediainfo(file_path)
# 返回文件格式
return "." + info["format_name"]
except Exception as e:
print(f"Error occurred: {e}")
return None
def download_audio(url: str, save_path: str):
with NamedTemporaryFile(delete=False, suffix="_temp") as tmp_file:
temp_file_path = tmp_file.name
# 发送HTTP GET请求并下载内容
response = requests.get(url, stream=True)
# 检查请求是否成功
if response.status_code == 200:
# 将音频内容写入临时文件
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
else:
print(f"Failed to download file: HTTP {response.status_code}")
return ""
ext = get_audio_file_type(temp_file_path)
full_path = f"{save_path}{ext}"
# 重命名临时文件以包含正确的扩展名
shutil.move(temp_file_path, full_path)
return full_path
def is_url(s: str):
try:
# 解析字符串
result = urlparse(s)
# 检查scheme(如http, https)和netloc(域名)
return all([result.scheme, result.netloc])
except:
# 如果解析过程中发生异常,则返回False
return False
def audio2midi(audio_path: str):
# Load audio
audio, _ = load_audio(audio_path, sr=sample_rate, mono=True)
# Transcriptor
transcriptor = PianoTranscription(
device="cuda" if torch.cuda.is_available() else "cpu",
checkpoint_path=WEIGHTS_PATH,
)
# device: 'cuda' | 'cpu' Transcribe and write out to MIDI file
midi_path = f"{CACHE_DIR}/output.mid"
# midi_path = audio_path.replace(audio_path.split(".")[-1], "mid")
transcriptor.transcribe(audio, midi_path)
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize()
def upl_infer(audio_path: str):
clean_cache()
try:
midi, title = audio2midi(audio_path)
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
return midi, pdf, xml, mxl, abc, jpg
except Exception as e:
return None, None, None, None, f"{e}", None
def get_first_integer(input_string: str):
match = re.search(r"\d+", input_string)
if match:
return str(int(match.group()))
else:
return ""
def music163_song_info(id: str):
detail_api = "https://music.163.com/api/v3/song/detail"
parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""}
free = False
song_name = "获取歌曲失败 Failed to get the song"
response = requests.get(detail_api, params=parm_dict)
# 检查请求是否成功
if response.status_code == 200:
# 处理成功响应
data = json.loads(response.text)
if data and "songs" in data and data["songs"]:
fee = int(data["songs"][0]["fee"])
free = fee == 0 or fee == 8
song_name = str(data["songs"][0]["name"])
else:
song_name = "歌曲不存在 Song not exist"
else:
raise ConnectionError(f"Error: {response.status_code}, {response.text}")
return song_name, free
def url_infer(audio_url: str):
clean_cache()
song_name = ""
download_path = f"{CACHE_DIR}/output"
try:
if is_url(audio_url):
if "163" in audio_url and not audio_url.endswith(".mp3"):
song_id = get_first_integer(audio_url.split("?id=")[1])
audio_url = (
f"https://music.163.com/song/media/outer/url?id={song_id}.mp3"
)
song_name, free = music163_song_info(song_id)
if not free:
raise AttributeError("付费歌曲无法解析 Unable to parse VIP songs")
download_path = download_audio(audio_url, download_path)
midi, title = audio2midi(download_path)
if song_name:
title = song_name
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
return download_path, midi, pdf, xml, mxl, abc, jpg
except Exception as e:
return None, None, None, None, None, f"{e}", None
if __name__ == "__main__":
with gr.Blocks() as iface:
with gr.Tab("上传模式 (Upload Mode)"):
gr.Interface(
fn=upl_infer,
inputs=gr.Audio(
label="上传音频 (Upload an audio)",
type="filepath",
),
outputs=[
gr.File(label="下载 MIDI (Download MIDI)"),
gr.File(label="下载 PDF 乐谱 (Download PDF score)"),
gr.File(label="下载 MusicXML (Download MusicXML)"),
gr.File(label="下载 MXL (Download MXL)"),
gr.Textbox(label="abc 乐谱 (abc notation)", show_copy_button=True),
gr.Image(label="五线谱 (Staff)", type="filepath"),
],
title="请上传音频 100% 后再点提交<br>Please make sure the audio is completely uploaded before clicking Submit",
allow_flagging="never",
)
with gr.Tab("直链模式 (Direct Link Mode)"):
gr.Interface(
fn=url_infer,
inputs=gr.Textbox(label="输入音频直链 URL (Input audio direct link)"),
outputs=[
gr.Audio(label="下载音频 (Download audio)", type="filepath"),
gr.File(label="下载 MIDI (Download MIDI)"),
gr.File(label="下载 PDF 乐谱 (Download PDF score)"),
gr.File(label="下载 MusicXML (Download MusicXML)"),
gr.File(label="下载 MXL (Download MXL)"),
gr.Textbox(label="abc 乐谱 (abc notation)", show_copy_button=True),
gr.Image(label="五线谱 (Staff)", type="filepath"),
],
title="网易云音乐可直接输入非 VIP 歌曲页面链接自动解析<br>For Netease Cloud music, you can directly input the non-VIP song page link",
examples=[
"https://music.163.com/#/song?id=1945798894",
"https://music.163.com/#/song?id=1945798973",
"https://music.163.com/#/song?id=1946098771",
],
allow_flagging="never",
cache_examples=False,
)
iface.launch()