Spaces:
Running
Running
""" | |
File: model.py | |
Author: Elena Ryumina and Dmitry Ryumin | |
Description: This module provides functions for loading and processing a pre-trained deep learning model | |
for facial expression recognition. | |
License: MIT License | |
""" | |
import torch | |
import requests | |
# Importing necessary components for the Gradio app | |
from app.config import config_data | |
from app.model_architectures import ResNet50, LSTMPyTorch, ExprModelV3 | |
from transformers import AutoFeatureExtractor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_model(model_url, model_path): | |
try: | |
with requests.get(model_url, stream=True) as response: | |
with open(model_path, "wb") as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
return model_path | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
gradients = {} | |
def get_gradients(name): | |
def hook(model, input, output): | |
gradients[name] = output | |
return hook | |
activations = {} | |
def get_activations(name): | |
def hook(model, input, output): | |
activations[name] = output.detach() | |
return hook | |
test_static = torch.rand(1, 3, 224, 224) | |
test_dynamic = torch.rand(1, 10, 512) | |
test_audio = torch.rand(1, 64000) | |
path_static = load_model(config_data.model_static_url, config_data.model_static_path) | |
pth_model_static = ResNet50(7, channels=3) | |
pth_model_static.load_state_dict(torch.load(path_static)) | |
pth_model_static.to(device) | |
pth_model_static.eval() | |
pth_model_static(test_static.to(device)) | |
pth_model_static.layer4.register_full_backward_hook(get_gradients('layer4')) | |
pth_model_static.layer4.register_forward_hook(get_activations('layer4')) | |
pth_model_static.fc1.register_forward_hook(get_activations('features')) | |
path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path) | |
pth_model_dynamic = LSTMPyTorch() | |
pth_model_dynamic.load_state_dict(torch.load(path_dynamic)) | |
pth_model_dynamic.to(device) | |
pth_model_dynamic.eval() | |
pth_model_dynamic(test_dynamic.to(device)) | |
path_audio_model_1 = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | |
path_audio_model_2 = load_model(config_data.model_audio_url, config_data.model_audio_path) | |
audio_processor = AutoFeatureExtractor.from_pretrained(path_audio_model_1) | |
audio_model = ExprModelV3.from_pretrained(path_audio_model_1) | |
audio_model.load_state_dict(torch.load(path_audio_model_2)["model_state_dict"]) | |
audio_model.to(device) | |
audio_model.eval() | |
audio_model(test_audio.to(device)) |