|
|
|
import argparse |
|
import os |
|
from typing import Sequence |
|
|
|
import mmcv |
|
from mmdet.apis import inference_detector, init_detector |
|
from mmengine import Config, DictAction |
|
from mmengine.registry import init_default_scope |
|
from mmengine.utils import ProgressBar |
|
|
|
from mmyolo.registry import VISUALIZERS |
|
from mmyolo.utils.misc import auto_arrange_images, get_file_list |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Visualize feature map') |
|
parser.add_argument( |
|
'img', help='Image path, include image file, dir and URL.') |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument( |
|
'--out-dir', default='./output', help='Path to output file') |
|
parser.add_argument( |
|
'--target-layers', |
|
default=['backbone'], |
|
nargs='+', |
|
type=str, |
|
help='The target layers to get feature map, if not set, the tool will ' |
|
'specify the backbone') |
|
parser.add_argument( |
|
'--preview-model', |
|
default=False, |
|
action='store_true', |
|
help='To preview all the model layers') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference') |
|
parser.add_argument( |
|
'--score-thr', type=float, default=0.3, help='Bbox score threshold') |
|
parser.add_argument( |
|
'--show', action='store_true', help='Show the featmap results') |
|
parser.add_argument( |
|
'--channel-reduction', |
|
default='select_max', |
|
help='Reduce multiple channels to a single channel') |
|
parser.add_argument( |
|
'--topk', |
|
type=int, |
|
default=4, |
|
help='Select topk channel to show by the sum of each channel') |
|
parser.add_argument( |
|
'--arrangement', |
|
nargs='+', |
|
type=int, |
|
default=[2, 2], |
|
help='The arrangement of featmap when channel_reduction is ' |
|
'not None and topk > 0') |
|
parser.add_argument( |
|
'--cfg-options', |
|
nargs='+', |
|
action=DictAction, |
|
help='override some settings in the used config, the key-value pair ' |
|
'in xxx=yyy format will be merged into config file. If the value to ' |
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' |
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' |
|
'Note that the quotation marks are necessary and that no white space ' |
|
'is allowed.') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
class ActivationsWrapper: |
|
|
|
def __init__(self, model, target_layers): |
|
self.model = model |
|
self.activations = [] |
|
self.handles = [] |
|
self.image = None |
|
for target_layer in target_layers: |
|
self.handles.append( |
|
target_layer.register_forward_hook(self.save_activation)) |
|
|
|
def save_activation(self, module, input, output): |
|
self.activations.append(output) |
|
|
|
def __call__(self, img_path): |
|
self.activations = [] |
|
results = inference_detector(self.model, img_path) |
|
return results, self.activations |
|
|
|
def release(self): |
|
for handle in self.handles: |
|
handle.remove() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
cfg = Config.fromfile(args.config) |
|
if args.cfg_options is not None: |
|
cfg.merge_from_dict(args.cfg_options) |
|
|
|
init_default_scope(cfg.get('default_scope', 'mmyolo')) |
|
|
|
channel_reduction = args.channel_reduction |
|
if channel_reduction == 'None': |
|
channel_reduction = None |
|
assert len(args.arrangement) == 2 |
|
|
|
model = init_detector(args.config, args.checkpoint, device=args.device) |
|
|
|
if not os.path.exists(args.out_dir) and not args.show: |
|
os.mkdir(args.out_dir) |
|
|
|
if args.preview_model: |
|
print(model) |
|
print('\n This flag is only show model, if you want to continue, ' |
|
'please remove `--preview-model` to get the feature map.') |
|
return |
|
|
|
target_layers = [] |
|
for target_layer in args.target_layers: |
|
try: |
|
target_layers.append(eval(f'model.{target_layer}')) |
|
except Exception as e: |
|
print(model) |
|
raise RuntimeError('layer does not exist', e) |
|
|
|
activations_wrapper = ActivationsWrapper(model, target_layers) |
|
|
|
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer) |
|
visualizer.dataset_meta = model.dataset_meta |
|
|
|
|
|
image_list, source_type = get_file_list(args.img) |
|
|
|
progress_bar = ProgressBar(len(image_list)) |
|
for image_path in image_list: |
|
result, featmaps = activations_wrapper(image_path) |
|
if not isinstance(featmaps, Sequence): |
|
featmaps = [featmaps] |
|
|
|
flatten_featmaps = [] |
|
for featmap in featmaps: |
|
if isinstance(featmap, Sequence): |
|
flatten_featmaps.extend(featmap) |
|
else: |
|
flatten_featmaps.append(featmap) |
|
|
|
img = mmcv.imread(image_path) |
|
img = mmcv.imconvert(img, 'bgr', 'rgb') |
|
|
|
if source_type['is_dir']: |
|
filename = os.path.relpath(image_path, args.img).replace('/', '_') |
|
else: |
|
filename = os.path.basename(image_path) |
|
out_file = None if args.show else os.path.join(args.out_dir, filename) |
|
|
|
|
|
shown_imgs = [] |
|
visualizer.add_datasample( |
|
'result', |
|
img, |
|
data_sample=result, |
|
draw_gt=False, |
|
show=False, |
|
wait_time=0, |
|
out_file=None, |
|
pred_score_thr=args.score_thr) |
|
drawn_img = visualizer.get_image() |
|
|
|
for featmap in flatten_featmaps: |
|
shown_img = visualizer.draw_featmap( |
|
featmap[0], |
|
drawn_img, |
|
channel_reduction=channel_reduction, |
|
topk=args.topk, |
|
arrangement=args.arrangement) |
|
shown_imgs.append(shown_img) |
|
|
|
shown_imgs = auto_arrange_images(shown_imgs) |
|
|
|
progress_bar.update() |
|
if out_file: |
|
mmcv.imwrite(shown_imgs[..., ::-1], out_file) |
|
|
|
if args.show: |
|
visualizer.show(shown_imgs) |
|
|
|
if not args.show: |
|
print(f'All done!' |
|
f'\nResults have been saved at {os.path.abspath(args.out_dir)}') |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|