Inference Endpoints
File size: 2,953 Bytes
a567fa4
 
72a2e39
5b3ddda
a567fa4
 
 
 
 
 
 
 
72a2e39
 
 
 
a567fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72d9615
4ac2cad
 
 
 
 
 
 
df0396d
4ac2cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0396d
4ac2cad
 
 
a567fa4
 
df0396d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os

os.system("pip install detectron2-0.6-cp39-cp39-linux_x86_64.whl")
os.system("pip install deepspeed==0.7.0")

from PIL import Image
from io import BytesIO
import argparse
import sys
import numpy as np
import torch

# from detectron2.config import get_cfg
# from detectron2.data.detection_utils import read_image
# from detectron2.utils.logger import setup_logger
print("###################################", os.getcwd())
sys.path.insert(0, "third_party/CenterNet2/projects/CenterNet2/")
from centernet.config import add_centernet_config
from grit.config import add_grit_config

from grit.predictor import VisualizationDemo


def setup_cfg(args):
    cfg = get_cfg()
    if args.cpu:
        cfg.MODEL.DEVICE = "cpu"
    add_centernet_config(cfg)
    add_grit_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    # Set score_threshold for builtin models
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
        args.confidence_threshold
    )
    if args.test_task:
        cfg.MODEL.TEST_TASK = args.test_task
    cfg.MODEL.BEAM_SIZE = 1
    cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
    cfg.USE_ACT_CHECKPOINT = False
    cfg.freeze()
    return cfg


args = argparse.ArgumentParser().parse_args()
args.config_file = "configs/GRiT_B_DenseCap_ObjectDet.yaml"
args.cpu = False
args.confidence_threshold = 0.5
args.opts = []
args.opts.append("MODEL.WEIGHTS")
args.opts.append("./models/grit_b_densecap_objectdet.pth")
args.test_task = "DenseCap"
setup_logger(name="fvcore")
logger = setup_logger()
logger.info("Arguments: " + str(args))

cfg = setup_cfg(args)

dense_captioning_demo = VisualizationDemo(cfg)


class EndpointHandler:
    def __init__(self):
        pass


    def __call__(self, image_file):
        image_array = np.array(image_file)[:, :, ::-1]  # BGR
        predictions, visualized_output = dense_captioning_demo.run_on_image(image_array)
        buffer = BytesIO()
        visualized_output.fig.savefig(buffer, format="png")
        buffer.seek(0)
        detections = {}
        predictions = predictions["instances"].to(torch.device("cpu"))

        for box, description, score in zip(
            predictions.pred_boxes,
            predictions.pred_object_descriptions.data,
            predictions.scores,
        ):
            if description not in detections:
                detections[description] = []
            detections[description].append(
                {
                    "xmin": float(box[0]),
                    "ymin": float(box[1]),
                    "xmax": float(box[2]),
                    "ymax": float(box[3]),
                    "score": float(score),
                }
            )

        output = {
            "dense_captioning_results": {
                "detections": detections,
            }
        }

        return Image.open(buffer), output