Marcos12886 commited on
Commit
166aa6c
·
1 Parent(s): abdf62b

Decibelios. Llamar modelos mejor. Mejorar botones...

Browse files
Files changed (3) hide show
  1. app.py +69 -72
  2. interfaz.py +2 -2
  3. model.py +9 -9
app.py CHANGED
@@ -7,71 +7,63 @@ from interfaz import estilo, my_theme
7
 
8
  token = os.getenv("HF_TOKEN")
9
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
10
- model_cache = {}
 
 
11
 
12
- def load_model_and_dataset(model_path, dataset_path, filter_white_noise):
13
- if (model_path, dataset_path, filter_white_noise) not in model_cache:
14
- model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
15
- model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
16
- return model_cache[(model_path, dataset_path, filter_white_noise)]
17
-
18
- def predict(audio_path, model_path, dataset_path, filter_white_noise):
19
- model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model.to(device)
22
  model.eval()
23
- audios = AudioDataset(dataset_path, {}, filter_white_noise).preprocess_audio(audio_path)
24
- inputs = {"input_values": audios.to(device).unsqueeze(0)}
 
25
  with torch.no_grad():
26
  outputs = model(**inputs)
27
  logits = outputs.logits
28
- predicted_class_ids = torch.argmax(logits, dim=-1).item()
29
- label = id2label[predicted_class_ids]
30
- if dataset_path == "data/mixed_data":
31
- label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
32
- label = label_mapping.get(predicted_class_ids, label)
33
- return label
34
 
35
- def predict_stream(audio_path):
36
- model_mon, _ = load_model_and_dataset(
37
- model_path="distilhubert-finetuned-cry-detector",
38
- dataset_path="data/baby_cry_detection",
39
- filter_white_noise=False
40
- )
41
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- model_mon.to(device)
43
- model_mon.eval()
44
- audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
45
- processed_audio = audio_dataset.preprocess_audio(audio_path)
46
- inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
47
  with torch.no_grad():
48
- outputs = model_mon(**inputs)
49
- logits = outputs.logits
 
 
 
 
 
 
 
 
50
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
51
  crying_probabilities = probabilities[:, 1]
52
- avg_crying_probability = crying_probabilities.mean().item()*100
53
- if avg_crying_probability < 25:
54
- model_class, id2label = load_model_and_dataset(
55
- model_path="distilhubert-finetuned-mixed-data",
56
- dataset_path="data/mixed_data",
57
- filter_white_noise=True
58
- )
59
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
- model_class.to(device)
61
- model_class.eval()
62
- audio_dataset_class = AudioDataset(dataset_path="data/mixed_data", label2id={}, filter_white_noise=True)
63
- processed_audio_class = audio_dataset_class.preprocess_audio(audio_path)
64
- inputs_class = {"input_values": processed_audio_class.to(device).unsqueeze(0)}
65
- with torch.no_grad():
66
- outputs_class = model_class(**inputs_class)
67
- logits_class = outputs_class.logits
68
- predicted_class_ids_class = torch.argmax(logits_class, dim=-1).item()
69
- label_class = id2label[predicted_class_ids_class]
70
- label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
71
- label_class = label_mapping.get(predicted_class_ids_class, label_class)
72
- return f"Bebé llorando por {label_class}. Probabilidad: {avg_crying_probability:.1f})"
 
 
 
 
 
73
  else:
74
- return f"No está llorando. Proabilidad: {avg_crying_probability:.1f})"
75
 
76
  def chatbot_config(message, history: list[tuple[str, str]]):
77
  system_message = "You are a Chatbot specialized in baby health and care."
@@ -105,12 +97,12 @@ with gr.Blocks(theme=my_theme) as demo:
105
  with gr.Row():
106
  with gr.Column():
107
  gr.Markdown("<h2>Predictor</h2>")
108
- boton_pagina_1 = gr.Button("Prueba el predictor")
109
- gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
110
  with gr.Column():
111
  gr.Markdown("<h2>Monitor</h2>")
112
- boton_pagina_2 = gr.Button("Prueba el monitor")
113
- gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
114
  with gr.Column(visible=False) as pag_predictor:
115
  gr.Markdown("<h2>Predictor</h2>")
116
  audio_input = gr.Audio(
@@ -119,14 +111,8 @@ with gr.Blocks(theme=my_theme) as demo:
119
  label="Baby recorder",
120
  type="filepath",
121
  )
122
- classify_btn = gr.Button("¿Por qué llora?")
123
- classify_btn.click(
124
- lambda audio: predict( # Mirar porque usar lambda
125
- audio,
126
- model_path="distilhubert-finetuned-mixed-data",
127
- dataset_path="data/mixed_data",
128
- filter_white_noise=True
129
- ),
130
  inputs=audio_input,
131
  outputs=gr.Textbox(label="Tu bebé llora por:")
132
  )
@@ -134,18 +120,29 @@ with gr.Blocks(theme=my_theme) as demo:
134
  with gr.Column(visible=False) as pag_monitor:
