from typing import Union import dotenv import traceback import json import io import os import base64 from fastapi import FastAPI, File, HTTPException, UploadFile, Response import models.face_classifier as classifier from fastapi.middleware.cors import CORSMiddleware from PIL import Image from rembg import remove from utils.helpers import image_to_base64, calculate_mask_area dotenv.load_dotenv() app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # CLIENT = create_client(os.getenv("ROBOFLOW_API_KEY")) # model = FaceClassifierModel(client=CLIENT) model = classifier.FaceSegmentationModel() @app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"]) async def predict_image(file: UploadFile = File(...)): try: # Menangani file yang diunggah image_file = await file.read() pil_image= Image.open(io.BytesIO(image_file)).convert("RGB") # resize image to 512x512 pil_image = pil_image.resize((512, 512)) image_bg_removed= remove(pil_image, bgcolor=(0,0,255,255)) # Memanggil metode classify untuk melakukan klasifikasi results = model.infer(image_bg_removed) print(len(results)) background_element = next((element for element in results if element['label'] == 'background'), None) if background_element: background_area = calculate_mask_area(background_element['mask']) else: background_area = 0 # change the mask to base64 and calculate the score for i in range(len(results)): results[i]["mask"] = image_to_base64(results[i]["mask"]) if results[i]["label"] == "background": continue mask_area = calculate_mask_area(results[i]["mask"]) score = mask_area / (512 * 512 - background_area) results[i]["score"] = score # add original image base 64 as original image: image_bg_removed = image_bg_removed.convert("RGB") response = { "original_image": image_to_base64(image_bg_removed), "segmentation_results":results } # Kembalikan hasil klasifikasi return Response(content=json.dumps(response), status_code=200) except Exception as e: # Mendapatkan stack trace error_traceback = traceback.format_exc() raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")