Spaces:
Running
Running
import base64 | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import io | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
app = FastAPI() | |
# Inicializamos el modelo de clasificaci贸n de edad y el extractor | |
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') | |
transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') | |
# Mapeo de 铆ndices de clase a rangos de edad | |
age_mapping = [ | |
"0-2", "3-6", "7-9", "10-12", "13-15", | |
"16-19", "20-24", "25-29", "30-34", "35-39", | |
"40-44", "45-49", "50-54", "55-59", "60-64", | |
"65-69", "70+" | |
] | |
# Endpoint para predecir la edad de los rostros detectados en una imagen | |
async def predict_age(file: UploadFile = File(...)): | |
""" | |
Endpoint para predecir el rango de edad de los rostros detectados en una imagen. | |
""" | |
try: | |
# Leer la imagen cargada | |
image_bytes = await file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Convertimos a RGB si es necesario | |
# Convertir la imagen a formato NumPy para usar OpenCV | |
img_np = np.array(image) | |
# Cargar el clasificador Haar para detecci贸n de rostros | |
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
# Convertir la imagen a escala de grises para la detecci贸n de rostros | |
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) | |
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
# Verificamos si se detectaron rostros | |
if len(faces) == 0: | |
raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.") | |
# Lista para almacenar los resultados de predicci贸n de cada rostro | |
results = [] | |
for (x, y, w, h) in faces: | |
# Extraer la regi贸n del rostro | |
face_img = img_np[y:y+h, x:x+w] | |
pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) | |
# Aplicar la transformaci贸n y hacer la predicci贸n de edad | |
inputs = transforms(pil_face_img, return_tensors='pt') | |
outputs = model(**inputs) | |
# Calcular probabilidades y predicci贸n | |
proba = outputs.logits.softmax(1) | |
preds = proba.argmax(1).item() # 脥ndice de la clase predicha | |
predicted_age_range = age_mapping[preds] | |
# Dibujar un rect谩ngulo alrededor del rostro y agregar la edad predicha | |
cv2.rectangle(img_np, (x, y), (x + w, y + h), (255, 0, 0), 2) | |
cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2) | |
# Guardar el resultado de la predicci贸n de edad y las coordenadas del rostro | |
results.append({ | |
"edad_predicha": predicted_age_range, | |
"coordenadas_rostro": (x, y, w, h) | |
}) | |
# Convertir la imagen procesada a base64 para la respuesta | |
result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)) | |
img_byte_arr = io.BytesIO() | |
result_image.save(img_byte_arr, format='JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
# Devolver los resultados | |
return { | |
"message": "Rostros detectados y edad predicha", | |
"rostros_detectados": len(faces), | |
"resultados": results, | |
"imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8') | |
} | |
except Exception as e: | |
# Manejo de errores generales | |
raise HTTPException(status_code=500, detail=f"Error procesando la imagen: {str(e)}") | |