135
  gr.Markdown("<h2>Monitor</h2>")
136
  audio_stream = gr.Audio(
137
- # min_length=1.0, # mirar por qué no va esto
138
  format="wav",
139
  label="Baby recorder",
140
  type="filepath",
141
  streaming=True
142
  )
 
 
 
 
 
 
 
 
 
 
 
 
143
  audio_stream.stream(
144
- predict_stream,
145
- inputs=audio_stream,
146
- outputs=gr.Textbox(label="Tu bebé está:"),
147
  )
148
  gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
149
- boton_pagina_1.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
150
- boton_pagina_2.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
151
  demo.launch(share=True)
 
7
 
8
  token = os.getenv("HF_TOKEN")
9
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model_class, id2label_class = predict_params(model_path="distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True)
12
+ model_mon, id2label_mon = predict_params(model_path="distilhubert-finetuned-cry-detector", dataset_path="data/baby_cry_detection", filter_white_noise=False)
13
 
14
+ def call(audiopath, model, dataset_path, filter_white_noise):
 
 
 
 
 
 
 
 
15
  model.to(device)
16
  model.eval()
17
+ audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
18
+ processed_audio = audio_dataset.preprocess_audio(audiopath)
19
+ inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
20
  with torch.no_grad():
21
  outputs = model(**inputs)
22
  logits = outputs.logits
23
+ return logits
 
 
 
 
 
24
 
25
+ def predict(audio_path_pred):
 
 
 
 
 
 
 
 
 
 
 
26
  with torch.no_grad():
27
+ logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True)
28
+ predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
29
+ label_class = id2label_class[predicted_class_ids_class]
30
+ label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
31
+ label_class = label_mapping.get(predicted_class_ids_class, label_class)
32
+ return label_class
33
+
34
+ def predict_stream(audio_path_stream):
35
+ with torch.no_grad():
36
+ logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False)
37
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
38
  crying_probabilities = probabilities[:, 1]
39
+ avg_crying_probability = crying_probabilities.mean()*100
40
+ if avg_crying_probability < 15:
41
+ label_class = predict(audio_path_stream)
42
+ return "Está llorando por:", f"{label_class}. Probabilidad: {avg_crying_probability:.1f}%"
43
+ else:
44
+ return "No está llorando.", f"Probabilidad: {avg_crying_probability:.1f}%"
45
+
46
+ def decibelios(audio_path_stream):
47
+ with torch.no_grad():
48
+ logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False)
49
+ rms = torch.sqrt(torch.mean(torch.square(logits)))
50
+ db_level = 20 * torch.log10(rms + 1e-6).item()
51
+ return db_level
52
+
53
+ def mostrar_decibelios(audio_path_stream, visual_threshold):
54
+ db_level = decibelios(audio_path_stream)
55
+ if db_level < visual_threshold:
56
+ return f"Prediciendo. Decibelios: {db_level:.2f}"
57
+ elif db_level > visual_threshold:
58
+ return "No detectamos ruido..."
59
+
60
+ def predict_stream_decib(audio_path_stream, visual_threshold):
61
+ db_level = decibelios(audio_path_stream)
62
+ if db_level < visual_threshold:
63
+ llorando, probabilidad = predict_stream(audio_path_stream)
64
+ return f"{llorando} {probabilidad}"
65
  else:
66
+ return ""
67
 
68
  def chatbot_config(message, history: list[tuple[str, str]]):
69
  system_message = "You are a Chatbot specialized in baby health and care."
 
97
  with gr.Row():
98
  with gr.Column():
99
  gr.Markdown("<h2>Predictor</h2>")
100
+ boton_predictor = gr.Button("Prueba el predictor")
101
+ gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
102
  with gr.Column():
103
  gr.Markdown("<h2>Monitor</h2>")
104
+ boton_monitor = gr.Button("Prueba el monitor")
105
+ gr.Markdown("<p>Monitoriza si tu hijo está llorando y por qué, sin levantarte del sofá</p>")
106
  with gr.Column(visible=False) as pag_predictor:
107
  gr.Markdown("<h2>Predictor</h2>")
108
  audio_input = gr.Audio(
 
111
  label="Baby recorder",
112
  type="filepath",
113
  )
114
+ gr.Button("¿Por qué llora?").click(
115
+ predict,
 
 
 
 
 
 
116
  inputs=audio_input,
117
  outputs=gr.Textbox(label="Tu bebé llora por:")
118
  )
 
120
  with gr.Column(visible=False) as pag_monitor:
121
  gr.Markdown("<h2>Monitor</h2>")
122
  audio_stream = gr.Audio(
 
123
  format="wav",
124
  label="Baby recorder",
125
  type="filepath",
126
  streaming=True
127
  )
