Inference Endpoints
GRiT / handler.py
Vishakaraj's picture
Update handler.py
4ac2cad
raw
history blame
3 kB
import os
# os.system("cd detectron2 && pip install detectron2-0.6-cp310-cp310-linux_x86_64.whl")
# os.system("pip install deepspeed==0.7.0")
import site
from importlib import reload
reload(site)
from PIL import Image
from io import BytesIO
import argparse
import sys
import numpy as np
import torch
import gradio as gr
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/")
from centernet.config import add_centernet_config
from grit.config import add_grit_config
from grit.predictor import VisualizationDemo
def setup_cfg(args):
cfg = get_cfg()
if args.cpu:
cfg.MODEL.DEVICE = "cpu"
add_centernet_config(cfg)
add_grit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# Set score_threshold for builtin models
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
if args.test_task:
cfg.MODEL.TEST_TASK = args.test_task
cfg.MODEL.BEAM_SIZE = 1
cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
cfg.USE_ACT_CHECKPOINT = False
cfg.freeze()
return cfg
args = argparse.ArgumentParser().parse_args()
args.config_file = "configs/GRiT_B_DenseCap_ObjectDet.yaml"
args.cpu = False
args.confidence_threshold = 0.5
args.opts = []
args.opts.append("MODEL.WEIGHTS")
args.opts.append("./models/grit_b_densecap_objectdet.pth")
args.test_task = "DenseCap"
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))
cfg = setup_cfg(args)
dense_captioning_demo = VisualizationDemo(cfg)
class EndpointHandler:
def __init__(self):
pass
def __call__(self, image_file):
image_array = np.array(image_file)[:, :, ::-1] # BGR
predictions, visualized_output = dense_captioning_demo.run_on_image(image_array)
buffer = BytesIO()
visualized_output.fig.savefig(buffer, format="png")
buffer.seek(0)
detections = {}
predictions = predictions["instances"].to(torch.device("cpu"))
for box, description, score in zip(
predictions.pred_boxes,
predictions.pred_object_descriptions.data,
predictions.scores,
):
if description not in detections:
detections[description] = []
detections[description].append(
{
"xmin": float(box[0]),
"ymin": float(box[1]),
"xmax": float(box[2]),
"ymax": float(box[3]),
"score": float(score),
}
)
output = {
"dense_captioning_results": {
"detections": detections,
}
}
return Image.open(buffer), output