Spaces:
Build error
Build error
hsiangyualex
commited on
Commit
•
d1add7a
1
Parent(s):
0fc1d10
Upload 41 files
Browse files- app.py +133 -0
- audio_models.py +389 -0
- ckpt/emo_dim_model/config.json +122 -0
- ckpt/emo_dim_model/model.safetensors +3 -0
- ckpt/emo_dim_model/preprocessor_config.json +9 -0
- ckpt/emo_dim_model/vocab.json +1 -0
- ckpt/sepformer-wham-enhancement/decoder.ckpt +3 -0
- ckpt/sepformer-wham-enhancement/encoder.ckpt +3 -0
- ckpt/sepformer-wham-enhancement/hyperparams.yaml +66 -0
- ckpt/sepformer-wham-enhancement/masknet.ckpt +3 -0
- ckpt/ser_cn_audio/config.json +74 -0
- ckpt/ser_cn_audio/preprocessor_config.json +9 -0
- ckpt/ser_cn_audio/pytorch_model.bin +3 -0
- ckpt/ser_en_audio/config.json +148 -0
- ckpt/ser_en_audio/model.safetensors +3 -0
- ckpt/ser_en_audio/optimizer.pt +3 -0
- ckpt/ser_en_audio/preprocessor_config.json +10 -0
- ckpt/ser_en_audio/rng_state_0.pth +3 -0
- ckpt/ser_en_audio/rng_state_1.pth +3 -0
- ckpt/ser_en_audio/rng_state_2.pth +3 -0
- ckpt/ser_en_audio/rng_state_3.pth +3 -0
- ckpt/ser_en_audio/scheduler.pt +3 -0
- ckpt/ser_en_audio/trainer_state.json +652 -0
- ckpt/ser_en_audio/training_args.bin +3 -0
- ckpt/ser_en_text/config.json +45 -0
- ckpt/ser_en_text/merges.txt +0 -0
- ckpt/ser_en_text/pytorch_model.bin +3 -0
- ckpt/ser_en_text/special_tokens_map.json +1 -0
- ckpt/ser_en_text/tokenizer.json +0 -0
- ckpt/ser_en_text/tokenizer_config.json +1 -0
- ckpt/ser_en_text/training_args.bin +3 -0
- ckpt/ser_en_text/vocab.json +0 -0
- ckpt/zh-2-en/config.json +60 -0
- ckpt/zh-2-en/generation_config.json +16 -0
- ckpt/zh-2-en/metadata.json +1 -0
- ckpt/zh-2-en/pytorch_model.bin +3 -0
- ckpt/zh-2-en/rust_model.ot +3 -0
- ckpt/zh-2-en/source.spm +0 -0
- ckpt/zh-2-en/target.spm +0 -0
- ckpt/zh-2-en/tokenizer_config.json +1 -0
- ckpt/zh-2-en/vocab.json +0 -0
app.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
3 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
4 |
+
import time
|
5 |
+
import shutil
|
6 |
+
import argparse
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import gradio as gr
|
13 |
+
from glob import glob
|
14 |
+
from audio_models import EnglishEmotionModel
|
15 |
+
|
16 |
+
|
17 |
+
def classify(audio, model_choice, preprocess, weight):
|
18 |
+
return en_model.predict(audio, model_choice, preprocess, weight)
|
19 |
+
|
20 |
+
|
21 |
+
def handle_feedback(audio: str, model_prediction, unsatisfied, true_label, savedir='./user_feedback'):
|
22 |
+
os.makedirs(savedir, exist_ok=True)
|
23 |
+
if unsatisfied:
|
24 |
+
audio_path = os.path.join(savedir, f"{int(time.time())}.wav")
|
25 |
+
# 保存用户反馈的音频
|
26 |
+
if isinstance(audio, str): # a temporary file path
|
27 |
+
shutil.copy(audio, audio_path)
|
28 |
+
elif isinstance(audio, tuple): # audio data
|
29 |
+
sr, data = audio
|
30 |
+
librosa.output.write_wav(audio_path, data, sr)
|
31 |
+
else:
|
32 |
+
raise ValueError("Invalid audio input")
|
33 |
+
# save model prediction and true label as Python dict, save it as pickle file
|
34 |
+
feedback = {
|
35 |
+
"audio_path": audio_path,
|
36 |
+
"model_prediction": model_prediction,
|
37 |
+
"true_label": true_label
|
38 |
+
}
|
39 |
+
feedback_path = os.path.join(savedir, f"{int(time.time())}.pkl")
|
40 |
+
torch.save(feedback, feedback_path)
|
41 |
+
return f"Feedback submitted: True Label = {true_label}, Model Prediction = {model_prediction}"
|
42 |
+
return "Thank you for using our SER demo!"
|
43 |
+
|
44 |
+
|
45 |
+
def parse_args():
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument('--duration', type=int, default=10, help='duration of audio')
|
48 |
+
parser.add_argument('--sr', type=int, default=16000, help='sampling rate of audio')
|
49 |
+
parser.add_argument('--device', type=str, default='cuda', help='device index to run model')
|
50 |
+
return parser.parse_args()
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
args = parse_args()
|
56 |
+
en_model = EnglishEmotionModel(duration=args.duration, sr=args.sr, device=torch.device(args.device))
|
57 |
+
english_audio_paths = glob('audio_files/english/*.wav')
|
58 |
+
english_audio_paths.sort()
|
59 |
+
english_audio_paths = {f"English Audio {idx + 1}": path for idx, path in enumerate(english_audio_paths)}
|
60 |
+
|
61 |
+
# 更新音频选项的函数
|
62 |
+
def update_audio_options():
|
63 |
+
return gr.update(choices=list(english_audio_paths.keys()), value="English Audio 1")
|
64 |
+
|
65 |
+
|
66 |
+
# 更新音频播放器的函数
|
67 |
+
def update_audio_file(audio_selection):
|
68 |
+
return english_audio_paths[audio_selection]
|
69 |
+
|
70 |
+
|
71 |
+
with gr.Blocks() as demo:
|
72 |
+
# 创建tab界面
|
73 |
+
# with gr.Tab("Demo (Built-In Audio)"):
|
74 |
+
# gr.Markdown("""## Automatic Emotion Recognition Demo \n
|
75 |
+
# This is a demo for audio emotion recognition.
|
76 |
+
# Note that the model is still under active developments. Please feel free to report any issues. \n
|
77 |
+
# The Chinese model is based on Hubert and the English model is based on Wav2Vec2.""")
|
78 |
+
# with gr.Row():
|
79 |
+
# with gr.Column():
|
80 |
+
# # 选择音频的 Dropdown,默认显示中文音频的第一个
|
81 |
+
# audio_dropdown = gr.Dropdown(list(english_audio_paths.keys()), label="Select Audio", value="English Audio 1", interactive=True)
|
82 |
+
# # 音频播放器,默认播放中文音频 1
|
83 |
+
# audio_player = gr.Audio(value=english_audio_paths["English Audio 1"], interactive=False)
|
84 |
+
# slider = gr.Slider(label='Context Weight', minimum=0, maximum=1, step=0.01, value=0.6)
|
85 |
+
# with gr.Column():
|
86 |
+
# # 显示情感分类结果
|
87 |
+
# emotion_label = gr.Label(label="Emotion Prediction")
|
88 |
+
# dim_label = gr.Plot(label="Emotion Dimension")
|
89 |
+
# transcripts = gr.Textbox(label="Transcription", type='text', lines=5, max_lines=20, placeholder="Transcription")
|
90 |
+
# # 按钮,点击后更新情感分类结果
|
91 |
+
# classify_button = gr.Button("Classify Emotion")
|
92 |
+
|
93 |
+
# audio_dropdown.change(
|
94 |
+
# fn=update_audio_file,
|
95 |
+
# inputs=audio_dropdown,
|
96 |
+
# outputs=audio_player
|
97 |
+
# )
|
98 |
+
|
99 |
+
# # 点击按钮后,更新情感分类结果
|
100 |
+
# classify_button.click(base_classify, inputs=[audio_player, slider], outputs=[emotion_label, dim_label, transcripts])
|
101 |
+
|
102 |
+
with gr.Tab("Speech Emotion Recognition Demo"):
|
103 |
+
gr.Markdown("""## Interactive SER Demo \n
|
104 |
+
Please upload audio via file path or microphone. If you are recording audio via microphone, please make sure that the audio is clear. \n
|
105 |
+
The performance could be affected by environmental noise. \n
|
106 |
+
If you are recording in a noisy environment, please enable the noise reduction option. Note that this will lead to slight deterioration in performance.\n""")
|
107 |
+
with gr.Row():
|
108 |
+
with gr.Column():
|
109 |
+
audio = gr.Audio(sources=['microphone', 'upload'], type='filepath')
|
110 |
+
text = gr.Textbox(label="Transcription", type='text', lines=5, max_lines=20, placeholder="Transcription")
|
111 |
+
model_choice = gr.Dropdown(choices=['中文', 'English'], label='语言 / Language', value='中文')
|
112 |
+
with gr.Accordion("Advanced Settings", open=False):
|
113 |
+
preprocess = gr.Checkbox(label='Noise Reduction (Do not tick the box unless the environment is noisy)', value=False)
|
114 |
+
weight_slider = gr.Slider(label='Context Weight', minimum=0, maximum=1, step=0.01, value=0.6)
|
115 |
+
demo_button = gr.Button("Analyze Emotion")
|
116 |
+
with gr.Column():
|
117 |
+
emotion_pred = gr.Label(label="Emotion Prediction")
|
118 |
+
dim_pred = gr.Plot(label="Emotion Dimension")
|
119 |
+
|
120 |
+
with gr.Accordion("Feedback", open=False) as feedback_section:
|
121 |
+
gr.Markdown("### User Feedback")
|
122 |
+
satisfied_checkbox = gr.Checkbox(label="Are you unsatisfied with the result?", value=False)
|
123 |
+
true_label_dropdown = gr.Dropdown(
|
124 |
+
label="Select the correct label",
|
125 |
+
choices=["angry", "disgust", "fearful", "happy", "neutral", "sad", "surprised"],
|
126 |
+
)
|
127 |
+
submit_feedback_button = gr.Button("Submit Feedback")
|
128 |
+
feedback_result = gr.Textbox(label="Feedback Result", interactive=False)
|
129 |
+
|
130 |
+
demo_button.click(classify, inputs=[audio, model_choice, preprocess, weight_slider], outputs=[emotion_pred, dim_pred, text])
|
131 |
+
submit_feedback_button.click(handle_feedback, inputs=[audio, emotion_pred, satisfied_checkbox, true_label_dropdown], outputs=[feedback_result])
|
132 |
+
|
133 |
+
demo.launch(share=True)
|
audio_models.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
3 |
+
# os.environ['HF_HUB_OFFLINE'] = '1'
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib.patches as patches
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import whisper
|
12 |
+
from transformers import AutoConfig, AutoModelForAudioClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, \
|
13 |
+
Wav2Vec2FeatureExtractor, Wav2Vec2PreTrainedModel, HubertPreTrainedModel, HubertModel, Wav2Vec2Model
|
14 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
15 |
+
from speechbrain.inference.separation import SepformerSeparation as separator
|
16 |
+
|
17 |
+
|
18 |
+
class HubertClassificationHead(nn.Module):
|
19 |
+
def __init__(self, config):
|
20 |
+
super().__init__()
|
21 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
22 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
23 |
+
self.out_proj = nn.Linear(config.hidden_size, config.num_class)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.dense(x)
|
27 |
+
x = torch.tanh(x)
|
28 |
+
x = self.dropout(x)
|
29 |
+
x = self.out_proj(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class RegressionHead(nn.Module):
|
34 |
+
r"""Classification head."""
|
35 |
+
|
36 |
+
def __init__(self, config):
|
37 |
+
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
41 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
42 |
+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
43 |
+
|
44 |
+
def forward(self, features, **kwargs):
|
45 |
+
|
46 |
+
x = features
|
47 |
+
x = self.dropout(x)
|
48 |
+
x = self.dense(x)
|
49 |
+
x = torch.tanh(x)
|
50 |
+
x = self.dropout(x)
|
51 |
+
x = self.out_proj(x)
|
52 |
+
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class HubertForSpeechClassification(HubertPreTrainedModel):
|
57 |
+
def __init__(self, config):
|
58 |
+
super().__init__(config)
|
59 |
+
self.hubert = HubertModel(config)
|
60 |
+
self.classifier = HubertClassificationHead(config)
|
61 |
+
self.init_weights()
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
outputs = self.hubert(x)
|
65 |
+
hidden_states = outputs[0]
|
66 |
+
x = torch.mean(hidden_states, dim=1)
|
67 |
+
x = self.classifier(x)
|
68 |
+
return SequenceClassifierOutput(
|
69 |
+
loss=None,
|
70 |
+
logits=x,
|
71 |
+
hidden_states=outputs.hidden_states,
|
72 |
+
attentions=outputs.attentions,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
class Wav2VecForSpeechRegression(Wav2Vec2PreTrainedModel):
|
77 |
+
r"""Speech emotion classifier."""
|
78 |
+
|
79 |
+
def __init__(self, config):
|
80 |
+
|
81 |
+
super().__init__(config)
|
82 |
+
|
83 |
+
self.config = config
|
84 |
+
self.wav2vec2 = Wav2Vec2Model(config)
|
85 |
+
self.classifier = RegressionHead(config)
|
86 |
+
self.init_weights()
|
87 |
+
|
88 |
+
def forward(
|
89 |
+
self,
|
90 |
+
input_values,
|
91 |
+
):
|
92 |
+
|
93 |
+
outputs = self.wav2vec2(input_values)
|
94 |
+
hidden_states = outputs[0]
|
95 |
+
hidden_states = torch.mean(hidden_states, dim=1)
|
96 |
+
logits = self.classifier(hidden_states)
|
97 |
+
|
98 |
+
return hidden_states, logits
|
99 |
+
|
100 |
+
|
101 |
+
class EmotionModel:
|
102 |
+
def __init__(self, duration: int = 6, sr: int = 16000, device: torch.device = "cuda", use_text: bool = True):
|
103 |
+
# basic configurations
|
104 |
+
self.device = device
|
105 |
+
self.duration = duration
|
106 |
+
self.sr = sr
|
107 |
+
self.use_text = use_text
|
108 |
+
# audio config
|
109 |
+
self.audio_id2label = {}
|
110 |
+
self.processor = None
|
111 |
+
self.audio_model = None
|
112 |
+
# text config
|
113 |
+
self.text_id2label = {} # note that id2label should be identical, we perform classification on the intersection classes
|
114 |
+
self.tokenizer = None
|
115 |
+
self.text_model = None
|
116 |
+
# noise reduction
|
117 |
+
self.nr_model = separator.from_hparams(source="speechbrain/sepformer-wham-enhancement", savedir='ckpt/sepformer-wham-enhancement', run_opts={'device': 'cuda'})
|
118 |
+
# TTS using openai-whisper
|
119 |
+
self.tts_model = whisper.load_model('turbo', device=device)
|
120 |
+
|
121 |
+
def preprocess_audio(self, speech):
|
122 |
+
"""
|
123 |
+
Preprocess the audio: including noise reduction and silence removal.
|
124 |
+
Args:
|
125 |
+
speech: audio waveform.
|
126 |
+
"""
|
127 |
+
# noise reduction
|
128 |
+
speech = self.nr_model.separate_batch(torch.as_tensor(speech).unsqueeze(0))[0, :, 0].detach().cpu().numpy()
|
129 |
+
# speech = nr.reduce_noise(y=speech, sr=self.sr, stationary=True)
|
130 |
+
# # remove silence in the segment
|
131 |
+
# speech, index = librosa.effects.trim(speech, top_db=40)
|
132 |
+
return speech
|
133 |
+
|
134 |
+
def load_audio(self, audio, preprocess: bool = True):
|
135 |
+
"""
|
136 |
+
Load the audio segment into np.ndarray.
|
137 |
+
Args:
|
138 |
+
audio: audio file path or audio data;
|
139 |
+
preprocess: bool, whether to run preprocess function.
|
140 |
+
"""
|
141 |
+
if isinstance(audio, str):
|
142 |
+
# load the speech and resample it to the target sampling rate
|
143 |
+
speech, _ = librosa.load(path=audio, sr=self.sr)
|
144 |
+
speech = librosa.to_mono(speech)
|
145 |
+
# clip the very beginning and end of the audio
|
146 |
+
speech = speech[int(0.5 * self.sr):int(-0.1 * self.sr)]
|
147 |
+
elif isinstance(audio, tuple):
|
148 |
+
assert len(audio) == 2, "audio tuple must have 2 elements: sr and speech"
|
149 |
+
orig_sr, orig_speech = audio
|
150 |
+
speech = librosa.resample(orig_speech.astype(np.float32), orig_sr=orig_sr, target_sr=self.sr)
|
151 |
+
speech = librosa.to_mono(speech)
|
152 |
+
else:
|
153 |
+
raise ValueError("audio must be a file path or audio data, get file type: {}".format(type(audio)))
|
154 |
+
if preprocess:
|
155 |
+
speech = self.preprocess_audio(speech)
|
156 |
+
return speech
|
157 |
+
|
158 |
+
def id2label(self, id2label, indices, scores):
|
159 |
+
"""
|
160 |
+
Get the label based on the index.
|
161 |
+
Args:
|
162 |
+
indices: emotion class index;
|
163 |
+
scores: emotion class scores;
|
164 |
+
modal: str, "audio" or "text".
|
165 |
+
"""
|
166 |
+
output = {}
|
167 |
+
for idx, score in zip(indices, scores):
|
168 |
+
if idx in id2label.keys():
|
169 |
+
output[id2label[idx]] = score
|
170 |
+
return output
|
171 |
+
|
172 |
+
def normalize_scores(self, audio_result, text_result, audio_weight: float = 0.25, text_weight: float = 0.75):
|
173 |
+
"""
|
174 |
+
Normalize the scores based on the weights.
|
175 |
+
Args:
|
176 |
+
audio_result: a dict of audio pred, keys being emotion labels and values being scores;
|
177 |
+
text_result: a dict of text pred, keys being emotion labels and values being scores;
|
178 |
+
audio_weight: float, weight for audio;
|
179 |
+
text_weight: float, weight for text.
|
180 |
+
"""
|
181 |
+
audio_result = {k: v * audio_weight for k, v in audio_result.items()}
|
182 |
+
text_result = {k: v * text_weight for k, v in text_result.items()}
|
183 |
+
# merge the results, the order of classes should be the same
|
184 |
+
result = {}
|
185 |
+
for k in audio_result.keys():
|
186 |
+
result[k] = audio_result[k] + text_result[k]
|
187 |
+
# normalize the scores to 1
|
188 |
+
total = sum(result.values())
|
189 |
+
result = {k: v / total for k, v in result.items()}
|
190 |
+
return result
|
191 |
+
|
192 |
+
def audio_pred(self, inputs):
|
193 |
+
"""
|
194 |
+
Predict emotion class on audio segment.
|
195 |
+
Args:
|
196 |
+
inputs: audio inputs;
|
197 |
+
scores_only: bool, whether to return only scores.
|
198 |
+
"""
|
199 |
+
speech = self.processor(inputs, padding="max_length", truncation=True, max_length=self.duration * self.sr,
|
200 |
+
return_tensors="pt", sampling_rate=self.sr).input_values.to(self.device)
|
201 |
+
with torch.no_grad():
|
202 |
+
logits = self.audio_model(speech).logits
|
203 |
+
scores, indices = torch.sort(logits.squeeze().detach().cpu(), descending=True)
|
204 |
+
scores = F.softmax(scores, dim=0).numpy()
|
205 |
+
indices = indices.numpy()
|
206 |
+
return self.id2label(self.audio_id2label, indices, scores)
|
207 |
+
|
208 |
+
def text_pred(self, text):
|
209 |
+
"""
|
210 |
+
Predict emotion class on text.
|
211 |
+
Args:
|
212 |
+
text: text inputs by TTS;
|
213 |
+
scores_only: bool, whether to return only scores.
|
214 |
+
"""
|
215 |
+
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(self.device)
|
216 |
+
with torch.no_grad():
|
217 |
+
logits = self.text_model(**inputs).logits
|
218 |
+
scores, indices = torch.sort(logits.squeeze().detach().cpu(), descending=True)
|
219 |
+
scores = scores.numpy()
|
220 |
+
scores = F.softmax(torch.tensor(scores), dim=0).numpy()
|
221 |
+
indices = indices.numpy()
|
222 |
+
return self.id2label(self.text_id2label, indices, scores)
|
223 |
+
|
224 |
+
def predict(self, audio, preprocess: bool = True, scores_only: bool = False):
|
225 |
+
"""
|
226 |
+
Run prediction based on the recipe.
|
227 |
+
"""
|
228 |
+
speech = self.load_audio(audio, preprocess=preprocess)
|
229 |
+
audio_scores = self.audio_pred(speech)
|
230 |
+
if not self.use_text:
|
231 |
+
return audio_scores if not scores_only else list(audio_scores.values())
|
232 |
+
output = self.tts_model.transcribe(speech)
|
233 |
+
text, language = output['text'], output['language']
|
234 |
+
text_scores = self.text_pred(text)
|
235 |
+
result = self.normalize_scores(audio_scores, text_scores)
|
236 |
+
return result if not scores_only else list(result.values())
|
237 |
+
|
238 |
+
|
239 |
+
class EnglishEmotionModel(EmotionModel):
|
240 |
+
def __init__(self, duration: int = 6, sr: int = 16000, device: torch.device = "cuda", use_text: bool = True):
|
241 |
+
super().__init__(duration, sr, device, use_text)
|
242 |
+
# english audio model
|
243 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained("./ckpt/ser_en_audio")
|
244 |
+
self.audio_model = AutoModelForAudioClassification.from_pretrained("./ckpt/ser_en_audio").eval()
|
245 |
+
self.audio_id2label = {
|
246 |
+
0: "angry",
|
247 |
+
6: "disgust",
|
248 |
+
9: "fearful",
|
249 |
+
10: "happy",
|
250 |
+
11: "neutral",
|
251 |
+
12: "sad",
|
252 |
+
13: "surprised"
|
253 |
+
}
|
254 |
+
# chinese audio model
|
255 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path="./ckpt/ser_cn_audio")
|
256 |
+
self.cn_processor = Wav2Vec2FeatureExtractor.from_pretrained("./ckpt/ser_cn_audio")
|
257 |
+
self.cn_audio_model = HubertForSpeechClassification.from_pretrained("./ckpt/ser_cn_audio", config=config).eval()
|
258 |
+
self.cn_audio_id2label = {
|
259 |
+
0: "angry",
|
260 |
+
1: "fearful",
|
261 |
+
2: "happy",
|
262 |
+
3: "neutral",
|
263 |
+
4: "sad",
|
264 |
+
5: "surprised"
|
265 |
+
}
|
266 |
+
# english text model
|
267 |
+
self.text_model = AutoModelForSequenceClassification.from_pretrained("./ckpt/ser_en_text").eval()
|
268 |
+
self.tokenizer = AutoTokenizer.from_pretrained("./ckpt/ser_en_text")
|
269 |
+
self.text_id2label = {
|
270 |
+
0: "angry",
|
271 |
+
1: "disgust",
|
272 |
+
2: "fearful",
|
273 |
+
3: "happy",
|
274 |
+
4: "neutral",
|
275 |
+
5: "sad",
|
276 |
+
6: "surprised"
|
277 |
+
}
|
278 |
+
self.audio_model.to(self.device)
|
279 |
+
self.text_model.to(self.device)
|
280 |
+
self.cn_audio_model.to(self.device)
|
281 |
+
# load the MSP-DIM model
|
282 |
+
self.msp_dim = Wav2VecForSpeechRegression.from_pretrained('./ckpt/emo_dim_model').to(device)
|
283 |
+
self.msp_processor = Wav2Vec2FeatureExtractor.from_pretrained('./ckpt/emo_dim_model')
|
284 |
+
# self.msp_dim = Wav2VecForSpeechRegression.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(device)
|
285 |
+
# self.msp_processor = Wav2Vec2FeatureExtractor.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim')
|
286 |
+
# load the translation model (CN2EN)
|
287 |
+
# self.translator = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').to(device)
|
288 |
+
# self.translator_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
|
289 |
+
self.translator = AutoModelForSeq2SeqLM.from_pretrained('./ckpt/zh-2-en').to(device)
|
290 |
+
self.translator_tokenizer = AutoTokenizer.from_pretrained('./ckpt/zh-2-en')
|
291 |
+
|
292 |
+
def plot_dim(self, dim_result, line_height=1):
|
293 |
+
# 设置图表参数
|
294 |
+
fig, ax = plt.subplots(figsize=(7, 3), dpi=300)
|
295 |
+
|
296 |
+
# 配置条形图参数
|
297 |
+
labels = list(dim_result.keys()) # 标签 (A, V, D)
|
298 |
+
values = list(dim_result.values()) # 各标签的值
|
299 |
+
colors = ['blue', 'red', 'green'] # 每个标签的颜色
|
300 |
+
|
301 |
+
# 绘制每个条形图
|
302 |
+
for i, (label, value) in enumerate(dim_result.items()):
|
303 |
+
# 绘制条形图
|
304 |
+
ax.barh(i, value, color=colors[i], height=line_height, align='center')
|
305 |
+
|
306 |
+
# 绘制边框,从 -1 到 1
|
307 |
+
rect = patches.Rectangle((-1, i - line_height/2), 2, line_height, edgecolor='black', facecolor='none', linewidth=1)
|
308 |
+
ax.add_patch(rect)
|
309 |
+
|
310 |
+
# 设置 x 轴范围和样式
|
311 |
+
ax.set_xlim(-1.0, 1.1)
|
312 |
+
ax.axvline(0, color='black', linewidth=1) # 中心线
|
313 |
+
|
314 |
+
# 设置轴标签和标题
|
315 |
+
ax.set_xticks([-1, 0, 1])
|
316 |
+
ax.set_xticklabels(['Low (-1)', 'Neutral (0)', 'High (1)'])
|
317 |
+
ax.set_yticks(range(len(dim_result)))
|
318 |
+
ax.set_yticklabels(labels)
|
319 |
+
|
320 |
+
# 去除图像外边框
|
321 |
+
ax.spines['top'].set_visible(False)
|
322 |
+
ax.spines['right'].set_visible(False)
|
323 |
+
ax.spines['left'].set_visible(False)
|
324 |
+
ax.spines['bottom'].set_visible(False)
|
325 |
+
|
326 |
+
# 隐藏 y 轴网格线
|
327 |
+
ax.grid(False)
|
328 |
+
|
329 |
+
# 显示图形
|
330 |
+
plt.tight_layout()
|
331 |
+
return fig
|
332 |
+
|
333 |
+
def dim_pred(self, inputs, return_plot=True):
|
334 |
+
inputs = self.msp_processor(inputs, return_tensors="pt", padding=False, truncation=True, max_length=160000, sampling_rate=16000).input_values.to(self.device)
|
335 |
+
with torch.no_grad():
|
336 |
+
hidden_states, logits = self.msp_dim(input_values=inputs)
|
337 |
+
logits = logits[0].clamp_(0, 1).detach().cpu().numpy()
|
338 |
+
logits = (logits - 0.5) * 2 # remap to (-1, 1)
|
339 |
+
result = {'arousal': logits[0], 'valence': logits[2], 'dominance': logits[1]}
|
340 |
+
if return_plot:
|
341 |
+
result = self.plot_dim(result)
|
342 |
+
return result
|
343 |
+
|
344 |
+
def cn_audio_pred(self, inputs):
|
345 |
+
"""
|
346 |
+
Predict emotion class on audio segment.
|
347 |
+
Args:
|
348 |
+
inputs: audio inputs;
|
349 |
+
scores_only: bool, whether to return only scores.
|
350 |
+
"""
|
351 |
+
speech = self.cn_processor(inputs, padding="max_length", truncation=True, max_length=self.duration * self.sr,
|
352 |
+
return_tensors="pt", sampling_rate=self.sr).input_values.to(self.device)
|
353 |
+
with torch.no_grad():
|
354 |
+
logits = self.cn_audio_model(speech).logits
|
355 |
+
scores, indices = torch.sort(logits.squeeze().detach().cpu(), descending=True)
|
356 |
+
scores = F.softmax(scores, dim=0).numpy()
|
357 |
+
indices = indices.numpy()
|
358 |
+
return self.id2label(self.cn_audio_id2label, indices, scores)
|
359 |
+
|
360 |
+
def predict(self, audio, language_choice, preprocess: bool = True, text_weight: float = 0.5):
|
361 |
+
"""
|
362 |
+
Run prediction based on the recipe.
|
363 |
+
Args:
|
364 |
+
audio: audio file path or audio data;
|
365 |
+
language_choice: str, "中文" or "English";
|
366 |
+
preprocess: bool, whether to run preprocess function.
|
367 |
+
text_weight: the ratio for text prediction.
|
368 |
+
"""
|
369 |
+
speech = self.load_audio(audio, preprocess=preprocess)
|
370 |
+
if language_choice == '中文':
|
371 |
+
audio_scores = self.cn_audio_pred(speech)
|
372 |
+
else:
|
373 |
+
audio_scores = self.audio_pred(speech)
|
374 |
+
if not self.use_text:
|
375 |
+
dim_result = self.dim_pred(speech)
|
376 |
+
return audio_scores, dim_result, None
|
377 |
+
output = self.tts_model.transcribe(speech)
|
378 |
+
text, language = output['text'], output['language']
|
379 |
+
if language != 'en':
|
380 |
+
inputs = self.translator_tokenizer(text, return_tensors="pt").to(self.device)
|
381 |
+
output = self.translator.generate(**inputs)
|
382 |
+
text_en = self.translator_tokenizer.decode(output[0], skip_special_tokens=True)
|
383 |
+
text_scores = self.text_pred(text_en)
|
384 |
+
else:
|
385 |
+
text_scores = self.text_pred(text)
|
386 |
+
result = self.normalize_scores(audio_scores, text_scores, audio_weight=1-text_weight, text_weight=text_weight)
|
387 |
+
dim_result = self.dim_pred(speech)
|
388 |
+
# emotion, dim, text
|
389 |
+
return result, dim_result, text
|
ckpt/emo_dim_model/config.json
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "torch",
|
3 |
+
"activation_dropout": 0.1,
|
4 |
+
"adapter_kernel_size": 3,
|
5 |
+
"adapter_stride": 2,
|
6 |
+
"add_adapter": false,
|
7 |
+
"apply_spec_augment": true,
|
8 |
+
"architectures": [
|
9 |
+
"Wav2Vec2ForSpeechClassification"
|
10 |
+
],
|
11 |
+
"attention_dropout": 0.1,
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"classifier_proj_size": 256,
|
14 |
+
"codevector_dim": 768,
|
15 |
+
"contrastive_logits_temperature": 0.1,
|
16 |
+
"conv_bias": true,
|
17 |
+
"conv_dim": [
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512,
|
24 |
+
512
|
25 |
+
],
|
26 |
+
"conv_kernel": [
|
27 |
+
10,
|
28 |
+
3,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
3,
|
32 |
+
2,
|
33 |
+
2
|
34 |
+
],
|
35 |
+
"conv_stride": [
|
36 |
+
5,
|
37 |
+
2,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2
|
43 |
+
],
|
44 |
+
"ctc_loss_reduction": "sum",
|
45 |
+
"ctc_zero_infinity": false,
|
46 |
+
"diversity_loss_weight": 0.1,
|
47 |
+
"do_stable_layer_norm": true,
|
48 |
+
"eos_token_id": 2,
|
49 |
+
"feat_extract_activation": "gelu",
|
50 |
+
"feat_extract_dropout": 0.0,
|
51 |
+
"feat_extract_norm": "layer",
|
52 |
+
"feat_proj_dropout": 0.1,
|
53 |
+
"feat_quantizer_dropout": 0.0,
|
54 |
+
"final_dropout": 0.1,
|
55 |
+
"finetuning_task": "wav2vec2_reg",
|
56 |
+
"gradient_checkpointing": false,
|
57 |
+
"hidden_act": "gelu",
|
58 |
+
"hidden_dropout": 0.1,
|
59 |
+
"hidden_dropout_prob": 0.1,
|
60 |
+
"hidden_size": 1024,
|
61 |
+
"id2label": {
|
62 |
+
"0": "arousal",
|
63 |
+
"1": "dominance",
|
64 |
+
"2": "valence"
|
65 |
+
},
|
66 |
+
"initializer_range": 0.02,
|
67 |
+
"intermediate_size": 4096,
|
68 |
+
"label2id": {
|
69 |
+
"arousal": 0,
|
70 |
+
"dominance": 1,
|
71 |
+
"valence": 2
|
72 |
+
},
|
73 |
+
"layer_norm_eps": 1e-05,
|
74 |
+
"layerdrop": 0.1,
|
75 |
+
"mask_feature_length": 10,
|
76 |
+
"mask_feature_min_masks": 0,
|
77 |
+
"mask_feature_prob": 0.0,
|
78 |
+
"mask_time_length": 10,
|
79 |
+
"mask_time_min_masks": 2,
|
80 |
+
"mask_time_prob": 0.05,
|
81 |
+
"model_type": "wav2vec2",
|
82 |
+
"num_adapter_layers": 3,
|
83 |
+
"num_attention_heads": 16,
|
84 |
+
"num_codevector_groups": 2,
|
85 |
+
"num_codevectors_per_group": 320,
|
86 |
+
"num_conv_pos_embedding_groups": 16,
|
87 |
+
"num_conv_pos_embeddings": 128,
|
88 |
+
"num_feat_extract_layers": 7,
|
89 |
+
"num_hidden_layers": 12,
|
90 |
+
"num_negatives": 100,
|
91 |
+
"output_hidden_size": 1024,
|
92 |
+
"pad_token_id": 0,
|
93 |
+
"pooling_mode": "mean",
|
94 |
+
"problem_type": "regression",
|
95 |
+
"proj_codevector_dim": 768,
|
96 |
+
"tdnn_dilation": [
|
97 |
+
1,
|
98 |
+
2,
|
99 |
+
3,
|
100 |
+
1,
|
101 |
+
1
|
102 |
+
],
|
103 |
+
"tdnn_dim": [
|
104 |
+
512,
|
105 |
+
512,
|
106 |
+
512,
|
107 |
+
512,
|
108 |
+
1500
|
109 |
+
],
|
110 |
+
"tdnn_kernel": [
|
111 |
+
5,
|
112 |
+
3,
|
113 |
+
3,
|
114 |
+
1,
|
115 |
+
1
|
116 |
+
],
|
117 |
+
"torch_dtype": "float32",
|
118 |
+
"transformers_version": "4.17.0.dev0",
|
119 |
+
"use_weighted_layer_sum": false,
|
120 |
+
"vocab_size": null,
|
121 |
+
"xvector_output_dim": 512
|
122 |
+
}
|
ckpt/emo_dim_model/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:efa5ac1a13b2d2f42182738e44794b1eb4c0cdd221a8b4ae11304c3a5f5fae95
|
3 |
+
size 661375508
|
ckpt/emo_dim_model/preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"return_attention_mask": true,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
ckpt/emo_dim_model/vocab.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
ckpt/sepformer-wham-enhancement/decoder.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4703b15d23ad5dd4c9b6b93b09539cf0048ba2e58a36c71a62fb860d5b0d343f
|
3 |
+
size 17272
|
ckpt/sepformer-wham-enhancement/encoder.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c6b3e53a4061b81b7b0abf6a7faac8ee9714e0509a49ac60d249488e347430c
|
3 |
+
size 17272
|
ckpt/sepformer-wham-enhancement/hyperparams.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ################################
|
2 |
+
# Model: Pretrained SepFormer for speech enhancement
|
3 |
+
# Dataset : WHAM!
|
4 |
+
# ################################
|
5 |
+
|
6 |
+
num_spks: 1
|
7 |
+
sample_rate: 8000
|
8 |
+
|
9 |
+
# Encoder parameters
|
10 |
+
N_encoder_out: 256
|
11 |
+
out_channels: 256
|
12 |
+
kernel_size: 16
|
13 |
+
kernel_stride: 8
|
14 |
+
|
15 |
+
# Specifying the network
|
16 |
+
Encoder: !new:speechbrain.lobes.models.dual_path.Encoder
|
17 |
+
kernel_size: 16
|
18 |
+
out_channels: 256
|
19 |
+
|
20 |
+
SBtfintra: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
|
21 |
+
num_layers: 8
|
22 |
+
d_model: 256
|
23 |
+
nhead: 8
|
24 |
+
d_ffn: 1024
|
25 |
+
dropout: 0
|
26 |
+
use_positional_encoding: true
|
27 |
+
norm_before: true
|
28 |
+
|
29 |
+
SBtfinter: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
|
30 |
+
num_layers: 8
|
31 |
+
d_model: 256
|
32 |
+
nhead: 8
|
33 |
+
d_ffn: 1024
|
34 |
+
dropout: 0
|
35 |
+
use_positional_encoding: true
|
36 |
+
norm_before: true
|
37 |
+
|
38 |
+
MaskNet: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
|
39 |
+
num_spks: 1
|
40 |
+
in_channels: 256
|
41 |
+
out_channels: 256
|
42 |
+
num_layers: 2
|
43 |
+
K: 250
|
44 |
+
intra_model: !ref <SBtfintra>
|
45 |
+
inter_model: !ref <SBtfinter>
|
46 |
+
norm: ln
|
47 |
+
linear_layer_after_inter_intra: false
|
48 |
+
skip_around_intra: true
|
49 |
+
|
50 |
+
Decoder: !new:speechbrain.lobes.models.dual_path.Decoder
|
51 |
+
in_channels: 256
|
52 |
+
out_channels: 1
|
53 |
+
kernel_size: 16
|
54 |
+
stride: 8
|
55 |
+
bias: false
|
56 |
+
|
57 |
+
modules:
|
58 |
+
encoder: !ref <Encoder>
|
59 |
+
decoder: !ref <Decoder>
|
60 |
+
masknet: !ref <MaskNet>
|
61 |
+
|
62 |
+
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
|
63 |
+
loadables:
|
64 |
+
encoder: !ref <Encoder>
|
65 |
+
masknet: !ref <MaskNet>
|
66 |
+
decoder: !ref <Decoder>
|
ckpt/sepformer-wham-enhancement/masknet.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:111312d682aee6b72610b83edc0dcf253d7a62f745136cd2828113fb75fbb6e4
|
3 |
+
size 112849478
|
ckpt/ser_cn_audio/config.json
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "TencentGameMate/chinese-hubert-base",
|
3 |
+
"activation_dropout": 0.1,
|
4 |
+
"apply_spec_augment": true,
|
5 |
+
"architectures": [
|
6 |
+
"HubertForSpeechClassification"
|
7 |
+
],
|
8 |
+
"attention_dropout": 0.1,
|
9 |
+
"bos_token_id": 1,
|
10 |
+
"classifier_dropout": 0.1,
|
11 |
+
"classifier_proj_size": 256,
|
12 |
+
"conv_bias": false,
|
13 |
+
"conv_dim": [
|
14 |
+
512,
|
15 |
+
512,
|
16 |
+
512,
|
17 |
+
512,
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512
|
21 |
+
],
|
22 |
+
"conv_kernel": [
|
23 |
+
10,
|
24 |
+
3,
|
25 |
+
3,
|
26 |
+
3,
|
27 |
+
3,
|
28 |
+
2,
|
29 |
+
2
|
30 |
+
],
|
31 |
+
"conv_stride": [
|
32 |
+
5,
|
33 |
+
2,
|
34 |
+
2,
|
35 |
+
2,
|
36 |
+
2,
|
37 |
+
2,
|
38 |
+
2
|
39 |
+
],
|
40 |
+
"ctc_loss_reduction": "sum",
|
41 |
+
"ctc_zero_infinity": false,
|
42 |
+
"do_stable_layer_norm": false,
|
43 |
+
"eos_token_id": 2,
|
44 |
+
"feat_extract_activation": "gelu",
|
45 |
+
"feat_extract_norm": "group",
|
46 |
+
"feat_proj_dropout": 0.0,
|
47 |
+
"feat_proj_layer_norm": true,
|
48 |
+
"final_dropout": 0.1,
|
49 |
+
"hidden_act": "gelu",
|
50 |
+
"hidden_dropout": 0.1,
|
51 |
+
"hidden_size": 768,
|
52 |
+
"initializer_range": 0.02,
|
53 |
+
"intermediate_size": 3072,
|
54 |
+
"layer_norm_eps": 1e-05,
|
55 |
+
"layerdrop": 0.1,
|
56 |
+
"mask_feature_length": 10,
|
57 |
+
"mask_feature_min_masks": 0,
|
58 |
+
"mask_feature_prob": 0.0,
|
59 |
+
"mask_time_length": 10,
|
60 |
+
"mask_time_min_masks": 2,
|
61 |
+
"mask_time_prob": 0.05,
|
62 |
+
"model_type": "hubert",
|
63 |
+
"num_attention_heads": 12,
|
64 |
+
"num_class": 6,
|
65 |
+
"num_conv_pos_embedding_groups": 16,
|
66 |
+
"num_conv_pos_embeddings": 128,
|
67 |
+
"num_feat_extract_layers": 7,
|
68 |
+
"num_hidden_layers": 12,
|
69 |
+
"pad_token_id": 0,
|
70 |
+
"torch_dtype": "float32",
|
71 |
+
"transformers_version": "4.24.0",
|
72 |
+
"use_weighted_layer_sum": false,
|
73 |
+
"vocab_size": 32
|
74 |
+
}
|
ckpt/ser_cn_audio/preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0,
|
7 |
+
"return_attention_mask": false,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
ckpt/ser_cn_audio/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cd2286572750ab6f4cf3d1a5283cf6c92b4a8ae9e87f38ebb515439a56c5b53
|
3 |
+
size 379939475
|
ckpt/ser_en_audio/config.json
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
|
3 |
+
"activation_dropout": 0.05,
|
4 |
+
"adapter_attn_dim": null,
|
5 |
+
"adapter_kernel_size": 3,
|
6 |
+
"adapter_stride": 2,
|
7 |
+
"add_adapter": false,
|
8 |
+
"apply_spec_augment": true,
|
9 |
+
"architectures": [
|
10 |
+
"Wav2Vec2ForSequenceClassification"
|
11 |
+
],
|
12 |
+
"attention_dropout": 0.1,
|
13 |
+
"bos_token_id": 1,
|
14 |
+
"classifier_proj_size": 256,
|
15 |
+
"codevector_dim": 256,
|
16 |
+
"contrastive_logits_temperature": 0.1,
|
17 |
+
"conv_bias": true,
|
18 |
+
"conv_dim": [
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512,
|
24 |
+
512,
|
25 |
+
512
|
26 |
+
],
|
27 |
+
"conv_kernel": [
|
28 |
+
10,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
3,
|
32 |
+
3,
|
33 |
+
2,
|
34 |
+
2
|
35 |
+
],
|
36 |
+
"conv_stride": [
|
37 |
+
5,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2,
|
43 |
+
2
|
44 |
+
],
|
45 |
+
"ctc_loss_reduction": "mean",
|
46 |
+
"ctc_zero_infinity": true,
|
47 |
+
"diversity_loss_weight": 0.1,
|
48 |
+
"do_stable_layer_norm": true,
|
49 |
+
"eos_token_id": 2,
|
50 |
+
"feat_extract_activation": "gelu",
|
51 |
+
"feat_extract_dropout": 0.0,
|
52 |
+
"feat_extract_norm": "layer",
|
53 |
+
"feat_proj_dropout": 0.05,
|
54 |
+
"feat_quantizer_dropout": 0.0,
|
55 |
+
"final_dropout": 0.0,
|
56 |
+
"hidden_act": "gelu",
|
57 |
+
"hidden_dropout": 0.05,
|
58 |
+
"hidden_size": 1024,
|
59 |
+
"id2label": {
|
60 |
+
"0": "angry",
|
61 |
+
"1": "anxious",
|
62 |
+
"2": "apologetic",
|
63 |
+
"3": "assertive",
|
64 |
+
"4": "calm",
|
65 |
+
"5": "concerned",
|
66 |
+
"6": "disgust",
|
67 |
+
"7": "encouraging",
|
68 |
+
"8": "excited",
|
69 |
+
"9": "fearful",
|
70 |
+
"10": "happy",
|
71 |
+
"11": "neutral",
|
72 |
+
"12": "sad",
|
73 |
+
"13": "surprised"
|
74 |
+
},
|
75 |
+
"initializer_range": 0.02,
|
76 |
+
"intermediate_size": 4096,
|
77 |
+
"label2id": {
|
78 |
+
"angry": 0,
|
79 |
+
"anxious": 1,
|
80 |
+
"apologetic": 2,
|
81 |
+
"assertive": 3,
|
82 |
+
"calm": 4,
|
83 |
+
"concerned": 5,
|
84 |
+
"disgust": 6,
|
85 |
+
"encouraging": 7,
|
86 |
+
"excited": 8,
|
87 |
+
"fearful": 9,
|
88 |
+
"happy": 10,
|
89 |
+
"neutral": 11,
|
90 |
+
"sad": 12,
|
91 |
+
"surprised": 13
|
92 |
+
},
|
93 |
+
"layer_norm_eps": 1e-05,
|
94 |
+
"layerdrop": 0.05,
|
95 |
+
"mask_channel_length": 10,
|
96 |
+
"mask_channel_min_space": 1,
|
97 |
+
"mask_channel_other": 0.0,
|
98 |
+
"mask_channel_prob": 0.0,
|
99 |
+
"mask_channel_selection": "static",
|
100 |
+
"mask_feature_length": 10,
|
101 |
+
"mask_feature_min_masks": 0,
|
102 |
+
"mask_feature_prob": 0.0,
|
103 |
+
"mask_time_length": 10,
|
104 |
+
"mask_time_min_masks": 2,
|
105 |
+
"mask_time_min_space": 1,
|
106 |
+
"mask_time_other": 0.0,
|
107 |
+
"mask_time_prob": 0.05,
|
108 |
+
"mask_time_selection": "static",
|
109 |
+
"model_type": "wav2vec2",
|
110 |
+
"num_adapter_layers": 3,
|
111 |
+
"num_attention_heads": 16,
|
112 |
+
"num_codevector_groups": 2,
|
113 |
+
"num_codevectors_per_group": 320,
|
114 |
+
"num_conv_pos_embedding_groups": 16,
|
115 |
+
"num_conv_pos_embeddings": 128,
|
116 |
+
"num_feat_extract_layers": 7,
|
117 |
+
"num_hidden_layers": 24,
|
118 |
+
"num_negatives": 100,
|
119 |
+
"output_hidden_size": 1024,
|
120 |
+
"pad_token_id": 0,
|
121 |
+
"proj_codevector_dim": 256,
|
122 |
+
"tdnn_dilation": [
|
123 |
+
1,
|
124 |
+
2,
|
125 |
+
3,
|
126 |
+
1,
|
127 |
+
1
|
128 |
+
],
|
129 |
+
"tdnn_dim": [
|
130 |
+
512,
|
131 |
+
512,
|
132 |
+
512,
|
133 |
+
512,
|
134 |
+
1500
|
135 |
+
],
|
136 |
+
"tdnn_kernel": [
|
137 |
+
5,
|
138 |
+
3,
|
139 |
+
3,
|
140 |
+
1,
|
141 |
+
1
|
142 |
+
],
|
143 |
+
"torch_dtype": "float32",
|
144 |
+
"transformers_version": "4.45.2",
|
145 |
+
"use_weighted_layer_sum": false,
|
146 |
+
"vocab_size": 33,
|
147 |
+
"xvector_output_dim": 512
|
148 |
+
}
|
ckpt/ser_en_audio/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03a81b9298d54f85fec594b071c6fa8d32483876219ad0912f7324aaff2c3d72
|
3 |
+
size 1262871640
|
ckpt/ser_en_audio/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:334204fa336cc008160506dd3272196fde73f9b74039d8b0c2e9eac8c8998acf
|
3 |
+
size 2525994320
|
ckpt/ser_en_audio/preprocessor_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"processor_class": "Wav2Vec2ProcessorWithLM",
|
8 |
+
"return_attention_mask": true,
|
9 |
+
"sampling_rate": 16000
|
10 |
+
}
|
ckpt/ser_en_audio/rng_state_0.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cc40fede51cbf483771073afe9c2d734758236aad4608d6f0393100f5015429
|
3 |
+
size 15024
|
ckpt/ser_en_audio/rng_state_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58842f3563b84c1197fe10d111b14adbe13041260bf73977633194268f547050
|
3 |
+
size 15024
|
ckpt/ser_en_audio/rng_state_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a770c18a5735b020b633ee5c0dc23f405eb5244709c33f644d9f449cdbc2331
|
3 |
+
size 15024
|
ckpt/ser_en_audio/rng_state_3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ca2caa02ebff6c447147c9295ca4a54048a4e2d5d2d45b27fad60090f0084fc
|
3 |
+
size 15024
|
ckpt/ser_en_audio/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8017f56711090ed38368e1dea349327d91f7f162e079c14836498d327bbfab77
|
3 |
+
size 1064
|
ckpt/ser_en_audio/trainer_state.json
ADDED
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 30.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 18600,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.8064516129032258,
|
13 |
+
"grad_norm": 5.6446027755737305,
|
14 |
+
"learning_rate": 9.798387096774194e-05,
|
15 |
+
"loss": 0.8653,
|
16 |
+
"step": 500
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"epoch": 1.0,
|
20 |
+
"eval_accuracy": 0.8519116311913649,
|
21 |
+
"eval_f1": 0.5612972035236246,
|
22 |
+
"eval_loss": 0.41698312759399414,
|
23 |
+
"eval_precision": 0.6519175226159025,
|
24 |
+
"eval_recall": 0.5331925780455707,
|
25 |
+
"eval_runtime": 444.2569,
|
26 |
+
"eval_samples_per_second": 22.314,
|
27 |
+
"eval_steps_per_second": 0.698,
|
28 |
+
"step": 620
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"epoch": 1.6129032258064515,
|
32 |
+
"grad_norm": 2.4346652030944824,
|
33 |
+
"learning_rate": 9.596774193548387e-05,
|
34 |
+
"loss": 0.3288,
|
35 |
+
"step": 1000
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"epoch": 2.0,
|
39 |
+
"eval_accuracy": 0.9004337738323414,
|
40 |
+
"eval_f1": 0.7645083858520498,
|
41 |
+
"eval_loss": 0.27656903862953186,
|
42 |
+
"eval_precision": 0.775442200715531,
|
43 |
+
"eval_recall": 0.7599977765696228,
|
44 |
+
"eval_runtime": 441.6079,
|
45 |
+
"eval_samples_per_second": 22.448,
|
46 |
+
"eval_steps_per_second": 0.702,
|
47 |
+
"step": 1240
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"epoch": 2.4193548387096775,
|
51 |
+
"grad_norm": 3.350771903991699,
|
52 |
+
"learning_rate": 9.395161290322582e-05,
|
53 |
+
"loss": 0.2313,
|
54 |
+
"step": 1500
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"epoch": 3.0,
|
58 |
+
"eval_accuracy": 0.9010390396449107,
|
59 |
+
"eval_f1": 0.7136957134951721,
|
60 |
+
"eval_loss": 0.30654314160346985,
|
61 |
+
"eval_precision": 0.7918800656051795,
|
62 |
+
"eval_recall": 0.680703182570182,
|
63 |
+
"eval_runtime": 443.8741,
|
64 |
+
"eval_samples_per_second": 22.333,
|
65 |
+
"eval_steps_per_second": 0.698,
|
66 |
+
"step": 1860
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"epoch": 3.225806451612903,
|
70 |
+
"grad_norm": 1.5124069452285767,
|
71 |
+
"learning_rate": 9.193548387096774e-05,
|
72 |
+
"loss": 0.1733,
|
73 |
+
"step": 2000
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"epoch": 4.0,
|
77 |
+
"eval_accuracy": 0.909210128114597,
|
78 |
+
"eval_f1": 0.8132841024537738,
|
79 |
+
"eval_loss": 0.3023378252983093,
|
80 |
+
"eval_precision": 0.8052977470471581,
|
81 |
+
"eval_recall": 0.8366718533903973,
|
82 |
+
"eval_runtime": 444.6165,
|
83 |
+
"eval_samples_per_second": 22.296,
|
84 |
+
"eval_steps_per_second": 0.697,
|
85 |
+
"step": 2480
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"epoch": 4.032258064516129,
|
89 |
+
"grad_norm": 1.1565921306610107,
|
90 |
+
"learning_rate": 8.991935483870968e-05,
|
91 |
+
"loss": 0.1336,
|
92 |
+
"step": 2500
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"epoch": 4.838709677419355,
|
96 |
+
"grad_norm": 7.502715110778809,
|
97 |
+
"learning_rate": 8.790322580645162e-05,
|
98 |
+
"loss": 0.1044,
|
99 |
+
"step": 3000
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"epoch": 5.0,
|
103 |
+
"eval_accuracy": 0.9228286088974075,
|
104 |
+
"eval_f1": 0.7768284943233322,
|
105 |
+
"eval_loss": 0.27094921469688416,
|
106 |
+
"eval_precision": 0.8088427869040292,
|
107 |
+
"eval_recall": 0.7605468716950611,
|
108 |
+
"eval_runtime": 446.1958,
|
109 |
+
"eval_samples_per_second": 22.217,
|
110 |
+
"eval_steps_per_second": 0.695,
|
111 |
+
"step": 3100
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"epoch": 5.645161290322581,
|
115 |
+
"grad_norm": 4.222163677215576,
|
116 |
+
"learning_rate": 8.588709677419356e-05,
|
117 |
+
"loss": 0.0891,
|
118 |
+
"step": 3500
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"epoch": 6.0,
|
122 |
+
"eval_accuracy": 0.9190961363865631,
|
123 |
+
"eval_f1": 0.836047484191284,
|
124 |
+
"eval_loss": 0.2973528802394867,
|
125 |
+
"eval_precision": 0.8333888690415643,
|
126 |
+
"eval_recall": 0.8462180948065121,
|
127 |
+
"eval_runtime": 445.7518,
|
128 |
+
"eval_samples_per_second": 22.239,
|
129 |
+
"eval_steps_per_second": 0.695,
|
130 |
+
"step": 3720
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"epoch": 6.451612903225806,
|
134 |
+
"grad_norm": 0.9999768137931824,
|
135 |
+
"learning_rate": 8.387096774193549e-05,
|
136 |
+
"loss": 0.0738,
|
137 |
+
"step": 4000
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"epoch": 7.0,
|
141 |
+
"eval_accuracy": 0.9198022798345606,
|
142 |
+
"eval_f1": 0.8339480533055649,
|
143 |
+
"eval_loss": 0.32465115189552307,
|
144 |
+
"eval_precision": 0.8531984900149399,
|
145 |
+
"eval_recall": 0.8313244274621819,
|
146 |
+
"eval_runtime": 444.2156,
|
147 |
+
"eval_samples_per_second": 22.316,
|
148 |
+
"eval_steps_per_second": 0.698,
|
149 |
+
"step": 4340
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"epoch": 7.258064516129032,
|
153 |
+
"grad_norm": 0.9860001802444458,
|
154 |
+
"learning_rate": 8.185483870967743e-05,
|
155 |
+
"loss": 0.0617,
|
156 |
+
"step": 4500
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"epoch": 8.0,
|
160 |
+
"eval_accuracy": 0.9283768788459599,
|
161 |
+
"eval_f1": 0.8488005438808053,
|
162 |
+
"eval_loss": 0.2583344280719757,
|
163 |
+
"eval_precision": 0.8446683685244591,
|
164 |
+
"eval_recall": 0.8628243687115127,
|
165 |
+
"eval_runtime": 446.3564,
|
166 |
+
"eval_samples_per_second": 22.209,
|
167 |
+
"eval_steps_per_second": 0.695,
|
168 |
+
"step": 4960
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"epoch": 8.064516129032258,
|
172 |
+
"grad_norm": 0.3879972994327545,
|
173 |
+
"learning_rate": 7.983870967741936e-05,
|
174 |
+
"loss": 0.0574,
|
175 |
+
"step": 5000
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"epoch": 8.870967741935484,
|
179 |
+
"grad_norm": 4.836447715759277,
|
180 |
+
"learning_rate": 7.78225806451613e-05,
|
181 |
+
"loss": 0.0492,
|
182 |
+
"step": 5500
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"epoch": 9.0,
|
186 |
+
"eval_accuracy": 0.9225259759911227,
|
187 |
+
"eval_f1": 0.816817658073689,
|
188 |
+
"eval_loss": 0.34519901871681213,
|
189 |
+
"eval_precision": 0.8676454880398878,
|
190 |
+
"eval_recall": 0.7877172029441207,
|
191 |
+
"eval_runtime": 447.3198,
|
192 |
+
"eval_samples_per_second": 22.161,
|
193 |
+
"eval_steps_per_second": 0.693,
|
194 |
+
"step": 5580
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"epoch": 9.67741935483871,
|
198 |
+
"grad_norm": 0.07304174453020096,
|
199 |
+
"learning_rate": 7.580645161290323e-05,
|
200 |
+
"loss": 0.0419,
|
201 |
+
"step": 6000
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"epoch": 10.0,
|
205 |
+
"eval_accuracy": 0.9251487945122566,
|
206 |
+
"eval_f1": 0.8469126432067672,
|
207 |
+
"eval_loss": 0.33950358629226685,
|
208 |
+
"eval_precision": 0.8519618889252182,
|
209 |
+
"eval_recall": 0.8489192187654137,
|
210 |
+
"eval_runtime": 446.1732,
|
211 |
+
"eval_samples_per_second": 22.218,
|
212 |
+
"eval_steps_per_second": 0.695,
|
213 |
+
"step": 6200
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"epoch": 10.483870967741936,
|
217 |
+
"grad_norm": 3.8705227375030518,
|
218 |
+
"learning_rate": 7.379032258064516e-05,
|
219 |
+
"loss": 0.0331,
|
220 |
+
"step": 6500
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"epoch": 11.0,
|
224 |
+
"eval_accuracy": 0.9118329466357309,
|
225 |
+
"eval_f1": 0.8376640160936598,
|
226 |
+
"eval_loss": 0.4379476308822632,
|
227 |
+
"eval_precision": 0.8312605064526365,
|
228 |
+
"eval_recall": 0.8501611125751187,
|
229 |
+
"eval_runtime": 444.981,
|
230 |
+
"eval_samples_per_second": 22.277,
|
231 |
+
"eval_steps_per_second": 0.697,
|
232 |
+
"step": 6820
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"epoch": 11.290322580645162,
|
236 |
+
"grad_norm": 0.8799965381622314,
|
237 |
+
"learning_rate": 7.177419354838711e-05,
|
238 |
+
"loss": 0.0353,
|
239 |
+
"step": 7000
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"epoch": 12.0,
|
243 |
+
"eval_accuracy": 0.9295874104710986,
|
244 |
+
"eval_f1": 0.8496199395886753,
|
245 |
+
"eval_loss": 0.3607427775859833,
|
246 |
+
"eval_precision": 0.8474190215073337,
|
247 |
+
"eval_recall": 0.8606261174217371,
|
248 |
+
"eval_runtime": 445.7391,
|
249 |
+
"eval_samples_per_second": 22.239,
|
250 |
+
"eval_steps_per_second": 0.695,
|
251 |
+
"step": 7440
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"epoch": 12.096774193548388,
|
255 |
+
"grad_norm": 1.5596448183059692,
|
256 |
+
"learning_rate": 6.975806451612904e-05,
|
257 |
+
"loss": 0.0315,
|
258 |
+
"step": 7500
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"epoch": 12.903225806451612,
|
262 |
+
"grad_norm": 1.5588679313659668,
|
263 |
+
"learning_rate": 6.774193548387096e-05,
|
264 |
+
"loss": 0.0289,
|
265 |
+
"step": 8000
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"epoch": 13.0,
|
269 |
+
"eval_accuracy": 0.9273681024916776,
|
270 |
+
"eval_f1": 0.8508873409327621,
|
271 |
+
"eval_loss": 0.3614977300167084,
|
272 |
+
"eval_precision": 0.8526893632091166,
|
273 |
+
"eval_recall": 0.8572290757862527,
|
274 |
+
"eval_runtime": 486.6477,
|
275 |
+
"eval_samples_per_second": 20.37,
|
276 |
+
"eval_steps_per_second": 0.637,
|
277 |
+
"step": 8060
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"epoch": 13.709677419354838,
|
281 |
+
"grad_norm": 0.03195716440677643,
|
282 |
+
"learning_rate": 6.57258064516129e-05,
|
283 |
+
"loss": 0.0261,
|
284 |
+
"step": 8500
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"epoch": 14.0,
|
288 |
+
"eval_accuracy": 0.9272672248562494,
|
289 |
+
"eval_f1": 0.8630758757277098,
|
290 |
+
"eval_loss": 0.36916211247444153,
|
291 |
+
"eval_precision": 0.8586072821035567,
|
292 |
+
"eval_recall": 0.8724563229411526,
|
293 |
+
"eval_runtime": 472.9747,
|
294 |
+
"eval_samples_per_second": 20.959,
|
295 |
+
"eval_steps_per_second": 0.655,
|
296 |
+
"step": 8680
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"epoch": 14.516129032258064,
|
300 |
+
"grad_norm": 0.2278767228126526,
|
301 |
+
"learning_rate": 6.370967741935485e-05,
|
302 |
+
"loss": 0.0239,
|
303 |
+
"step": 9000
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"epoch": 15.0,
|
307 |
+
"eval_accuracy": 0.9261575708665389,
|
308 |
+
"eval_f1": 0.8540811504617694,
|
309 |
+
"eval_loss": 0.4022212028503418,
|
310 |
+
"eval_precision": 0.8616150833457688,
|
311 |
+
"eval_recall": 0.8575704808190433,
|
312 |
+
"eval_runtime": 487.8531,
|
313 |
+
"eval_samples_per_second": 20.32,
|
314 |
+
"eval_steps_per_second": 0.635,
|
315 |
+
"step": 9300
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"epoch": 15.32258064516129,
|
319 |
+
"grad_norm": 0.053363025188446045,
|
320 |
+
"learning_rate": 6.169354838709678e-05,
|
321 |
+
"loss": 0.0255,
|
322 |
+
"step": 9500
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"epoch": 16.0,
|
326 |
+
"eval_accuracy": 0.9373549883990719,
|
327 |
+
"eval_f1": 0.8401806488376877,
|
328 |
+
"eval_loss": 0.33772599697113037,
|
329 |
+
"eval_precision": 0.8640083214464809,
|
330 |
+
"eval_recall": 0.8396382260534664,
|
331 |
+
"eval_runtime": 462.9155,
|
332 |
+
"eval_samples_per_second": 21.414,
|
333 |
+
"eval_steps_per_second": 0.67,
|
334 |
+
"step": 9920
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"epoch": 16.129032258064516,
|
338 |
+
"grad_norm": 0.11232730746269226,
|
339 |
+
"learning_rate": 5.9677419354838715e-05,
|
340 |
+
"loss": 0.0199,
|
341 |
+
"step": 10000
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"epoch": 16.93548387096774,
|
345 |
+
"grad_norm": 0.48599764704704285,
|
346 |
+
"learning_rate": 5.7661290322580655e-05,
|
347 |
+
"loss": 0.0196,
|
348 |
+
"step": 10500
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"epoch": 17.0,
|
352 |
+
"eval_accuracy": 0.9315040855442348,
|
353 |
+
"eval_f1": 0.8327759456218216,
|
354 |
+
"eval_loss": 0.3767533302307129,
|
355 |
+
"eval_precision": 0.8591712003111222,
|
356 |
+
"eval_recall": 0.8342081502036376,
|
357 |
+
"eval_runtime": 513.732,
|
358 |
+
"eval_samples_per_second": 19.296,
|
359 |
+
"eval_steps_per_second": 0.603,
|
360 |
+
"step": 10540
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"epoch": 17.741935483870968,
|
364 |
+
"grad_norm": 0.019181491807103157,
|
365 |
+
"learning_rate": 5.5645161290322576e-05,
|
366 |
+
"loss": 0.0192,
|
367 |
+
"step": 11000
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"epoch": 18.0,
|
371 |
+
"eval_accuracy": 0.9199031574699889,
|
372 |
+
"eval_f1": 0.8519369732275665,
|
373 |
+
"eval_loss": 0.45619821548461914,
|
374 |
+
"eval_precision": 0.8639403373066186,
|
375 |
+
"eval_recall": 0.8504081404701405,
|
376 |
+
"eval_runtime": 504.1002,
|
377 |
+
"eval_samples_per_second": 19.665,
|
378 |
+
"eval_steps_per_second": 0.615,
|
379 |
+
"step": 11160
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"epoch": 18.548387096774192,
|
383 |
+
"grad_norm": 0.03667838126420975,
|
384 |
+
"learning_rate": 5.362903225806452e-05,
|
385 |
+
"loss": 0.0136,
|
386 |
+
"step": 11500
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"epoch": 19.0,
|
390 |
+
"eval_accuracy": 0.9227277312619793,
|
391 |
+
"eval_f1": 0.792027873615724,
|
392 |
+
"eval_loss": 0.43276721239089966,
|
393 |
+
"eval_precision": 0.8409174974381394,
|
394 |
+
"eval_recall": 0.770137877267392,
|
395 |
+
"eval_runtime": 551.2838,
|
396 |
+
"eval_samples_per_second": 17.982,
|
397 |
+
"eval_steps_per_second": 0.562,
|
398 |
+
"step": 11780
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"epoch": 19.35483870967742,
|
402 |
+
"grad_norm": 4.472720623016357,
|
403 |
+
"learning_rate": 5.161290322580645e-05,
|
404 |
+
"loss": 0.0206,
|
405 |
+
"step": 12000
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"epoch": 20.0,
|
409 |
+
"eval_accuracy": 0.9308988197316654,
|
410 |
+
"eval_f1": 0.8577732939026992,
|
411 |
+
"eval_loss": 0.4217771291732788,
|
412 |
+
"eval_precision": 0.8689115082442737,
|
413 |
+
"eval_recall": 0.8597514273918758,
|
414 |
+
"eval_runtime": 494.2923,
|
415 |
+
"eval_samples_per_second": 20.055,
|
416 |
+
"eval_steps_per_second": 0.627,
|
417 |
+
"step": 12400
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"epoch": 20.161290322580644,
|
421 |
+
"grad_norm": 0.23360569775104523,
|
422 |
+
"learning_rate": 4.959677419354839e-05,
|
423 |
+
"loss": 0.0145,
|
424 |
+
"step": 12500
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"epoch": 20.967741935483872,
|
428 |
+
"grad_norm": 11.598882675170898,
|
429 |
+
"learning_rate": 4.7580645161290326e-05,
|
430 |
+
"loss": 0.0136,
|
431 |
+
"step": 13000
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"epoch": 21.0,
|
435 |
+
"eval_accuracy": 0.9211136890951276,
|
436 |
+
"eval_f1": 0.8543572369365581,
|
437 |
+
"eval_loss": 0.4980121850967407,
|
438 |
+
"eval_precision": 0.8536818621962678,
|
439 |
+
"eval_recall": 0.8631371178427569,
|
440 |
+
"eval_runtime": 509.7213,
|
441 |
+
"eval_samples_per_second": 19.448,
|
442 |
+
"eval_steps_per_second": 0.608,
|
443 |
+
"step": 13020
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"epoch": 21.774193548387096,
|
447 |
+
"grad_norm": 0.007597383111715317,
|
448 |
+
"learning_rate": 4.556451612903226e-05,
|
449 |
+
"loss": 0.0129,
|
450 |
+
"step": 13500
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"epoch": 22.0,
|
454 |
+
"eval_accuracy": 0.9353374356905074,
|
455 |
+
"eval_f1": 0.8522499542325414,
|
456 |
+
"eval_loss": 0.37319281697273254,
|
457 |
+
"eval_precision": 0.8670789223875701,
|
458 |
+
"eval_recall": 0.8504049853786648,
|
459 |
+
"eval_runtime": 479.4485,
|
460 |
+
"eval_samples_per_second": 20.676,
|
461 |
+
"eval_steps_per_second": 0.647,
|
462 |
+
"step": 13640
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"epoch": 22.580645161290324,
|
466 |
+
"grad_norm": 0.007587379310280085,
|
467 |
+
"learning_rate": 4.3548387096774194e-05,
|
468 |
+
"loss": 0.0097,
|
469 |
+
"step": 14000
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"epoch": 23.0,
|
473 |
+
"eval_accuracy": 0.9339251487945123,
|
474 |
+
"eval_f1": 0.8622667883302407,
|
475 |
+
"eval_loss": 0.39090433716773987,
|
476 |
+
"eval_precision": 0.8768600760726798,
|
477 |
+
"eval_recall": 0.8598596008497976,
|
478 |
+
"eval_runtime": 549.8766,
|
479 |
+
"eval_samples_per_second": 18.028,
|
480 |
+
"eval_steps_per_second": 0.564,
|
481 |
+
"step": 14260
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"epoch": 23.387096774193548,
|
485 |
+
"grad_norm": 0.0043108644895255566,
|
486 |
+
"learning_rate": 4.1532258064516135e-05,
|
487 |
+
"loss": 0.0124,
|
488 |
+
"step": 14500
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"epoch": 24.0,
|
492 |
+
"eval_accuracy": 0.9375567436699284,
|
493 |
+
"eval_f1": 0.8596119869481005,
|
494 |
+
"eval_loss": 0.3828715682029724,
|
495 |
+
"eval_precision": 0.8788495808974863,
|
496 |
+
"eval_recall": 0.8587605197643191,
|
497 |
+
"eval_runtime": 518.7131,
|
498 |
+
"eval_samples_per_second": 19.111,
|
499 |
+
"eval_steps_per_second": 0.598,
|
500 |
+
"step": 14880
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"epoch": 24.193548387096776,
|
504 |
+
"grad_norm": 0.013451021164655685,
|
505 |
+
"learning_rate": 3.951612903225806e-05,
|
506 |
+
"loss": 0.0082,
|
507 |
+
"step": 15000
|
508 |
+
},
|
509 |
+
{
|
510 |
+
"epoch": 25.0,
|
511 |
+
"grad_norm": 0.0018616759916767478,
|
512 |
+
"learning_rate": 3.7500000000000003e-05,
|
513 |
+
"loss": 0.0097,
|
514 |
+
"step": 15500
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"epoch": 25.0,
|
518 |
+
"eval_accuracy": 0.9202057903762736,
|
519 |
+
"eval_f1": 0.8674300779359293,
|
520 |
+
"eval_loss": 0.4943585991859436,
|
521 |
+
"eval_precision": 0.8615130325799393,
|
522 |
+
"eval_recall": 0.8783789279568988,
|
523 |
+
"eval_runtime": 466.3291,
|
524 |
+
"eval_samples_per_second": 21.258,
|
525 |
+
"eval_steps_per_second": 0.665,
|
526 |
+
"step": 15500
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"epoch": 25.806451612903224,
|
530 |
+
"grad_norm": 0.004003328271210194,
|
531 |
+
"learning_rate": 3.548387096774194e-05,
|
532 |
+
"loss": 0.0113,
|
533 |
+
"step": 16000
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"epoch": 26.0,
|
537 |
+
"eval_accuracy": 0.9377584989407848,
|
538 |
+
"eval_f1": 0.8685888939122031,
|
539 |
+
"eval_loss": 0.3813043534755707,
|
540 |
+
"eval_precision": 0.870556797543307,
|
541 |
+
"eval_recall": 0.8777354113914896,
|
542 |
+
"eval_runtime": 698.3043,
|
543 |
+
"eval_samples_per_second": 14.196,
|
544 |
+
"eval_steps_per_second": 0.444,
|
545 |
+
"step": 16120
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"epoch": 26.612903225806452,
|
549 |
+
"grad_norm": 19.115692138671875,
|
550 |
+
"learning_rate": 3.346774193548387e-05,
|
551 |
+
"loss": 0.0081,
|
552 |
+
"step": 16500
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"epoch": 27.0,
|
556 |
+
"eval_accuracy": 0.9331181277110865,
|
557 |
+
"eval_f1": 0.8731036089741192,
|
558 |
+
"eval_loss": 0.433142751455307,
|
559 |
+
"eval_precision": 0.8634855877520003,
|
560 |
+
"eval_recall": 0.886410634024776,
|
561 |
+
"eval_runtime": 912.9943,
|
562 |
+
"eval_samples_per_second": 10.858,
|
563 |
+
"eval_steps_per_second": 0.34,
|
564 |
+
"step": 16740
|
565 |
+
},
|
566 |
+
{
|
567 |
+
"epoch": 27.419354838709676,
|
568 |
+
"grad_norm": 0.0007239320548251271,
|
569 |
+
"learning_rate": 3.1451612903225806e-05,
|
570 |
+
"loss": 0.0071,
|
571 |
+
"step": 17000
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"epoch": 28.0,
|
575 |
+
"eval_accuracy": 0.9340260264299405,
|
576 |
+
"eval_f1": 0.8482513205734062,
|
577 |
+
"eval_loss": 0.401883065700531,
|
578 |
+
"eval_precision": 0.8646923523178319,
|
579 |
+
"eval_recall": 0.8467129210887174,
|
580 |
+
"eval_runtime": 829.4485,
|
581 |
+
"eval_samples_per_second": 11.951,
|
582 |
+
"eval_steps_per_second": 0.374,
|
583 |
+
"step": 17360
|
584 |
+
},
|
585 |
+
{
|
586 |
+
"epoch": 28.225806451612904,
|
587 |
+
"grad_norm": 0.008651689626276493,
|
588 |
+
"learning_rate": 2.9435483870967743e-05,
|
589 |
+
"loss": 0.008,
|
590 |
+
"step": 17500
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"epoch": 29.0,
|
594 |
+
"eval_accuracy": 0.9370523554927872,
|
595 |
+
"eval_f1": 0.8750355664355425,
|
596 |
+
"eval_loss": 0.3932338356971741,
|
597 |
+
"eval_precision": 0.8682629852956361,
|
598 |
+
"eval_recall": 0.8885135737659093,
|
599 |
+
"eval_runtime": 591.0085,
|
600 |
+
"eval_samples_per_second": 16.773,
|
601 |
+
"eval_steps_per_second": 0.525,
|
602 |
+
"step": 17980
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"epoch": 29.032258064516128,
|
606 |
+
"grad_norm": 0.0007982092211022973,
|
607 |
+
"learning_rate": 2.7419354838709678e-05,
|
608 |
+
"loss": 0.0077,
|
609 |
+
"step": 18000
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"epoch": 29.838709677419356,
|
613 |
+
"grad_norm": 0.0016565436962991953,
|
614 |
+
"learning_rate": 2.5403225806451615e-05,
|
615 |
+
"loss": 0.0066,
|
616 |
+
"step": 18500
|
617 |
+
},
|
618 |
+
{
|
619 |
+
"epoch": 30.0,
|
620 |
+
"eval_accuracy": 0.940986583274488,
|
621 |
+
"eval_f1": 0.8921279164005701,
|
622 |
+
"eval_loss": 0.38233834505081177,
|
623 |
+
"eval_precision": 0.9059714000838046,
|
624 |
+
"eval_recall": 0.8867045383809143,
|
625 |
+
"eval_runtime": 470.8987,
|
626 |
+
"eval_samples_per_second": 21.051,
|
627 |
+
"eval_steps_per_second": 0.658,
|
628 |
+
"step": 18600
|
629 |
+
}
|
630 |
+
],
|
631 |
+
"logging_steps": 500,
|
632 |
+
"max_steps": 24800,
|
633 |
+
"num_input_tokens_seen": 0,
|
634 |
+
"num_train_epochs": 40,
|
635 |
+
"save_steps": 500,
|
636 |
+
"stateful_callbacks": {
|
637 |
+
"TrainerControl": {
|
638 |
+
"args": {
|
639 |
+
"should_epoch_stop": false,
|
640 |
+
"should_evaluate": false,
|
641 |
+
"should_log": false,
|
642 |
+
"should_save": true,
|
643 |
+
"should_training_stop": false
|
644 |
+
},
|
645 |
+
"attributes": {}
|
646 |
+
}
|
647 |
+
},
|
648 |
+
"total_flos": 3.607823021993519e+20,
|
649 |
+
"train_batch_size": 8,
|
650 |
+
"trial_name": null,
|
651 |
+
"trial_params": null
|
652 |
+
}
|
ckpt/ser_en_audio/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e954588f3e2bd63d010ca82fe9ecff999d276a6113425e5185f8f5bb8f0caa3b
|
3 |
+
size 5240
|
ckpt/ser_en_text/config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "distilroberta-base",
|
3 |
+
"architectures": [
|
4 |
+
"RobertaForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"gradient_checkpointing": false,
|
10 |
+
"hidden_act": "gelu",
|
11 |
+
"hidden_dropout_prob": 0.1,
|
12 |
+
"hidden_size": 768,
|
13 |
+
"id2label": {
|
14 |
+
"0": "anger",
|
15 |
+
"1": "disgust",
|
16 |
+
"2": "fear",
|
17 |
+
"3": "joy",
|
18 |
+
"4": "neutral",
|
19 |
+
"5": "sadness",
|
20 |
+
"6": "surprise"
|
21 |
+
},
|
22 |
+
"initializer_range": 0.02,
|
23 |
+
"intermediate_size": 3072,
|
24 |
+
"label2id": {
|
25 |
+
"anger": 0,
|
26 |
+
"disgust": 1,
|
27 |
+
"fear": 2,
|
28 |
+
"joy": 3,
|
29 |
+
"neutral": 4,
|
30 |
+
"sadness": 5,
|
31 |
+
"surprise": 6
|
32 |
+
},
|
33 |
+
"layer_norm_eps": 1e-05,
|
34 |
+
"max_position_embeddings": 514,
|
35 |
+
"model_type": "roberta",
|
36 |
+
"num_attention_heads": 12,
|
37 |
+
"num_hidden_layers": 6,
|
38 |
+
"pad_token_id": 1,
|
39 |
+
"position_embedding_type": "absolute",
|
40 |
+
"problem_type": "single_label_classification",
|
41 |
+
"transformers_version": "4.6.1",
|
42 |
+
"type_vocab_size": 1,
|
43 |
+
"use_cache": true,
|
44 |
+
"vocab_size": 50265
|
45 |
+
}
|
ckpt/ser_en_text/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ckpt/ser_en_text/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dde1eadd81741344dd707d1c482a3293810eb895c873053213ccdb2b57ca9e95
|
3 |
+
size 328544361
|
ckpt/ser_en_text/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
|
ckpt/ser_en_text/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ckpt/ser_en_text/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "distilroberta-base"}
|
ckpt/ser_en_text/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ed7a68d54395ab0be21726d6fcf25f942ed459b16387bbf9cf251051986766f
|
3 |
+
size 2415
|
ckpt/ser_en_text/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ckpt/zh-2-en/config.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/tmp/Helsinki-NLP/opus-mt-zh-en",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "swish",
|
5 |
+
"add_bias_logits": false,
|
6 |
+
"add_final_layer_norm": false,
|
7 |
+
"architectures": [
|
8 |
+
"MarianMTModel"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.0,
|
11 |
+
"bad_words_ids": [
|
12 |
+
[
|
13 |
+
65000
|
14 |
+
]
|
15 |
+
],
|
16 |
+
"bos_token_id": 0,
|
17 |
+
"classif_dropout": 0.0,
|
18 |
+
"classifier_dropout": 0.0,
|
19 |
+
"d_model": 512,
|
20 |
+
"decoder_attention_heads": 8,
|
21 |
+
"decoder_ffn_dim": 2048,
|
22 |
+
"decoder_layerdrop": 0.0,
|
23 |
+
"decoder_layers": 6,
|
24 |
+
"decoder_start_token_id": 65000,
|
25 |
+
"decoder_vocab_size": 65001,
|
26 |
+
"dropout": 0.1,
|
27 |
+
"encoder_attention_heads": 8,
|
28 |
+
"encoder_ffn_dim": 2048,
|
29 |
+
"encoder_layerdrop": 0.0,
|
30 |
+
"encoder_layers": 6,
|
31 |
+
"eos_token_id": 0,
|
32 |
+
"extra_pos_embeddings": 65001,
|
33 |
+
"forced_eos_token_id": 0,
|
34 |
+
"id2label": {
|
35 |
+
"0": "LABEL_0",
|
36 |
+
"1": "LABEL_1",
|
37 |
+
"2": "LABEL_2"
|
38 |
+
},
|
39 |
+
"init_std": 0.02,
|
40 |
+
"is_encoder_decoder": true,
|
41 |
+
"label2id": {
|
42 |
+
"LABEL_0": 0,
|
43 |
+
"LABEL_1": 1,
|
44 |
+
"LABEL_2": 2
|
45 |
+
},
|
46 |
+
"max_length": 512,
|
47 |
+
"max_position_embeddings": 512,
|
48 |
+
"model_type": "marian",
|
49 |
+
"normalize_before": false,
|
50 |
+
"normalize_embedding": false,
|
51 |
+
"num_beams": 6,
|
52 |
+
"num_hidden_layers": 6,
|
53 |
+
"pad_token_id": 65000,
|
54 |
+
"scale_embedding": true,
|
55 |
+
"share_encoder_decoder_embeddings": true,
|
56 |
+
"static_position_embeddings": true,
|
57 |
+
"transformers_version": "4.22.0.dev0",
|
58 |
+
"use_cache": true,
|
59 |
+
"vocab_size": 65001
|
60 |
+
}
|
ckpt/zh-2-en/generation_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bad_words_ids": [
|
3 |
+
[
|
4 |
+
65000
|
5 |
+
]
|
6 |
+
],
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"decoder_start_token_id": 65000,
|
9 |
+
"eos_token_id": 0,
|
10 |
+
"forced_eos_token_id": 0,
|
11 |
+
"max_length": 512,
|
12 |
+
"num_beams": 6,
|
13 |
+
"pad_token_id": 65000,
|
14 |
+
"renormalize_logits": true,
|
15 |
+
"transformers_version": "4.32.0.dev0"
|
16 |
+
}
|
ckpt/zh-2-en/metadata.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"hf_name":"zho-eng","source_languages":"zho","target_languages":"eng","opus_readme_url":"https:\/\/github.com\/Helsinki-NLP\/Tatoeba-Challenge\/tree\/master\/models\/zho-eng\/README.md","original_repo":"Tatoeba-Challenge","tags":["translation"],"languages":["zh","en"],"src_constituents":["cmn_Hans","nan","nan_Hani","gan","yue","cmn_Kana","yue_Hani","wuu_Bopo","cmn_Latn","yue_Hira","cmn_Hani","cjy_Hans","cmn","lzh_Hang","lzh_Hira","cmn_Hant","lzh_Bopo","zho","zho_Hans","zho_Hant","lzh_Hani","yue_Hang","wuu","yue_Kana","wuu_Latn","yue_Bopo","cjy_Hant","yue_Hans","lzh","cmn_Hira","lzh_Yiii","lzh_Hans","cmn_Bopo","cmn_Hang","hak_Hani","cmn_Yiii","yue_Hant","lzh_Kana","wuu_Hani"],"tgt_constituents":["eng"],"src_multilingual":false,"tgt_multilingual":false,"prepro":" normalization + SentencePiece (spm32k,spm32k)","url_model":"https:\/\/object.pouta.csc.fi\/Tatoeba-MT-models\/zho-eng\/opus-2020-07-17.zip","url_test_set":"https:\/\/object.pouta.csc.fi\/Tatoeba-MT-models\/zho-eng\/opus-2020-07-17.test.txt","src_alpha3":"zho","tgt_alpha3":"eng","short_pair":"zh-en","chrF2_score":0.548,"bleu":36.1,"brevity_penalty":0.948,"ref_len":82826.0,"src_name":"Chinese","tgt_name":"English","train_date":"2020-07-17","src_alpha2":"zh","tgt_alpha2":"en","prefer_old":false,"long_pair":"zho-eng","helsinki_git_sha":"480fcbe0ee1bf4774bcbe6226ad9f58e63f6c535","transformers_git_sha":"2207e5d8cb224e954a7cba69fa4ac2309e9ff30b","port_machine":"brutasse","port_time":"2020-08-21-14:41"}
|
ckpt/zh-2-en/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d8ceb91d103ef89400c9d9d62328b4858743cf8924878aee3b8afc594242ce0
|
3 |
+
size 312087009
|
ckpt/zh-2-en/rust_model.ot
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:859d0e2531693a5f003ea110aa5cee1b3439cea362980668923126bbb11d56de
|
3 |
+
size 578358061
|
ckpt/zh-2-en/source.spm
ADDED
Binary file (805 kB). View file
|
|
ckpt/zh-2-en/target.spm
ADDED
Binary file (807 kB). View file
|
|
ckpt/zh-2-en/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"target_lang": "eng", "source_lang": "zho"}
|
ckpt/zh-2-en/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|