fungi2024 / script.py
Stefan Wolf
Fixed inference script.
4e31675
raw
history blame
1.59 kB
import pandas as pd
import numpy as np
import os
import subprocess
import sys
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch
# custom script arguments
CONFIG_PATH = 'models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6.py'
CHECKPOINT_PATH = "models/swinv2_base_w24_b16x4-fp16_fungi+val_res_384_cb_epochs_6_epoch_6_20240514-de00365e.pth"
SCORE_THRESHOLD = 0.2
def run_inference(input_csv, output_csv, data_root_path):
"""Load model and dataloader and run inference."""
if not data_root_path.endswith('/'):
data_root_path += '/'
data_cfg_opts = [
f'test_dataloader.dataset.data_root=',
f'test_dataloader.dataset.ann_file={input_csv}',
f'test_dataloader.dataset.data_prefix={data_root_path}']
inference = subprocess.Popen([
'python', '-m',
'tools.test_generate_result_pre-consensus',
CONFIG_PATH, CHECKPOINT_PATH,
output_csv,
'--threshold', str(SCORE_THRESHOLD),
'--no-scores',
'--cfg-options'] + data_cfg_opts)
return_code = inference.wait()
if return_code != 0:
print(f'Inference crashed with exit code {return_code}')
sys.exit(return_code)
print(f'Written {output_csv}')
if __name__ == "__main__":
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
zip_ref.extractall("/tmp/data")
metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
run_inference(metadata_file_path, "./submission.csv", "/tmp/data/private_testset/")