File size: 1,673 Bytes
c3a1897
 
 
 
 
eb902b3
 
c3a1897
 
 
eb902b3
 
c3a1897
 
 
 
 
 
 
 
 
 
 
 
eb902b3
c3a1897
 
 
 
eb902b3
c3a1897
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from models.segment_models.semgent_anything_model import SegmentAnything
from models.segment_models.semantic_segment_anything_model import SemanticSegment


class RegionSemantic():
    def __init__(self, device):
        self.device = device
        self.init_models()

    def init_models(self):
        self.segment_model = SegmentAnything(self.device)
        self.semantic_segment_model = SemanticSegment(self.device)

    def semantic_prompt_gen(self, anns):
        """
        fliter too small objects and objects with low stability score
        anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
        semantic_prompt: "person: [0.0, 0.0, 0.0, 0.0]; ..."
        """
        # Sort annotations by area in descending order
        sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
        # Select the top 10 largest regions
        top_10_largest_regions = sorted_annotations[:10]
        semantic_prompt = ""
        print('\033[1;35m' + '*' * 100 + '\033[0m')
        print("\nStep3, Semantic Prompt:")
        for region in top_10_largest_regions:
            semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
        print(semantic_prompt)
        print('\033[1;35m' + '*' * 100 + '\033[0m')
        return semantic_prompt

    def region_semantic(self, img_src):
        anns = self.segment_model.generate_mask(img_src)
        anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
        return self.semantic_prompt_gen(anns_w_class)
    
    def region_semantic_debug(self, img_src):
        return "region_semantic_debug"