Inference Endpoints
File size: 3,002 Bytes
a567fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac2cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a567fa4
 
4ac2cad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os

# os.system("cd detectron2 && pip install detectron2-0.6-cp310-cp310-linux_x86_64.whl")
# os.system("pip install deepspeed==0.7.0")

import site
from importlib import reload

reload(site)

from PIL import Image
from io import BytesIO
import argparse
import sys
import numpy as np
import torch
import gradio as gr

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger

sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/")
from centernet.config import add_centernet_config
from grit.config import add_grit_config

from grit.predictor import VisualizationDemo


def setup_cfg(args):
    cfg = get_cfg()
    if args.cpu:
        cfg.MODEL.DEVICE = "cpu"
    add_centernet_config(cfg)
    add_grit_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    # Set score_threshold for builtin models
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
        args.confidence_threshold
    )
    if args.test_task:
        cfg.MODEL.TEST_TASK = args.test_task
    cfg.MODEL.BEAM_SIZE = 1
    cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
    cfg.USE_ACT_CHECKPOINT = False
    cfg.freeze()
    return cfg


args = argparse.ArgumentParser().parse_args()
args.config_file = "configs/GRiT_B_DenseCap_ObjectDet.yaml"
args.cpu = False
args.confidence_threshold = 0.5
args.opts = []
args.opts.append("MODEL.WEIGHTS")
args.opts.append("./models/grit_b_densecap_objectdet.pth")
args.test_task = "DenseCap"
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))

cfg = setup_cfg(args)

dense_captioning_demo = VisualizationDemo(cfg)


class EndpointHandler:
    def __init__(self):
        pass


    def __call__(self, image_file):
        image_array = np.array(image_file)[:, :, ::-1]  # BGR
        predictions, visualized_output = dense_captioning_demo.run_on_image(image_array)
        buffer = BytesIO()
        visualized_output.fig.savefig(buffer, format="png")
        buffer.seek(0)
        detections = {}
        predictions = predictions["instances"].to(torch.device("cpu"))
    
        for box, description, score in zip(
            predictions.pred_boxes,
            predictions.pred_object_descriptions.data,
            predictions.scores,
        ):
            if description not in detections:
                detections[description] = []
            detections[description].append(
                {
                    "xmin": float(box[0]),
                    "ymin": float(box[1]),
                    "xmax": float(box[2]),
                    "ymax": float(box[3]),
                    "score": float(score),
                }
            )
    
        output = {
            "dense_captioning_results": {
                "detections": detections,
            }
        }
    
        return Image.open(buffer), output