Spaces:
Running
Running
import requests | |
import os | |
import sys | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import geopandas as gpd | |
from pyproj.transformer import Transformer | |
sys.path.append(os.path.dirname(os.path.realpath(__file__))) | |
from MapItAnywhere.mia.bev import get_bev | |
from MapItAnywhere.mia.fpv import get_fpv | |
from MapItAnywhere.mia.fpv import filters | |
from MapItAnywhere.mia import logger | |
def get_city_boundary(query, fetch_shape=False): | |
# Use Nominatim API to get the boundary of the city | |
base_url = "https://nominatim.openstreetmap.org/search" | |
params = { | |
'q': query, | |
'format': 'json', | |
'limit': 1, | |
'polygon_geojson': 1 if fetch_shape else 0 | |
} | |
headers = { | |
'User-Agent': f'mapperceptionnet_{query}' | |
} | |
response = requests.get(base_url, params=params, headers=headers) | |
if response.status_code != 200: | |
logger.error(f"Nominatim error when fetching boundary data for {query}.\n" | |
f"Status code: {response.status_code}. Content: {response.content}") | |
return None | |
data = response.json() | |
if data is None: | |
logger.warn(f"No data returned by Nominatim for {query}") | |
return None | |
# Extract bbox data from the API response | |
bbox_data = data[0]['boundingbox'] | |
bbox = { | |
'west': float(bbox_data[2]), | |
'south': float(bbox_data[0]), | |
'east': float(bbox_data[3]), | |
'north': float(bbox_data[1]) | |
} | |
if fetch_shape: | |
# Extract GeoJSON boundary data from the API response | |
boundary_geojson = data[0]['geojson'] | |
boundary_geojson = { | |
"type": "FeatureCollection", | |
"features": [ | |
{"type": "Feature", | |
"properties": {}, | |
"geometry": boundary_geojson}] | |
} | |
return bbox, boundary_geojson | |
else: | |
return bbox | |
def split_dataframe(df, chunk_size = 100): | |
chunks = list() | |
num_chunks = len(df) // chunk_size + 1 | |
for i in range(num_chunks): | |
chunks.append(df[i*chunk_size:(i+1)*chunk_size]) | |
return chunks | |
async def fetch(location, filter_undistort, disable_cam_filter, map_length, mpp): | |
N=1 | |
TOTAL_LOOKED_INTO_LIMIT = 10000 | |
################ FPV | |
downloader = get_fpv.MapillaryDownloader(os.getenv("MLY_TOKEN")) | |
bbox = get_city_boundary(query=location) | |
tiles = get_fpv.get_tiles_from_boundary(boundary_info=dict(bound_type="auto_bbox", bbox=bbox), zoom=14) | |
np.random.shuffle(tiles) | |
total_looked_into = 0 | |
dfs_meta = list() | |
for tile in tiles: | |
image_points_response = await downloader.get_tiles_image_points([tile]) | |
if image_points_response is None: | |
continue | |
try: | |
df = get_fpv.parse_image_points_json_data(image_points_response) | |
if len(df) == 0: | |
continue | |
total_looked_into += len(df) | |
df_split = split_dataframe(df, chunk_size=100) | |
for df in df_split: | |
image_ids = df["id"] | |
image_infos, num_fail = await get_fpv.fetch_image_infos(image_ids, downloader, infos_dir) | |
df_meta = get_fpv.geojson_feature_list_to_pandas(image_infos.values()) | |
# Some standardization of the data | |
df_meta["model"] = df_meta["model"].str.lower().str.replace(' ', '').str.replace('_', '') | |
df_meta["make"] = df_meta["make"].str.lower().str.replace(' ', '').str.replace('_', '') | |
if filter_undistort: | |
fp = no_cam_filter_pipeline if disable_cam_filter else filter_pipeline | |
df_meta = fp(df_meta) | |
dfs_meta.append(df_meta) | |
total_rows = sum([len(x) for x in dfs_meta]) | |
if total_rows > N: | |
break | |
elif total_looked_into > TOTAL_LOOKED_INTO_LIMIT: | |
yield (f"Went through {total_looked_into} images and could not find images satisfying the filters." | |
"\nPlease rerun or run the data engine locally for bulk time consuming operations.", None, None) | |
return | |
if total_rows > N: | |
break | |
except: | |
pass | |
df_meta = pd.concat(dfs_meta) | |
df_meta = df_meta.sample(N) | |
# Calc derrivative attributes | |
df_meta["loc_descrip"] = filters.haversine_np( | |
lon1=df_meta["geometry.long"], lat1=df_meta["geometry.lat"], | |
lon2=df_meta["computed_geometry.long"], lat2=df_meta["computed_geometry.lat"] | |
) | |
df_meta["angle_descrip"] = filters.angle_dist( | |
df_meta["compass_angle"], | |
df_meta["computed_compass_angle"] | |
) | |
for index, row in df_meta.iterrows(): | |
desc = list() | |
# Display attributes | |
keys = ["id", "geometry.long", "geometry.lat", "compass_angle", | |
"loc_descrip", "angle_descrip", | |
"make", "model", "camera_type", | |
"quality_score"] | |
for k in keys: | |
v = row[k] | |
if isinstance(v, float): | |
v = f"{v:.4f}" | |
bullet = f"{k}: {v}" | |
desc.append(bullet) | |
metadata_fmt = "\n".join(desc) | |
yield metadata_fmt, None, None | |
image_urls = list(df_meta.set_index("id")["thumb_2048_url"].items()) | |
num_fail = await get_fpv.fetch_images_pixels(image_urls, downloader, raw_image_dir) | |
if num_fail > 0: | |
logger.error(f"Failed to download {num_fail} images.") | |
seq_to_image_ids = df_meta.groupby('sequence')['id'].agg(list).to_dict() | |
lon_center = (bbox['east'] + bbox['west']) / 2 | |
lat_center = (bbox['north'] + bbox['south']) / 2 | |
projection = get_fpv.Projection(lat_center, lon_center, max_extent=200e3) | |
df_meta.index = df_meta["id"] | |
image_infos = df_meta.to_dict(orient="index") | |
process_sequence_args = get_fpv.default_cfg | |
if filter_undistort: | |
for seq_id, seq_image_ids in seq_to_image_ids.items(): | |
try: | |
d, pi = get_fpv.process_sequence( | |
seq_image_ids, | |
image_infos, | |
projection, | |
process_sequence_args, | |
raw_image_dir, | |
out_image_dir, | |
) | |
if d is None or pi is None: | |
raise Exception("process_sequence returned None") | |
except Exception as e: | |
logger.error(f"Failed to process sequence {seq_id} skipping it. Error: {repr(e)}.") | |
fpv = plt.imread(out_image_dir/ f"{row['id']}_undistorted.jpg") | |
else: | |
fpv = plt.imread(raw_image_dir/ f"{row['id']}.jpg") | |
yield metadata_fmt, fpv, None | |
################ BEV | |
df = df_meta | |
# convert pandas dataframe to geopandas dataframe | |
gdf = gpd.GeoDataFrame(df, | |
geometry=gpd.points_from_xy( | |
df['computed_geometry.long'], | |
df['computed_geometry.lat']), | |
crs=4326) | |
# convert the geopandas dataframe to UTM | |
utm_crs = gdf.estimate_utm_crs() | |
gdf_utm = gdf.to_crs(utm_crs) | |
transformer = Transformer.from_crs(utm_crs, 4326) | |
# load OSM data, if available | |
padding = 50 | |
# calculate the required distance from the center to the edge of the image | |
# so that the image will not be out of bounds when we rotate it | |
map_length = map_length | |
map_length = np.ceil(np.sqrt(map_length**2 + map_length**2)) | |
distance = map_length * mpp | |
# create bounding boxes for each point | |
gdf_utm['bounding_box_utm_p1'] = gdf_utm.apply(lambda row: ( | |
row.geometry.x - distance - padding, | |
row.geometry.y - distance - padding, | |
), axis=1) | |
gdf_utm['bounding_box_utm_p2'] = gdf_utm.apply(lambda row: ( | |
row.geometry.x + distance + padding, | |
row.geometry.y + distance + padding, | |
), axis=1) | |
# convert the bounding box back to lat, long | |
gdf_utm['bounding_box_lat_long_p1'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p1']), axis=1) | |
gdf_utm['bounding_box_lat_long_p2'] = gdf_utm.apply(lambda row: transformer.transform(*row['bounding_box_utm_p2']), axis=1) | |
gdf_utm['bbox_min_lat'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[0]) | |
gdf_utm['bbox_min_long'] = gdf_utm['bounding_box_lat_long_p1'].apply(lambda x: x[1]) | |
gdf_utm['bbox_max_lat'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[0]) | |
gdf_utm['bbox_max_long'] = gdf_utm['bounding_box_lat_long_p2'].apply(lambda x: x[1]) | |
gdf_utm['bbox_formatted'] = gdf_utm.apply(lambda row: f"{row['bbox_min_long']},{row['bbox_min_lat']},{row['bbox_max_long']},{row['bbox_max_lat']}", axis=1) | |
# iterate over the dataframe and get BEV images | |
jobs = gdf_utm[['id', 'bbox_formatted', 'computed_compass_angle']] # only need the id and bbox_formatted columns for the jobs | |
jobs = jobs.to_dict(orient='records').copy() | |
get_bev.get_bev_from_bbox_worker_init(osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, | |
"MapItAnywhere/mia/bev/styles/mia.yml", map_length, mpp, | |
None, True, False, True, True, 1) | |
for job_dict in jobs: | |
get_bev.get_bev_from_bbox_worker(job_dict) | |
bev = plt.imread(rendered_mask_dir / f"{row['id']}.png") | |
yield metadata_fmt, fpv, bev | |
filter_pipeline = filters.FilterPipeline.load_from_yaml("MapItAnywhere/mia/fpv/filter_pipelines/mia.yaml") | |
filter_pipeline.verbose=False | |
no_cam_filter_pipeline = filters.FilterPipeline.load_from_yaml("MapItAnywhere/mia/fpv/filter_pipelines/mia_rural.yaml") | |
no_cam_filter_pipeline.verbose=False | |
loc = Path(".") | |
infos_dir =loc / "infos_dir" | |
raw_image_dir = loc / "raw_images" | |
out_image_dir = loc / "images" | |
osm_cache_dir = loc / "osm_cache" | |
bev_dir = loc / "bev_raw" | |
semantic_mask_dir = loc / "semantic_masks" | |
rendered_mask_dir = loc / "rendered_semantic_masks" | |
all_dirs = [loc, osm_cache_dir, bev_dir, semantic_mask_dir, rendered_mask_dir, out_image_dir, raw_image_dir] | |
for d in all_dirs: | |
os.makedirs(d, exist_ok=True) | |
logger.info(f"Current working directory: {os.getcwd()}, listdir: {os.listdir('.')}") | |
demo = gr.Interface( | |
fn=fetch, | |
inputs=[gr.Text("Pittsburgh, PA, United States", label="Location"), | |
gr.Checkbox(value=False, label="Filter & Undistort"), | |
gr.Checkbox(value=False, label="Disable camera model filtering"), | |
gr.Slider(minimum=64, maximum=512, step=1, label="BEV Dimension", value=224), | |
gr.Slider(minimum=0.1, maximum=2, label="Meters Per Pixel", value=0.5)], | |
outputs=[gr.Text(label="METADATA"), gr.Image(label="FPV"), gr.Image(label="BEV")], | |
title="MapItAnywhere (Data Engine)", | |
description="A demo showcasing samples of MIA's capability to retrieve FPV-BEV pairs worldwide." | |
"For bulk download/heavy filtering please visit the github and follow the instructions to run locally" | |
) | |
logger.info("Starting server") | |
demo.launch(server_name="0.0.0.0", server_port=7860,share=False) |