File size: 3,761 Bytes
adfbf24
9c9ae83
 
 
 
 
 
 
 
 
65aa862
 
 
9c9ae83
65aa862
 
 
 
 
 
 
9c9ae83
65aa862
 
 
 
 
 
9c9ae83
65aa862
9c9ae83
65aa862
 
 
9c9ae83
 
65aa862
 
9c9ae83
65aa862
 
9c9ae83
 
65aa862
9c9ae83
65aa862
9c9ae83
65aa862
9c9ae83
65aa862
9c9ae83
65aa862
9c9ae83
 
 
65aa862
 
 
 
 
 
 
 
9c9ae83
65aa862
 
 
 
 
9c9ae83
 
65aa862
9c9ae83
 
65aa862
 
 
 
 
 
 
9c9ae83
65aa862
 
 
 
9c9ae83
 
 
65aa862
9c9ae83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import base64
from fastapi import FastAPI, File, UploadFile, HTTPException
import cv2
import numpy as np
from PIL import Image
import io
from transformers import ViTFeatureExtractor, ViTForImageClassification

app = FastAPI()

# Inicializamos el modelo de clasificaci贸n de edad y el extractor
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')

# Mapeo de 铆ndices de clase a rangos de edad
age_mapping = [
    "0-2", "3-6", "7-9", "10-12", "13-15",
    "16-19", "20-24", "25-29", "30-34", "35-39",
    "40-44", "45-49", "50-54", "55-59", "60-64",
    "65-69", "70+"
]

# Endpoint para predecir la edad de los rostros detectados en una imagen
@app.post("/predict/")
async def predict_age(file: UploadFile = File(...)):
    """
    Endpoint para predecir el rango de edad de los rostros detectados en una imagen.
    """
    try:
        # Leer la imagen cargada
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")  # Convertimos a RGB si es necesario

        # Convertir la imagen a formato NumPy para usar OpenCV
        img_np = np.array(image)

        # Cargar el clasificador Haar para detecci贸n de rostros
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

        # Convertir la imagen a escala de grises para la detecci贸n de rostros
        gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

        # Verificamos si se detectaron rostros
        if len(faces) == 0:
            raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")

        # Lista para almacenar los resultados de predicci贸n de cada rostro
        results = []

        for (x, y, w, h) in faces:
            # Extraer la regi贸n del rostro
            face_img = img_np[y:y+h, x:x+w]
            pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))

            # Aplicar la transformaci贸n y hacer la predicci贸n de edad
            inputs = transforms(pil_face_img, return_tensors='pt')
            outputs = model(**inputs)

            # Calcular probabilidades y predicci贸n
            proba = outputs.logits.softmax(1)
            preds = proba.argmax(1).item()  # 脥ndice de la clase predicha
            predicted_age_range = age_mapping[preds]

            # Dibujar un rect谩ngulo alrededor del rostro y agregar la edad predicha
            cv2.rectangle(img_np, (x, y), (x + w, y + h), (255, 0, 0), 2)
            cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

            # Guardar el resultado de la predicci贸n de edad y las coordenadas del rostro
            results.append({
                "edad_predicha": predicted_age_range,
                "coordenadas_rostro": (x, y, w, h)
            })

        # Convertir la imagen procesada a base64 para la respuesta
        result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
        img_byte_arr = io.BytesIO()
        result_image.save(img_byte_arr, format='JPEG')
        img_byte_arr = img_byte_arr.getvalue()

        # Devolver los resultados
        return {
            "message": "Rostros detectados y edad predicha",
            "rostros_detectados": len(faces),
            "resultados": results,
            "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
        }

    except Exception as e:
        # Manejo de errores generales
        raise HTTPException(status_code=500, detail=f"Error procesando la imagen: {str(e)}")