import gradio as gr import numpy as np from torchvision import transforms as T from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor ade_mean=[0.485, 0.456, 0.406] ade_std=[0.229, 0.224, 0.225] palette = [ [120, 120, 120], [4, 200, 4], [180, 120, 120], [6, 230, 230], [80, 50, 50], [120, 120, 80], [140, 140, 140], [204, 5, 255] ] model_id = f"thiagohersan/maskformer-satellite-trees" # preprocessor = MaskFormerImageProcessor.from_pretrained(model_id) preprocessor = MaskFormerImageProcessor( do_resize=False, do_normalize=False, do_rescale=False, ignore_index=255, reduce_labels=False ) model = MaskFormerForInstanceSegmentation.from_pretrained(model_id) test_transform = T.Compose([ T.ToTensor(), T.Normalize(mean=ade_mean, std=ade_std) ]) def visualize_instance_seg_mask(img_in, mask, id2label): img_out = np.zeros((mask.shape[0], mask.shape[1], 3)) image_total_pixels = mask.shape[0] * mask.shape[1] label_ids = np.unique(mask) id2color = {id: palette[id] for id in label_ids} id2count = {id: 0 for id in label_ids} for i in range(img_out.shape[0]): for j in range(img_out.shape[1]): img_out[i, j, :] = id2color[mask[i, j]] id2count[mask[i, j]] = id2count[mask[i, j]] + 1 image_res = (0.5 * img_in + 0.5 * img_out) / 255 dataframe = [[ f"{id2label[id]}", f"{(100 * id2count[id] / image_total_pixels):.2f} %", f"{np.sqrt(id2count[id] / image_total_pixels):.2f} m" ] for id in label_ids if 'tree' in id2label[id]] return image_res, dataframe def query_image(img): img_size = (img.shape[0], img.shape[1]) inputs = preprocessor(images=test_transform(np.array(img)), return_tensors="pt") outputs = model(**inputs) results = preprocessor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0] results = visualize_instance_seg_mask(img, results.numpy(), model.config.id2label) return results demo = gr.Interface( query_image, inputs=[gr.Image(label="Input Image")], outputs=[gr.Image(label="Trees"), gr.DataFrame(headers=None, label="Area Info")], title="maskformer-satellite-trees", allow_flagging="never", analytics_enabled=None ) demo.launch(show_api=False)