128
+ threshold_db = gr.Slider(
129
+ minimum=0,
130
+ maximum=100,
131
+ step=1,
132
+ value=30,
133
+ label="Umbral de dB para activar la predicción"
134
+ )
135
+ audio_stream.stream(
136
+ mostrar_decibelios,
137
+ inputs=[audio_stream, threshold_db],
138
+ outputs=gr.Textbox(value="Esperando...", label="Estado")
139
+ )
140
  audio_stream.stream(
141
+ predict_stream_decib,
142
+ inputs=[audio_stream, threshold_db],
143
+ outputs=gr.Textbox(value="", label="Tu bebé:")
144
  )
145
  gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
146
+ boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
147
+ boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
148
  demo.launch(share=True)
interfaz.py CHANGED
@@ -93,9 +93,9 @@ def inicio():
93
  with gr.Column():
94
  gr.Markdown("<h2>Predictor</h2>")
95
  boton_pagina_1 = gr.Button("Prueba el predictor")
96
- gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
97
  with gr.Column():
98
  gr.Markdown("<h2>Monitor</h2>")
99
  boton_pagina_2 = gr.Button("Prueba el monitor")
100
- gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
101
  return boton_pagina_1, boton_pagina_2
 
93
  with gr.Column():
94
  gr.Markdown("<h2>Predictor</h2>")
95
  boton_pagina_1 = gr.Button("Prueba el predictor")
96
+ gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
97
  with gr.Column():
98
  gr.Markdown("<h2>Monitor</h2>")
99
  boton_pagina_2 = gr.Button("Prueba el monitor")
100
+ gr.Markdown("<p>Detecta si tu hijo está llorando y por qué antes de que puedas levantarte del sofá</p>")
101
  return boton_pagina_1, boton_pagina_2
model.py CHANGED
@@ -5,8 +5,8 @@ import torch
5
  import torchaudio
6
  from torch.utils.data import Dataset, DataLoader
7
  from huggingface_hub import upload_folder
8
- from transformers.integrations import TensorBoardCallback
9
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support
 
10
  from transformers import (
11
  Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
12
  Trainer, TrainingArguments,
@@ -121,7 +121,7 @@ def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=T
121
  )
122
  return train_dataloader, test_dataloader, label2id, id2label
123
 
124
- def load_model(model_path, num_labels, label2id, id2label):
125
  config = HubertConfig.from_pretrained(
126
  pretrained_model_name_or_path=model_path,
127
  num_labels=num_labels,
@@ -140,13 +140,13 @@ def load_model(model_path, num_labels, label2id, id2label):
140
 
141
  def train_params(dataset_path, filter_white_noise):
142
  train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
143
- model = load_model(model_path=MODEL, num_labels=len(id2label), label2id=label2id, id2label=id2label)
144
  return model, train_dataloader, test_dataloader, id2label
145
 
146
  def predict_params(dataset_path, model_path, filter_white_noise):
147
  _, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
148
- model = load_model(model_path, num_labels=len(id2label), label2id=label2id, id2label=id2label)
149
- return model, None, None, id2label
150
 
151
  def compute_metrics(eval_pred):
152
  predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
@@ -187,10 +187,10 @@ def load_config(model_name):
187
  return model_config
188
 
189
  if __name__ == "__main__":
190
- config = load_config(clasificador) # PARA CAMBIAR MODELOS
191
- filter_white_noise = True
192
- # config = load_config(monitor) # PARA CAMBIAR MODELOS
193
- # filter_white_noise = False
194
  training_args = config["training_args"]
195
  output_dir = config["output_dir"]
196
  dataset_path = config["dataset_path"]
 
5
  import torchaudio
6
  from torch.utils.data import Dataset, DataLoader
7
  from huggingface_hub import upload_folder
 
8
  from sklearn.metrics import accuracy_score, precision_recall_fscore_support
9
+ from transformers.integrations import TensorBoardCallback
10
  from transformers import (
11
  Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
12
  Trainer, TrainingArguments,
 
121
  )
122
  return train_dataloader, test_dataloader, label2id, id2label
123
 
124
+ def load_model(model_path, label2id, id2label, num_labels):
125
  config = HubertConfig.from_pretrained(
126
  pretrained_model_name_or_path=model_path,
127
  num_labels=num_labels,
 
140
 
141
  def train_params(dataset_path, filter_white_noise):
142
  train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
143
+ model = load_model(MODEL, label2id, id2label, num_labels=len(id2label))
144
  return model, train_dataloader, test_dataloader, id2label
145
 
146
  def predict_params(dataset_path, model_path, filter_white_noise):
147
  _, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
148
+ model = load_model(model_path, label2id, id2label, num_labels=len(id2label))
149
+ return model, id2label
150
 
151
  def compute_metrics(eval_pred):
152
  predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
 
187
  return model_config
188
 
189
  if __name__ == "__main__":
190
+ # config = load_config(clasificador) # PARA CAMBIAR MODELOS
191
+ # filter_white_noise = True
192
+ config = load_config(monitor) # PARA CAMBIAR MODELOS
193
+ filter_white_noise = False
194
  training_args = config["training_args"]
195
  output_dir = config["output_dir"]
196
  dataset_path = config["dataset_path"]