angelo9830 commited on
Commit
65aa862
1 Parent(s): 9c9ae83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -4,62 +4,87 @@ import numpy as np
4
  from PIL import Image
5
  import io
6
  from transformers import ViTFeatureExtractor, ViTForImageClassification
7
- import torch
8
 
9
  app = FastAPI()
10
 
11
- # Cargar el modelo y el extractor de características
12
- model = ViTForImageClassification.from_pretrained("nateraw/vit-age-classifier")
13
- transforms = ViTFeatureExtractor.from_pretrained("nateraw/vit-age-classifier")
14
 
15
- @app.post("/detect/")
16
- async def detect_face(file: UploadFile = File(...)):
17
- # Validar tipo de archivo
18
- if not file.content_type.startswith("image/"):
19
- raise HTTPException(status_code=400, detail="El archivo no es una imagen válida.")
 
 
20
 
 
 
 
 
 
 
21
  try:
22
- # Leer la imagen y convertirla a numpy array
23
  image_bytes = await file.read()
24
- image = Image.open(io.BytesIO(image_bytes))
 
 
25
  img_np = np.array(image)
26
 
27
- # Convertir a formato adecuado (BGR) si tiene transparencia (4 canales)
28
- if img_np.shape[2] == 4:
29
- img_np = cv2.cvtColor(img_np, cv2.COLOR_BGRA2BGR)
30
 
31
- # Detectar rostros con Haar
32
- face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
33
- gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
34
  faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
35
 
 
36
  if len(faces) == 0:
37
- return {"message": "No se detectaron rostros en la imagen."}
38
 
 
39
  results = []
 
40
  for (x, y, w, h) in faces:
41
- # Extraer el rostro y convertirlo a RGB
42
  face_img = img_np[y:y+h, x:x+w]
43
  pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
44
 
45
- # Predicción del rango de edad
46
- inputs = transforms(pil_face_img, return_tensors="pt")
47
- output = model(**inputs)
48
- proba = output.logits.softmax(1)
49
- predicted_class = proba.argmax(1).item()
50
- predicted_age_range = str(predicted_class) # Mapeo de clases a rangos
 
 
51
 
52
- # Añadir datos al resultado
 
 
 
 
53
  results.append({
54
  "edad_predicha": predicted_age_range,
55
- "coordenadas_rostro": {"x": x, "y": y, "w": w, "h": h}
56
  })
57
 
 
 
 
 
 
 
 
58
  return {
59
- "message": "Rostros detectados y edad predicha.",
60
- "cantidad_rostros": len(faces),
61
- "detalles": results
 
62
  }
63
 
64
  except Exception as e:
 
65
  raise HTTPException(status_code=500, detail=f"Error procesando la imagen: {str(e)}")
 
4
  from PIL import Image
5
  import io
6
  from transformers import ViTFeatureExtractor, ViTForImageClassification
 
7
 
8
  app = FastAPI()
9
 
10
+ # Inicializamos el modelo de clasificación de edad y el extractor
11
+ model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
12
+ transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
13
 
14
+ # Mapeo de índices de clase a rangos de edad
15
+ age_mapping = [
16
+ "0-2", "3-6", "7-9", "10-12", "13-15",
17
+ "16-19", "20-24", "25-29", "30-34", "35-39",
18
+ "40-44", "45-49", "50-54", "55-59", "60-64",
19
+ "65-69", "70+"
20
+ ]
21
 
22
+ # Endpoint para predecir la edad de los rostros detectados en una imagen
23
+ @app.post("/predict/")
24
+ async def predict_age(file: UploadFile = File(...)):
25
+ """
26
+ Endpoint para predecir el rango de edad de los rostros detectados en una imagen.
27
+ """
28
  try:
29
+ # Leer la imagen cargada
30
  image_bytes = await file.read()
31
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Convertimos a RGB si es necesario
32
+
33
+ # Convertir la imagen a formato NumPy para usar OpenCV
34
  img_np = np.array(image)
35
 
36
+ # Cargar el clasificador Haar para detección de rostros
37
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
 
38
 
39
+ # Convertir la imagen a escala de grises para la detección de rostros
40
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
 
41
  faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
42
 
43
+ # Verificamos si se detectaron rostros
44
  if len(faces) == 0:
45
+ raise HTTPException(status_code=404, detail="No se detectaron rostros en la imagen.")
46
 
47
+ # Lista para almacenar los resultados de predicción de cada rostro
48
  results = []
49
+
50
  for (x, y, w, h) in faces:
51
+ # Extraer la región del rostro
52
  face_img = img_np[y:y+h, x:x+w]
53
  pil_face_img = Image.fromarray(cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB))
54
 
55
+ # Aplicar la transformación y hacer la predicción de edad
56
+ inputs = transforms(pil_face_img, return_tensors='pt')
57
+ outputs = model(**inputs)
58
+
59
+ # Calcular probabilidades y predicción
60
+ proba = outputs.logits.softmax(1)
61
+ preds = proba.argmax(1).item() # Índice de la clase predicha
62
+ predicted_age_range = age_mapping[preds]
63
 
64
+ # Dibujar un rectángulo alrededor del rostro y agregar la edad predicha
65
+ cv2.rectangle(img_np, (x, y), (x + w, y + h), (255, 0, 0), 2)
66
+ cv2.putText(img_np, f"Edad: {predicted_age_range}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
67
+
68
+ # Guardar el resultado de la predicción de edad y las coordenadas del rostro
69
  results.append({
70
  "edad_predicha": predicted_age_range,
71
+ "coordenadas_rostro": (x, y, w, h)
72
  })
73
 
74
+ # Convertir la imagen procesada a base64 para la respuesta
75
+ result_image = Image.fromarray(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
76
+ img_byte_arr = io.BytesIO()
77
+ result_image.save(img_byte_arr, format='JPEG')
78
+ img_byte_arr = img_byte_arr.getvalue()
79
+
80
+ # Devolver los resultados
81
  return {
82
+ "message": "Rostros detectados y edad predicha",
83
+ "rostros_detectados": len(faces),
84
+ "resultados": results,
85
+ "imagen_base64": base64.b64encode(img_byte_arr).decode('utf-8')
86
  }
87
 
88
  except Exception as e:
89
+ # Manejo de errores generales
90
  raise HTTPException(status_code=500, detail=f"Error procesando la imagen: {str(e)}")