Spaces:
Sleeping
Sleeping
Marcos12886
commited on
Commit
·
763091b
1
Parent(s):
1565b0a
TODO FUNCIONANDO. Igual que github
Browse files
README.md
CHANGED
@@ -1,13 +1,31 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Instalación
|
2 |
+
La instalación y uso están pensados para una gráfica NVIDIA decentilla. Si no dispones de una gráfica NVIDIA, ejecutar en las gráficas de Colab.
|
3 |
+
|
4 |
+
Instalaciones necesarias para local:
|
5 |
+
- pip install transformers[torch] gradio tensorboardX scikit-learn
|
6 |
+
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
7 |
+
|
8 |
+
#### GitHub
|
9 |
+
En el archivo .gitignore están puestas las carpetas que no se deben subir a github.
|
10 |
+
|
11 |
+
## Estructura
|
12 |
+
Dos funcionalidades:
|
13 |
+
- Monitor de bebés: identificar si tu bebé llora y por qué
|
14 |
+
- Clasificador de llantos: conocer por qué llora tu bebé
|
15 |
+
- Chatbot: poder hablar con llama 3 8B sobre las preocupaciones con tu bebé
|
16 |
+
|
17 |
+
Flujo de archivos:
|
18 |
+
1. Construir la estructura de los modelos y entrenarlos [model.py](model.py)
|
19 |
+
2. Chatbot en el que grabar audio y conectar con el llm [app.py](app.py)
|
20 |
+
|
21 |
+
Un modelo ([model.py](model.py)) entrenado con distintos datos:
|
22 |
+
- Modelo para monitorizar: --n monitor
|
23 |
+
- Modelo clasificador de llantos: --n class
|
24 |
+
|
25 |
+
Chatbot [app.py](app.py)
|
26 |
+
|
27 |
+
### Datos utilizados
|
28 |
+
- https://data.mendeley.com/datasets/hbppd883sd/1
|
29 |
+
- https://zenodo.org/records/2535878
|
30 |
+
- https://paperswithcode.com/dataset/esc50
|
31 |
+
- https://osf.io/usr8d
|
app.py
CHANGED
@@ -3,25 +3,24 @@ import torch
|
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import InferenceClient
|
5 |
from model import predict_params, AudioDataset
|
6 |
-
|
7 |
-
|
8 |
token = os.getenv("HF_TOKEN")
|
9 |
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
model_class, id2label_class = predict_params(
|
12 |
-
model_path="
|
13 |
dataset_path="data/mixed_data",
|
14 |
filter_white_noise=True,
|
15 |
undersample_normal=True
|
16 |
)
|
17 |
model_mon, id2label_mon = predict_params(
|
18 |
-
model_path="
|
19 |
dataset_path="data/baby_cry_detection",
|
20 |
filter_white_noise=False,
|
21 |
undersample_normal=False
|
22 |
)
|
23 |
|
24 |
-
def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal):
|
25 |
model.to(device)
|
26 |
model.eval()
|
27 |
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal)
|
@@ -34,10 +33,10 @@ def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal)
|
|
34 |
|
35 |
def predict(audio_path_pred):
|
36 |
with torch.no_grad():
|
37 |
-
logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=
|
38 |
predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
|
39 |
label_class = id2label_class[predicted_class_ids_class]
|
40 |
-
label_mapping = {0: '
|
41 |
label_class = label_mapping.get(predicted_class_ids_class, label_class)
|
42 |
return label_class
|
43 |
|
@@ -49,9 +48,9 @@ def predict_stream(audio_path_stream):
|
|
49 |
avg_crying_probability = crying_probabilities.mean()*100
|
50 |
if avg_crying_probability < 15:
|
51 |
label_class = predict(audio_path_stream)
|
52 |
-
return "Está llorando por:
|
53 |
else:
|
54 |
-
return "No está llorando."
|
55 |
|
56 |
def decibelios(audio_path_stream):
|
57 |
with torch.no_grad():
|
@@ -70,15 +69,15 @@ def mostrar_decibelios(audio_path_stream, visual_threshold):
|
|
70 |
def predict_stream_decib(audio_path_stream, visual_threshold):
|
71 |
db_level = decibelios(audio_path_stream)
|
72 |
if db_level < visual_threshold:
|
73 |
-
llorando
|
74 |
-
return f"{llorando}
|
75 |
else:
|
76 |
return ""
|
77 |
|
78 |
def chatbot_config(message, history: list[tuple[str, str]]):
|
79 |
system_message = "You are a Chatbot specialized in baby health and care."
|
80 |
max_tokens = 512
|
81 |
-
temperature = 0.
|
82 |
top_p = 0.95
|
83 |
messages = [{"role": "system", "content": system_message}]
|
84 |
for val in history:
|
@@ -96,25 +95,100 @@ def chatbot_config(message, history: list[tuple[str, str]]):
|
|
96 |
def cambiar_pestaña():
|
97 |
return gr.update(visible=False), gr.update(visible=True)
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
with gr.Blocks(theme=my_theme) as demo:
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
)
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
with gr.Column(visible=False) as pag_predictor:
|
117 |
-
gr.Markdown("<h2>
|
118 |
audio_input = gr.Audio(
|
119 |
min_length=1.0,
|
120 |
format="wav",
|
@@ -126,7 +200,7 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
126 |
inputs=audio_input,
|
127 |
outputs=gr.Textbox(label="Tu bebé llora por:")
|
128 |
)
|
129 |
-
gr.Button("Volver
|
130 |
with gr.Column(visible=False) as pag_monitor:
|
131 |
gr.Markdown("<h2>Monitor</h2>")
|
132 |
audio_stream = gr.Audio(
|
@@ -140,7 +214,7 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
140 |
maximum=100,
|
141 |
step=1,
|
142 |
value=30,
|
143 |
-
label="
|
144 |
)
|
145 |
audio_stream.stream(
|
146 |
mostrar_decibelios,
|
@@ -152,7 +226,8 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
152 |
inputs=[audio_stream, threshold_db],
|
153 |
outputs=gr.Textbox(value="", label="Tu bebé:")
|
154 |
)
|
155 |
-
gr.Button("Volver
|
|
|
156 |
boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
|
157 |
boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
|
158 |
demo.launch(share=True)
|
|
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import InferenceClient
|
5 |
from model import predict_params, AudioDataset
|
6 |
+
# TODO: Que no diga lo de que no hay 1s_normal al predecir
|
|
|
7 |
token = os.getenv("HF_TOKEN")
|
8 |
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
model_class, id2label_class = predict_params(
|
11 |
+
model_path="distilhubert-finetuned-mixed-data",
|
12 |
dataset_path="data/mixed_data",
|
13 |
filter_white_noise=True,
|
14 |
undersample_normal=True
|
15 |
)
|
16 |
model_mon, id2label_mon = predict_params(
|
17 |
+
model_path="distilhubert-finetuned-cry-detector",
|
18 |
dataset_path="data/baby_cry_detection",
|
19 |
filter_white_noise=False,
|
20 |
undersample_normal=False
|
21 |
)
|
22 |
|
23 |
+
def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
|
24 |
model.to(device)
|
25 |
model.eval()
|
26 |
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal)
|
|
|
33 |
|
34 |
def predict(audio_path_pred):
|
35 |
with torch.no_grad():
|
36 |
+
logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False)
|
37 |
predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
|
38 |
label_class = id2label_class[predicted_class_ids_class]
|
39 |
+
label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'}
|
40 |
label_class = label_mapping.get(predicted_class_ids_class, label_class)
|
41 |
return label_class
|
42 |
|
|
|
48 |
avg_crying_probability = crying_probabilities.mean()*100
|
49 |
if avg_crying_probability < 15:
|
50 |
label_class = predict(audio_path_stream)
|
51 |
+
return f"Está llorando por: {label_class}"
|
52 |
else:
|
53 |
+
return "No está llorando."
|
54 |
|
55 |
def decibelios(audio_path_stream):
|
56 |
with torch.no_grad():
|
|
|
69 |
def predict_stream_decib(audio_path_stream, visual_threshold):
|
70 |
db_level = decibelios(audio_path_stream)
|
71 |
if db_level < visual_threshold:
|
72 |
+
llorando = predict_stream(audio_path_stream)
|
73 |
+
return f"{llorando}"
|
74 |
else:
|
75 |
return ""
|
76 |
|
77 |
def chatbot_config(message, history: list[tuple[str, str]]):
|
78 |
system_message = "You are a Chatbot specialized in baby health and care."
|
79 |
max_tokens = 512
|
80 |
+
temperature = 0.5
|
81 |
top_p = 0.95
|
82 |
messages = [{"role": "system", "content": system_message}]
|
83 |
for val in history:
|
|
|
95 |
def cambiar_pestaña():
|
96 |
return gr.update(visible=False), gr.update(visible=True)
|
97 |
|
98 |
+
my_theme = gr.themes.Soft(
|
99 |
+
primary_hue="emerald",
|
100 |
+
secondary_hue="green",
|
101 |
+
neutral_hue="slate",
|
102 |
+
text_size="sm",
|
103 |
+
spacing_size="sm",
|
104 |
+
font=[gr.themes.GoogleFont('Nunito'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
105 |
+
font_mono=[gr.themes.GoogleFont('Nunito'), 'ui-monospace', 'Consolas', 'monospace'],
|
106 |
+
).set(
|
107 |
+
body_background_fill='*neutral_50',
|
108 |
+
body_text_color='*neutral_600',
|
109 |
+
body_text_size='*text_sm',
|
110 |
+
embed_radius='*radius_md',
|
111 |
+
shadow_drop='*shadow_spread',
|
112 |
+
shadow_spread='*button_shadow_active'
|
113 |
+
)
|
114 |
+
|
115 |
with gr.Blocks(theme=my_theme) as demo:
|
116 |
+
with gr.Column(visible=True) as inicial:
|
117 |
+
gr.HTML(
|
118 |
+
"""
|
119 |
+
<style>
|
120 |
+
@import url('https://fonts.googleapis.com/css2?family=Lobster&display=swap');
|
121 |
+
@import url('https://fonts.googleapis.com/css2?family=Roboto&display=swap');
|
122 |
+
|
123 |
+
h1 {
|
124 |
+
font-family: 'Lobster', cursive;
|
125 |
+
font-size: 5em !important;
|
126 |
+
text-align: center;
|
127 |
+
margin: 0;
|
128 |
+
}
|
129 |
+
|
130 |
+
.gr-button {
|
131 |
+
background-color: #4CAF50 !important;
|
132 |
+
color: white !important;
|
133 |
+
border: none;
|
134 |
+
padding: 25px 50px; /* Increase the padding for bigger buttons */
|
135 |
+
text-align: center;
|
136 |
+
text-decoration: none;
|
137 |
+
display: inline-block;
|
138 |
+
font-family: 'Lobster', cursive; /* Apply the Lobster font */
|
139 |
+
font-size: 2em !important; /* Increase the button text size */
|
140 |
+
margin: 4px 2px;
|
141 |
+
cursor: pointer;
|
142 |
+
border-radius: 12px;
|
143 |
+
}
|
144 |
+
|
145 |
+
.gr-button:hover {
|
146 |
+
background-color: #45a049;
|
147 |
+
}
|
148 |
+
h2 {
|
149 |
+
font-family: 'Lobster', cursive;
|
150 |
+
font-size: 3em !important;
|
151 |
+
text-align: center;
|
152 |
+
margin: 0;
|
153 |
+
}
|
154 |
+
p.slogan, h4, p, h3 {
|
155 |
+
font-family: 'Roboto', sans-serif;
|
156 |
+
text-align: center;
|
157 |
+
}
|
158 |
+
</style>
|
159 |
+
<h1>Iremia</h1>
|
160 |
+
<h4 style='text-align: center; font-size: 1.5em'>Tu aliado para el bienestar de tu bebé</h4>
|
161 |
+
"""
|
162 |
+
)
|
163 |
+
gr.Markdown(
|
164 |
+
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué es Iremia?</h4>"
|
165 |
+
"<p style='text-align: left'>Iremia es un proyecto llevado a cabo por un grupo de estudiantes interesados en el desarrollo de modelos de inteligencia artificial, enfocados específicamente en casos de uso relevantes para ayudar a cuidar a los más pequeños de la casa.</p>"
|
166 |
+
"<h4 style='text-align: left; font-size: 1.5em;'>Nuestra misión</h4>"
|
167 |
+
"<p style='text-align: left'>Sabemos que la paternidad puede suponer un gran desafío. Nuestra misión es brindarles a todos los padres unas herramientas de última tecnología que los ayuden a navegar esos primeros meses de vida tan cruciales en el desarrollo de sus pequeños.</p>"
|
168 |
+
"<h4 style='text-align: left; font-size: 1.5em;'>¿Qué ofrece Iremia?</h4>"
|
169 |
+
"<p style='text-align: left'>Chatbot: Pregunta a nuestro asistente que te ayudará con cualquier duda que tengas sobre el cuidado de tu bebé.</p>"
|
170 |
+
"<p style='text-align: left'>Analizador: Con nuestro modelo de inteligencia artificial somos capaces de predecir por qué tu hijo de menos de 2 años está llorando.</p>"
|
171 |
+
"<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
|
172 |
+
)
|
173 |
+
boton_inicial = gr.Button("Comenzar")
|
174 |
+
with gr.Column(visible=False) as chatbot:
|
175 |
+
gr.Markdown("<h2>Asistente</h2>")
|
176 |
+
gr.ChatInterface(
|
177 |
+
chatbot_config,
|
178 |
+
theme=my_theme,
|
179 |
+
retry_btn=None,
|
180 |
+
undo_btn=None,
|
181 |
+
clear_btn="Limpiar 🗑️",
|
182 |
+
autofocus=True,
|
183 |
+
fill_height=True,
|
184 |
+
)
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column():
|
187 |
+
boton_predictor = gr.Button("Analizador")
|
188 |
+
with gr.Column():
|
189 |
+
boton_monitor = gr.Button("Monitor")
|
190 |
with gr.Column(visible=False) as pag_predictor:
|
191 |
+
gr.Markdown("<h2>Analizador</h2>")
|
192 |
audio_input = gr.Audio(
|
193 |
min_length=1.0,
|
194 |
format="wav",
|
|
|
200 |
inputs=audio_input,
|
201 |
outputs=gr.Textbox(label="Tu bebé llora por:")
|
202 |
)
|
203 |
+
gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
|
204 |
with gr.Column(visible=False) as pag_monitor:
|
205 |
gr.Markdown("<h2>Monitor</h2>")
|
206 |
audio_stream = gr.Audio(
|
|
|
214 |
maximum=100,
|
215 |
step=1,
|
216 |
value=30,
|
217 |
+
label="Decibelios para activar la predicción:"
|
218 |
)
|
219 |
audio_stream.stream(
|
220 |
mostrar_decibelios,
|
|
|
226 |
inputs=[audio_stream, threshold_db],
|
227 |
outputs=gr.Textbox(value="", label="Tu bebé:")
|
228 |
)
|
229 |
+
gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
|
230 |
+
boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot])
|
231 |
boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
|
232 |
boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
|
233 |
demo.launch(share=True)
|
model.py
CHANGED
@@ -30,7 +30,7 @@ class AudioDataset(Dataset):
|
|
30 |
self.dataset_path = dataset_path
|
31 |
self.label2id = label2id
|
32 |
self.file_paths = []
|
33 |
-
self.filter_white_noise = filter_white_noise
|
34 |
self.labels = []
|
35 |
for label_dir, label_id in self.label2id.items():
|
36 |
label_path = os.path.join(self.dataset_path, label_dir)
|
@@ -39,33 +39,25 @@ class AudioDataset(Dataset):
|
|
39 |
audio_path = os.path.join(label_path, file_name)
|
40 |
self.file_paths.append(audio_path)
|
41 |
self.labels.append(label_id)
|
42 |
-
if undersample_normal:
|
43 |
self.undersample_normal_class()
|
44 |
|
45 |
def undersample_normal_class(self):
|
46 |
normal_label = self.label2id.get('1s_normal')
|
47 |
-
if normal_label is None:
|
48 |
-
print("Warning: No '1s_normal' class found. Skipping undersampling.")
|
49 |
-
return
|
50 |
label_counts = Counter(self.labels)
|
51 |
other_counts = [count for label, count in label_counts.items() if label != normal_label]
|
52 |
-
if
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
if label != normal_label or i in keep_indices:
|
65 |
-
new_file_paths.append(path)
|
66 |
-
new_labels.append(label)
|
67 |
-
self.file_paths = new_file_paths
|
68 |
-
self.labels = new_labels
|
69 |
|
70 |
def __len__(self):
|
71 |
return len(self.file_paths)
|
@@ -107,12 +99,11 @@ def is_white_noise(audio):
|
|
107 |
std = torch.std(audio)
|
108 |
return torch.abs(mean) < 0.001 and std < 0.01
|
109 |
|
110 |
-
def seed_everything():
|
111 |
torch.manual_seed(seed)
|
112 |
torch.cuda.manual_seed(seed)
|
113 |
-
torch.backends.cudnn.deterministic = True
|
114 |
-
torch.backends.cudnn.benchmark = False
|
115 |
-
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16384:8'
|
116 |
|
117 |
def build_label_mappings(dataset_path):
|
118 |
label2id = {}
|
@@ -165,10 +156,10 @@ def load_model(model_path, id2label, num_labels):
|
|
165 |
finetuning_task="audio-classification"
|
166 |
)
|
167 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
168 |
-
model = HubertForSequenceClassification.from_pretrained(
|
169 |
pretrained_model_name_or_path=model_path,
|
170 |
config=config,
|
171 |
-
torch_dtype=torch.float32,
|
172 |
)
|
173 |
model.to(device)
|
174 |
return model
|
|
|
30 |
self.dataset_path = dataset_path
|
31 |
self.label2id = label2id
|
32 |
self.file_paths = []
|
33 |
+
self.filter_white_noise = filter_white_noise
|
34 |
self.labels = []
|
35 |
for label_dir, label_id in self.label2id.items():
|
36 |
label_path = os.path.join(self.dataset_path, label_dir)
|
|
|
39 |
audio_path = os.path.join(label_path, file_name)
|
40 |
self.file_paths.append(audio_path)
|
41 |
self.labels.append(label_id)
|
42 |
+
if undersample_normal and self.label2id:
|
43 |
self.undersample_normal_class()
|
44 |
|
45 |
def undersample_normal_class(self):
|
46 |
normal_label = self.label2id.get('1s_normal')
|
|
|
|
|
|
|
47 |
label_counts = Counter(self.labels)
|
48 |
other_counts = [count for label, count in label_counts.items() if label != normal_label]
|
49 |
+
if other_counts: # Ensure there are other counts before taking max
|
50 |
+
target_count = max(other_counts)
|
51 |
+
normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label]
|
52 |
+
keep_indices = random.sample(normal_indices, target_count)
|
53 |
+
new_file_paths = []
|
54 |
+
new_labels = []
|
55 |
+
for i, (path, label) in enumerate(zip(self.file_paths, self.labels)):
|
56 |
+
if label != normal_label or i in keep_indices:
|
57 |
+
new_file_paths.append(path)
|
58 |
+
new_labels.append(label)
|
59 |
+
self.file_paths = new_file_paths
|
60 |
+
self.labels = new_labels
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def __len__(self):
|
63 |
return len(self.file_paths)
|
|
|
99 |
std = torch.std(audio)
|
100 |
return torch.abs(mean) < 0.001 and std < 0.01
|
101 |
|
102 |
+
def seed_everything(): # TODO: mirar si es necesario algo más
|
103 |
torch.manual_seed(seed)
|
104 |
torch.cuda.manual_seed(seed)
|
105 |
+
# torch.backends.cudnn.deterministic = True # Para reproducibilidad
|
106 |
+
# torch.backends.cudnn.benchmark = False # Para reproducibilidad
|
|
|
107 |
|
108 |
def build_label_mappings(dataset_path):
|
109 |
label2id = {}
|
|
|
156 |
finetuning_task="audio-classification"
|
157 |
)
|
158 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
159 |
+
model = HubertForSequenceClassification.from_pretrained(
|
160 |
pretrained_model_name_or_path=model_path,
|
161 |
config=config,
|
162 |
+
torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16
|
163 |
)
|
164 |
model.to(device)
|
165 |
return model
|