Inference Endpoints
GRiT / handler.py
Vishakaraj's picture
Debug handler
72a2e39
import os
os.system("pip install detectron2-0.6-cp39-cp39-linux_x86_64.whl")
os.system("pip install deepspeed==0.7.0")
from PIL import Image
from io import BytesIO
import argparse
import sys
import numpy as np
import torch
# from detectron2.config import get_cfg
# from detectron2.data.detection_utils import read_image
# from detectron2.utils.logger import setup_logger
print("###################################", os.getcwd())
sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/")
from centernet.config import add_centernet_config
from grit.config import add_grit_config
from grit.predictor import VisualizationDemo
def setup_cfg(args):
cfg = get_cfg()
if args.cpu:
cfg.MODEL.DEVICE = "cpu"
add_centernet_config(cfg)
add_grit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# Set score_threshold for builtin models
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
if args.test_task:
cfg.MODEL.TEST_TASK = args.test_task
cfg.MODEL.BEAM_SIZE = 1
cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
cfg.USE_ACT_CHECKPOINT = False
cfg.freeze()
return cfg
args = argparse.ArgumentParser().parse_args()
args.config_file = "configs/GRiT_B_DenseCap_ObjectDet.yaml"
args.cpu = False
args.confidence_threshold = 0.5
args.opts = []
args.opts.append("MODEL.WEIGHTS")
args.opts.append("./models/grit_b_densecap_objectdet.pth")
args.test_task = "DenseCap"
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))
cfg = setup_cfg(args)
dense_captioning_demo = VisualizationDemo(cfg)
class EndpointHandler:
def __init__(self):
pass
def __call__(self, image_file):
image_array = np.array(image_file)[:, :, ::-1] # BGR
predictions, visualized_output = dense_captioning_demo.run_on_image(image_array)
buffer = BytesIO()
visualized_output.fig.savefig(buffer, format="png")
buffer.seek(0)
detections = {}
predictions = predictions["instances"].to(torch.device("cpu"))
for box, description, score in zip(
predictions.pred_boxes,
predictions.pred_object_descriptions.data,
predictions.scores,
):
if description not in detections:
detections[description] = []
detections[description].append(
{
"xmin": float(box[0]),
"ymin": float(box[1]),
"xmax": float(box[2]),
"ymax": float(box[3]),
"score": float(score),
}
)
output = {
"dense_captioning_results": {
"detections": detections,
}
}
return Image.open(buffer), output