face2 / app.py
angelo9830's picture
Update app.py
adfbf24 verified
raw
history blame
3.76 kB
import base64
from fastapi import FastAPI, File, UploadFile, HTTPException
import cv2
import numpy as np
from PIL import Image
import io
from transformers import ViTFeatureExtractor, ViTForImageClassification
app = FastAPI()
# Inicializamos el modelo de clasificaci贸n de edad y el extractor
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
# Mapeo de 铆ndices de clase a rangos de edad
age_mapping = [
"0-2", "3-6", "7-9", "10-12", "13-15",
"16-19", "20-24", "25-29", "30-34", "35-39",
"40-44", "45-49", "50-54", "55-59", "60-64",
"65-69", "70+"
]
# Endpoint para predecir la edad de los rostros detectados en una imagen
@app.post("/predict/")
async def predict_age(file: UploadFile = File(...)):
"""
Endpoint para predecir el rango de edad de los rostros detectados en una imagen.
"""
try:
# Leer la imagen cargada
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Convertimos a RGB si es necesario
# Convertir la imagen a formato NumPy para usar OpenCV
img_np = np.array(image)
# Cargar el clasificador Haar para detecci贸n de rostros
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# Convertir la imagen a escala de grises para la detecci贸n de rostros
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
# Verificamos si se detectaron rostros
if len(faces) == 0:
raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")
# Lista para almacenar los resultados de predicci贸n de cada rostro
results = []
for (x, y, w, h) in faces:
# Extraer la regi贸n del rostro
face_img = img_np[y:y+h, x:x+w]
pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
# Aplicar la transformaci贸n y hacer la predicci贸n de edad
inputs = transforms(pil_face_img, return_tensors='pt')
outputs = model(**inputs)
# Calcular probabilidades y predicci贸n
proba = outputs.logits.softmax(1)
preds = proba.argmax(1).item() # 脥ndice de la clase predicha
predicted_age_range = age_mapping[preds]
# Dibujar un rect谩ngulo alrededor del rostro y agregar la edad predicha
cv2.rectangle(img_np, (x, y), (x + w, y + h), (255, 0, 0), 2)
cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
# Guardar el resultado de la predicci贸n de edad y las coordenadas del rostro
results.append({
"edad_predicha": predicted_age_range,
"coordenadas_rostro": (x, y, w, h)
})
# Convertir la imagen procesada a base64 para la respuesta
result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
# Devolver los resultados
return {
"message": "Rostros detectados y edad predicha",
"rostros_detectados": len(faces),
"resultados": results,
"imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
}
except Exception as e:
# Manejo de errores generales
raise HTTPException(status_code=500, detail=f"Error procesando la imagen: {str(e)}")