haydpw's picture
change mask colors
d3bf65a
from typing import Optional, Union
import dotenv
import traceback
import json
import io
import os
import base64
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Response
from ultralytics import YOLO
import models.face_classifier as classifier
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from PIL import Image
from rembg import remove
import procedures
from utils import handlers
from supervision import Detections
from utils.helpers import combine_images, image_to_base64, calculate_mask_area, process_image
dotenv.load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# CLIENT = create_client(os.getenv("ROBOFLOW_API_KEY"))
# model = FaceClassifierModel(client=CLIENT)
model = classifier.FaceSegmentationModel()
yolo_model_path = hf_hub_download(repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt")
@app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"])
async def predict_image(file: UploadFile = File(...)):
try:
# Menangani file yang diunggah
image_file = await file.read()
pil_image= Image.open(io.BytesIO(image_file)).convert("RGB")
# resize image to 512x512
pil_image = pil_image.resize((500, 500))
image_bg_removed= remove(pil_image, bgcolor=(0,0,255,255))
# Memanggil metode classify untuk melakukan klasifikasi
results = model.infer(image_bg_removed)
print(len(results))
background_element = next((element for element in results if element['label'] == 'background'), None)
print(background_element["mask"].size)
if background_element:
background_area = calculate_mask_area(background_element['mask'], True)
print(background_area)
else:
background_area = 0
mask_dict = {
'acne': (245, 177, 177),
'dry': (208, 181, 166),
'oily': (240, 230, 214),
'background': (255,255,255)
}
# change the mask to base64 and calculate the score
for i in range(len(results)):
mask_area = calculate_mask_area(results[i]["mask"])
print(results[i]["mask"].size)
processed_image = process_image(results[i]["mask"], mask_dict[results[i]['label']])
results[i]["mask"] = image_to_base64(processed_image, "PNG")
if results[i]["label"] == "background":
continue
print(f"{results[i]['label']} area: {mask_area}")
score = mask_area / (500 * 500 - background_area)
results[i]["score"] = score
# add original image base 64 as original image:
image_bg_removed = image_bg_removed.convert("RGB")
# Combine the original image and masks
combined_image = combine_images(pil_image, results)
combined_image_base64 = image_to_base64(combined_image, "PNG")
response = {
"original_image": image_to_base64(pil_image),
"segmentation_results":results,
"combined_image": combined_image_base64
}
# Kembalikan hasil klasifikasi
return Response(content=json.dumps(response), status_code=200)
except Exception as e:
# Mendapatkan stack trace
error_traceback = traceback.format_exc()
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
@app.post("/detect-face/", summary="Detect face from image", tags=["Classify"])
async def detect_face(file: Optional[UploadFile] = File(None), base64_string: Optional[str] = Form(None)):
try:
if file is None and base64_string is None:
raise HTTPException(status_code=400, detail="No input data provided")
base64_handler = handlers.Base64Handler()
image_handler = handlers.ImageFileHandler(successor=base64_handler)
input_data: Union[UploadFile, str, None] = file if file is not None else base64_string
print(input_data)
pil_image = await image_handler.handle(input_data)
if pil_image is None:
raise HTTPException(status_code=400, detail="Unsupported file type")
# Load the YOLO model
model = YOLO(yolo_model_path)
# Inference using the pil image
output = model(pil_image)
results = Detections.from_ultralytics(output[0])
if len(results) == 0:
raise HTTPException(status_code=404, detail="No face detected")
# Get the first bounding box
first_bbox = results[0].xyxy[0].tolist()
# Crop the image using the bounding box
x_min, y_min, x_max, y_max = map(int, first_bbox)
cropped_image = pil_image.crop((x_min, y_min, x_max, y_max))
# Convert cropped image to Base64
buffered = io.BytesIO()
cropped_image.save(buffered, format="JPEG")
cropped_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"bounding_box": first_bbox, "cropped_image": cropped_image_base64}
except HTTPException as e:
error_traceback = traceback.format_exc()
raise e
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")