import base64 from fastapi import FastAPI, File, UploadFile, HTTPException import cv2 import numpy as np from PIL import Image import io from transformers import ViTFeatureExtractor, ViTForImageClassification app = FastAPI() # Inicializamos el modelo de clasificación de edad y el extractor model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') # Mapeo de índices de clase a rangos de edad age_mapping = [ "0-2", "3-6", "7-9", "10-12", "13-15", "16-19", "20-24", "25-29", "30-34", "35-39", "40-44", "45-49", "50-54", "55-59", "60-64", "65-69", "70+" ] # Endpoint para predecir la edad de los rostros detectados en una imagen @app.post("/predict/") async def predict_age(file: UploadFile = File(...)): """ Endpoint para predecir el rango de edad de los rostros detectados en una imagen. """ try: # Leer la imagen cargada image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Convertimos a RGB si es necesario # Convertir la imagen a formato NumPy para usar OpenCV img_np = np.array(image) # Cargar el clasificador Haar para detección de rostros face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') # Convertir la imagen a escala de grises para la detección de rostros gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) # Verificamos si se detectaron rostros if len(faces) == 0: raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.") # Lista para almacenar los resultados de predicción de cada rostro results = [] for (x, y, w, h) in faces: # Extraer la región del rostro face_img = img_np[y:y+h, x:x+w] pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)) # Aplicar la transformación y hacer la predicción de edad inputs = transforms(pil_face_img, return_tensors='pt') outputs = model(**inputs) # Calcular probabilidades y predicción proba = outputs.logits.softmax(1) preds = proba.argmax(1).item() # Índice de la clase predicha predicted_age_range = age_mapping[preds] # Dibujar un rectángulo alrededor del rostro y agregar la edad predicha cv2.rectangle(img_np, (x, y), (x + w, y + h), (255, 0, 0), 2) cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2) # Guardar el resultado de la predicción de edad y las coordenadas del rostro results.append({ "edad_predicha": predicted_age_range, "coordenadas_rostro": (x, y, w, h) }) # Convertir la imagen procesada a base64 para la respuesta result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)) img_byte_arr = io.BytesIO() result_image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() # Devolver los resultados return { "message": "Rostros detectados y edad predicha", "rostros_detectados": len(faces), "resultados": results, "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8') } except Exception as e: # Manejo de errores generales raise HTTPException(status_code=500, detail=f"Error procesando la imagen: {str(e)}")