|
from typing import Optional, Union |
|
import dotenv |
|
import traceback |
|
import json |
|
import io |
|
import os |
|
import base64 |
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Response |
|
from ultralytics import YOLO |
|
import models.face_classifier as classifier |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from rembg import remove |
|
import procedures |
|
from utils import handlers |
|
from supervision import Detections |
|
from utils.helpers import combine_images, image_to_base64, calculate_mask_area, process_image |
|
|
|
|
|
dotenv.load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
model = classifier.FaceSegmentationModel() |
|
|
|
yolo_model_path = hf_hub_download(repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt") |
|
|
|
@app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"]) |
|
async def predict_image(file: UploadFile = File(...)): |
|
try: |
|
|
|
image_file = await file.read() |
|
pil_image= Image.open(io.BytesIO(image_file)).convert("RGB") |
|
|
|
|
|
pil_image = pil_image.resize((500, 500)) |
|
|
|
image_bg_removed= remove(pil_image, bgcolor=(0,0,255,255)) |
|
|
|
|
|
results = model.infer(image_bg_removed) |
|
|
|
print(len(results)) |
|
|
|
background_element = next((element for element in results if element['label'] == 'background'), None) |
|
print(background_element["mask"].size) |
|
|
|
if background_element: |
|
background_area = calculate_mask_area(background_element['mask'], True) |
|
print(background_area) |
|
else: |
|
background_area = 0 |
|
|
|
mask_dict = { |
|
'acne': (245, 177, 177), |
|
'dry': (208, 181, 166), |
|
'oily': (240, 230, 214), |
|
'background': (255,255,255) |
|
} |
|
|
|
for i in range(len(results)): |
|
mask_area = calculate_mask_area(results[i]["mask"]) |
|
print(results[i]["mask"].size) |
|
processed_image = process_image(results[i]["mask"], mask_dict[results[i]['label']]) |
|
results[i]["mask"] = image_to_base64(processed_image, "PNG") |
|
if results[i]["label"] == "background": |
|
continue |
|
print(f"{results[i]['label']} area: {mask_area}") |
|
score = mask_area / (500 * 500 - background_area) |
|
results[i]["score"] = score |
|
|
|
|
|
|
|
image_bg_removed = image_bg_removed.convert("RGB") |
|
|
|
|
|
combined_image = combine_images(pil_image, results) |
|
combined_image_base64 = image_to_base64(combined_image, "PNG") |
|
|
|
response = { |
|
"original_image": image_to_base64(pil_image), |
|
"segmentation_results":results, |
|
"combined_image": combined_image_base64 |
|
} |
|
|
|
|
|
return Response(content=json.dumps(response), status_code=200) |
|
|
|
except Exception as e: |
|
|
|
error_traceback = traceback.format_exc() |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
|
|
|
@app.post("/detect-face/", summary="Detect face from image", tags=["Classify"]) |
|
async def detect_face(file: Optional[UploadFile] = File(None), base64_string: Optional[str] = Form(None)): |
|
try: |
|
if file is None and base64_string is None: |
|
raise HTTPException(status_code=400, detail="No input data provided") |
|
base64_handler = handlers.Base64Handler() |
|
image_handler = handlers.ImageFileHandler(successor=base64_handler) |
|
input_data: Union[UploadFile, str, None] = file if file is not None else base64_string |
|
print(input_data) |
|
pil_image = await image_handler.handle(input_data) |
|
if pil_image is None: |
|
raise HTTPException(status_code=400, detail="Unsupported file type") |
|
|
|
|
|
|
|
model = YOLO(yolo_model_path) |
|
|
|
|
|
output = model(pil_image) |
|
results = Detections.from_ultralytics(output[0]) |
|
|
|
if len(results) == 0: |
|
raise HTTPException(status_code=404, detail="No face detected") |
|
|
|
|
|
first_bbox = results[0].xyxy[0].tolist() |
|
|
|
|
|
x_min, y_min, x_max, y_max = map(int, first_bbox) |
|
cropped_image = pil_image.crop((x_min, y_min, x_max, y_max)) |
|
|
|
|
|
buffered = io.BytesIO() |
|
cropped_image.save(buffered, format="JPEG") |
|
cropped_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return {"bounding_box": first_bbox, "cropped_image": cropped_image_base64} |
|
|
|
except HTTPException as e: |
|
error_traceback = traceback.format_exc() |
|
raise e |
|
except Exception as e: |
|
error_traceback = traceback.format_exc() |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |