from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.encoders import jsonable_encoder from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import load_img, img_to_array from PIL import Image import tensorflow.keras.backend as K import os import uvicorn import numpy as np # Initialize FastAPI app app = FastAPI() origins = ['*'] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Custom F1 score function def f1_score(y_true, y_pred): precision = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) / K.maximum( K.sum(K.round(K.clip(y_pred, 0, 1))), K.epsilon() ) recall = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) / K.maximum( K.sum(K.round(K.clip(y_true, 0, 1))), K.epsilon() ) return 2 * (precision * recall) / (precision + recall + K.epsilon()) # Load model MODEL_PATH = "Trained_after_EFF0.keras" # Ensure the path is correct model = load_model(MODEL_PATH, custom_objects={'f1_score': f1_score}) # Image size for the model IMAGE_SIZE = 224 # Preprocess image def preprocess_image(image_path, target_size): image = load_img(image_path, target_size=(target_size, target_size)) image_array = img_to_array(image) image_array = np.expand_dims(image_array, axis=0) return image_array # API to predict image @app.post("/predict") async def predict_image(file: UploadFile = File(...)): try: upload_dir = "./uploads" os.makedirs(upload_dir, exist_ok=True) file_path = os.path.join(upload_dir, file.filename) with open(file_path, "wb") as buffer: buffer.write(await file.read()) image_array = preprocess_image(file_path, target_size=IMAGE_SIZE) prediction = model.predict(image_array) predicted_label = int(np.argmax(prediction)) confidence = float(np.max(prediction)) os.remove(file_path) return JSONResponse( content=jsonable_encoder( {"predicted_label": predicted_label, "confidence": confidence} ) ) except Exception as e: return JSONResponse(content=jsonable_encoder({"error": str(e)}), status_code=500) # Run FastAPI server if __name__ == '__main__': uvicorn.run(app, host="0.0.0.0", port=8002)