Spaces:
Sleeping
Sleeping
vardaan123
commited on
Commit
•
3dba732
1
Parent(s):
c3e2aa9
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +5 -4
- __pycache__/model.cpython-38.pyc +0 -0
- app.py +92 -0
- dataset_subtree.csv +3 -0
- entity2id_subtree.json +1 -0
- model.py +118 -0
- overall_id_to_name.json +184 -0
- requirements.txt +11 -0
- species_class_model.pt +3 -0
- utils.py +74 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
dataset_subtree.csv filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
title: COSMO
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: COSMO
|
3 |
+
emoji: 🦀
|
4 |
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.29.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/model.cpython-38.pyc
ADDED
Binary file (3.57 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import gradio as gr
|
6 |
+
from model import DistMult
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision import transforms
|
9 |
+
import json
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
# Default image tensor normalization
|
13 |
+
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
|
14 |
+
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
|
15 |
+
|
16 |
+
def generate_target_list(data, entity2id):
|
17 |
+
sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
|
18 |
+
sub = list(sub['t'])
|
19 |
+
categories = []
|
20 |
+
for item in tqdm(sub):
|
21 |
+
if entity2id[str(int(float(item)))] not in categories:
|
22 |
+
categories.append(entity2id[str(int(float(item)))])
|
23 |
+
# print('categories = {}'.format(categories))
|
24 |
+
# print("No. of target categories = {}".format(len(categories)))
|
25 |
+
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
|
26 |
+
|
27 |
+
# Load necessary data and initialize the model
|
28 |
+
entity2id = json.load(open('entity2id_subtree.json', 'r'))
|
29 |
+
id2entity = {v: k for k, v in entity2id.items()}
|
30 |
+
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
|
31 |
+
num_ent_id = len(entity2id)
|
32 |
+
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
|
33 |
+
overall_id_to_name = json.load(open('overall_id_to_name.json'))
|
34 |
+
|
35 |
+
# Initialize your model here
|
36 |
+
model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
|
37 |
+
model.eval()
|
38 |
+
|
39 |
+
ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu'))
|
40 |
+
model.load_state_dict(ckpt['model'], strict=False)
|
41 |
+
print('ckpt loaded...')
|
42 |
+
|
43 |
+
# Define your evaluation function
|
44 |
+
def evaluate(img):
|
45 |
+
transform_steps = transforms.Compose([
|
46 |
+
transforms.ToPILImage(),
|
47 |
+
transforms.Resize((448, 448)),
|
48 |
+
transforms.ToTensor(),
|
49 |
+
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
|
50 |
+
])
|
51 |
+
h = transform_steps(img)
|
52 |
+
r = torch.tensor([3])
|
53 |
+
|
54 |
+
# Assuming `move_to` is a function to move tensors to the desired device
|
55 |
+
h = h.unsqueeze(0)
|
56 |
+
r = r.unsqueeze(0)
|
57 |
+
|
58 |
+
outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1)
|
59 |
+
|
60 |
+
# print('outputs = {}'.format(outputs.size()))
|
61 |
+
|
62 |
+
predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist()
|
63 |
+
|
64 |
+
# print('predictions', predictions)
|
65 |
+
|
66 |
+
result = {}
|
67 |
+
for i in predictions:
|
68 |
+
pred_label = target_list[i].item()
|
69 |
+
label = overall_id_to_name[str(id2entity[pred_label])]
|
70 |
+
prob = outputs[0, i].item()
|
71 |
+
result[label] = prob
|
72 |
+
|
73 |
+
# y_pred = outputs.argmax(-1).cpu()
|
74 |
+
# pred_label = target_list[y_pred].item()
|
75 |
+
# species_label = overall_id_to_name[str(id2entity[pred_label])]
|
76 |
+
|
77 |
+
# print('pred_label', pred_label)
|
78 |
+
# print('species_label', species_label)
|
79 |
+
|
80 |
+
# return species_label
|
81 |
+
return result
|
82 |
+
|
83 |
+
# Gradio interface
|
84 |
+
species_model = gr.Interface(
|
85 |
+
evaluate,
|
86 |
+
gr.inputs.Image(shape=(200, 200)),
|
87 |
+
outputs="label",
|
88 |
+
title='Camera Trap Species Classification demo',
|
89 |
+
# description='Species Classification',
|
90 |
+
# article='Species Classification'
|
91 |
+
)
|
92 |
+
species_model.launch(server_port=8977,share=True, debug=True)
|
dataset_subtree.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e89acb42f04c5593c492cf836ccf6b897d22e76c52b3a262c8e462813fb82cda
|
3 |
+
size 43352089
|
entity2id_subtree.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"93302": 0, "805080": 1, "304358": 2, "332573": 3, "5246131": 4, "691846": 5, "641038": 6, "117569": 7, "147604": 8, "125642": 9, "947318": 10, "801601": 11, "278114": 12, "114656": 13, "114654": 14, "458402": 15, "4940726": 16, "229562": 17, "229560": 18, "244265": 19, "229558": 20, "683263": 21, "847764": 22, "495017": 23, "273244": 24, "796671": 25, "796672": 26, "495018": 27, "495016": 28, "847766": 29, "490533": 30, "490538": 31, "273230": 32, "273227": 33, "5334778": 34, "392222": 35, "392220": 36, "644242": 37, "644258": 38, "864604": 39, "634572": 40, "864610": 41, "747873": 42, "864593": 43, "7067181": 44, "839752": 45, "671055": 46, "831080": 47, "1068783": 48, "922511": 49, "816256": 50, "320824": 51, "254744": 52, "254745": 53, "889045": 54, "271598": 55, "717794": 56, "782350": 57, "782347": 58, "3610624": 59, "3611156": 60, "844553": 61, "185338": 62, "23048": 63, "23039": 64, "666969": 65, "666961": 66, "976847": 67, "976856": 68, "1068778": 69, "237403": 70, "845966": 71, "671049": 72, "392236": 73, "764826": 74, "407000": 75, "44557": 76, "220323": 77, "173067": 78, "276723": 79, "220325": 80, "67361": 81, "220326": 82, "410922": 83, "848923": 84, "848914": 85, "592588": 86, "438471": 87, "438474": 88, "296191": 89, "44559": 90, "384218": 91, "630990": 92, "649553": 93, "866983": 94, "421036": 95, "970404": 96, "394011": 97, "474585": 98, "3609124": 99, "319614": 100, "524854": 101, "173836": 102, "765432": 103, "201068": 104, "970408": 105, "173811": 106, "675197": 107, "675198": 108, "913935": 109, "702152": 110, "386195": 111, "842867": 112, "386191": 113, "770311": 114, "312031": 115, "417957": 116, "417950": 117, "386194": 118, "842868": 119, "741061": 120, "989398": 121, "512437": 122, "842860": 123, "115460": 124, "115449": 125, "268324": 126, "837394": 127, "268346": 128, "203191": 129, "386004": 130, "571323": 131, "392223": 132, "622916": 133, "7655791": 134, "7655792": 135, "510764": 136, "510761": 137, "510762": 138, "986971": 139, "403912": 140, "768685": 141, "768687": 142, "768674": 143, "460505": 144, "534970": 145, "844149": 146, "844145": 147, "534996": 148, "194503": 149, "194523": 150, "194507": 151, "1030872": 152, "1030860": 153, "3611950": 154, "92562": 155, "410156": 156, "410145": 157, "768677": 158, "122647": 159, "19014": 160, "19015": 161, "122641": 162, "798021": 163, "540244": 164, "70819": 165, "346071": 166, "122649": 167, "644255": 168, "70831": 169, "561121": 170, "70827": 171, "70832": 172, "1066581": 173, "490099": 174, "385449": 175, "989809": 176, "989807": 177, "910691": 178, "768678": 179, "768679": 180, "70835": 181, "1016642": 182, "346068": 183, "513794": 184, "513789": 185, "591989": 186, "40168": 187, "1036727": 188, "702522": 189, "513800": 190, "122644": 191, "122645": 192, "591984": 193, "591988": 194, "98208": 195, "591990": 196, "591987": 197, "380144": 198, "436155": 199, "510773": 200, "510775": 201, "510767": 202, "510752": 203, "916745": 204, "730004": 205, "637442": 206, "1037242": 207, "1037247": 208, "906307": 209, "730008": 210, "995191": 211, "995183": 212, "1036752": 213, "1036755": 214, "730021": 215, "730013": 216, "644252": 217, "644249": 218, "644247": 219, "644245": 220, "44565": 221, "827263": 222, "297458": 223, "297460": 224, "679701": 225, "445986": 226, "231614": 227, "1023230": 228, "348043": 229, "67323": 230, "381139": 231, "381140": 232, "348045": 233, "213517": 234, "770319": 235, "837603": 236, "313163": 237, "821952": 238, "821973": 239, "821959": 240, "821953": 241, "372706": 242, "666235": 243, "621176": 244, "247341": 245, "264179": 246, "685113": 247, "3612582": 248, "821960": 249, "872571": 250, "348029": 251, "348040": 252, "914060": 253, "348030": 254, "736280": 255, "348031": 256, "827259": 257, "397138": 258, "397140": 259, "397157": 260, "397160": 261, "563161": 262, "383900": 263, "383901": 264, "397144": 265, "5681": 266, "211399": 267, "159587": 268, "350016": 269, "194343": 270, "5685": 271, "194345": 272, "194340": 273, "5686": 274, "397135": 275, "397136": 276, "252751": 277, "194342": 278, "194349": 279, "42311": 280, "159578": 281, "159576": 282, "42306": 283, "3613295": 284, "563159": 285, "626916": 286, "570215": 287, "280108": 288, "1033548": 289, "1033549": 290, "86169": 291, "86170": 292, "86161": 293, "86162": 294, "42307": 295, "563165": 296, "563163": 297, "774314": 298, "507553": 299, "752746": 300, "626917": 301, "763018": 302, "882766": 303, "86186": 304, "660452": 305, "563154": 306, "42314": 307, "42322": 308, "42324": 309, "563151": 310, "626920": 311, "752758": 312, "752759": 313, "541948": 314, "1070066": 315, "541951": 316, "94003": 317, "520756": 318, "615442": 319, "1068209": 320, "1068227": 321, "1087514": 322, "1034223": 323, "6146951": 324, "9419": 325, "746703": 326, "561107": 327, "561109": 328, "561113": 329, "561114": 330, "561100": 331, "561103": 332, "561106": 333, "561087": 334, "226176": 335, "541924": 336, "541933": 337, "541936": 338, "16033": 339, "277697": 340, "16069": 341, "608046": 342, "393366": 343, "170433": 344, "762047": 345, "919176": 346, "362785": 347, "639642": 348, "329823": 349, "35881": 350, "35888": 351, "4945781": 352, "4945815": 353, "4945816": 354, "4945872": 355, "139516": 356, "4945873": 357, "4945874": 358, "713776": 359, "713772": 360, "872963": 361, "4947372": 362, "335588": 363, "90215": 364, "90223": 365, "664350": 366, "664351": 367, "81461": 368, "241846": 369, "363030": 370, "938413": 371, "931109": 372, "150851": 373, "664463": 374, "244142": 375, "1032057": 376, "1032049": 377, "83286": 378, "604964": 379, "449653": 380, "664480": 381, "539139": 382, "539141": 383, "843074": 384, "772741": 385, "5839486": 386, "241841": 387, "765193": 388, "7068148": 389, "860117": 390, "693339": 391, "837585": 392, "684043": 393, "684045": 394, "684040": 395, "109893": 396, "109881": 397, "157741": 398, "109892": 399, "109882": 400, "979429": 401, "132829": 402, "728070": 403, "51353": 404, "102704": 405, "110936": 406, "521341": 407, "521339": 408, "624441": 409, "53692": 410, "53708": 411, "781250": 412, "446481": 413, "446490": 414, "136462": 415, "446477": 416, "3596058": 417, "204731": 418, "1080967": 419, "352754": 420, "352755": 421, "969837": 422, "609781": 423, "307211": 424, "3596764": 425, "938409": 426, "489432": 427, "989084": 428, "989081": 429, "4947835": 430, "1036185": 431, "786440": 432, "584448": 433, "266054": 434, "313124": 435, "261310": 436, "261316": 437, "427706": 438, "967304": 439, "966318": 440, "5025": 441, "5021": 442, "5030": 443, "3600024": 444, "521835": 445, "521834": 446, "966314": 447, "521837": 448, "414340": 449, "381374": 450, "906602": 451, "1041547": 452, "131990": 453, "774534": 454, "3598135": 455, "96286": 456, "568571": 457, "96367": 458, "176458": 459, "28338": 460, "695334": 461, "645461": 462, "7068500": 463, "3599375": 464, "1076202": 465, "451623": 466, "259942": 467, "1051167": 468, "907909": 469, "106895": 470, "635217": 471, "187411": 472, "320098": 473, "3598028": 474, "81443": 475, "292467": 476, "292469": 477, "402450": 478, "402466": 479, "857847": 480, "857849": 481, "292466": 482, "647692": 483, "8032375": 484, "8032203": 485, "8032276": 486, "8032351": 487, "8032318": 488, "8032251": 489, "8032381": 490, "8032285": 491, "8032224": 492, "8032345": 493, "8032358": 494, "8032284": 495, "8032368": 496, "8032286": 497, "8032289": 498, "8032377": 499, "8032372": 500, "8032362": 501, "8032369": 502, "8032295": 503, "8032363": 504, "8032294": 505, "8032384": 506, "8032383": 507, "8032326": 508, "8032325": 509, "8032234": 510}
|
model.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch import Tensor
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
from torchvision.models import resnet18, resnet50
|
8 |
+
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
9 |
+
|
10 |
+
class DistMult(nn.Module):
|
11 |
+
def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
|
12 |
+
super(DistMult, self).__init__()
|
13 |
+
self.num_ent_uid = num_ent_uid
|
14 |
+
|
15 |
+
self.num_relations = 4
|
16 |
+
|
17 |
+
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
|
18 |
+
self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)
|
19 |
+
|
20 |
+
self.location_embedding = MLP(2, 512, 3)
|
21 |
+
|
22 |
+
self.time_embedding = MLP(1, 512, 3)
|
23 |
+
|
24 |
+
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
25 |
+
self.image_embedding.fc = nn.Linear(2048, 512)
|
26 |
+
|
27 |
+
self.target_list = target_list
|
28 |
+
|
29 |
+
if all_locs is not None:
|
30 |
+
self.all_locs = all_locs.to(device)
|
31 |
+
if all_timestamps is not None:
|
32 |
+
self.all_timestamps = all_timestamps.to(device)
|
33 |
+
|
34 |
+
self.device = device
|
35 |
+
|
36 |
+
self.init()
|
37 |
+
|
38 |
+
def init(self):
|
39 |
+
nn.init.xavier_uniform_(self.ent_embedding.weight.data)
|
40 |
+
nn.init.xavier_uniform_(self.rel_embedding.weight.data)
|
41 |
+
nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)
|
42 |
+
|
43 |
+
def forward_ce(self, h, r, triple_type=None):
|
44 |
+
emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
|
45 |
+
|
46 |
+
emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]
|
47 |
+
|
48 |
+
emb_hr = emb_h * emb_r # [batch, hid]
|
49 |
+
|
50 |
+
if triple_type == ('image', 'id'):
|
51 |
+
score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
|
52 |
+
elif triple_type == ('id', 'id'):
|
53 |
+
score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
|
54 |
+
elif triple_type == ('image', 'location'):
|
55 |
+
loc_emb = self.location_embedding(self.all_locs) # computed for each batch
|
56 |
+
score = torch.mm(emb_hr, loc_emb.T)
|
57 |
+
elif triple_type == ('image', 'time'):
|
58 |
+
time_emb = self.time_embedding(self.all_timestamps)
|
59 |
+
score = torch.mm(emb_hr, time_emb.T)
|
60 |
+
else:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
return score
|
64 |
+
|
65 |
+
def batch_embedding_concat_h(self, e1):
|
66 |
+
e1_embedded = None
|
67 |
+
|
68 |
+
if len(e1.size())==1 or e1.size(1) == 1: # uid
|
69 |
+
# print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
|
70 |
+
e1_embedded = self.ent_embedding(e1.squeeze(-1))
|
71 |
+
elif e1.size(1) == 15: # time
|
72 |
+
e1_embedded = self.time_embedding(e1)
|
73 |
+
elif e1.size(1) == 2: # GPS
|
74 |
+
e1_embedded = self.location_embedding(e1)
|
75 |
+
elif e1.size(1) == 3: # Image
|
76 |
+
e1_embedded = self.image_embedding(e1)
|
77 |
+
|
78 |
+
return e1_embedded
|
79 |
+
|
80 |
+
|
81 |
+
class MLP(nn.Module):
|
82 |
+
def __init__(self,
|
83 |
+
input_dim,
|
84 |
+
output_dim,
|
85 |
+
num_layers=3,
|
86 |
+
p_dropout=0.0,
|
87 |
+
bias=True):
|
88 |
+
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
self.input_dim = input_dim
|
92 |
+
self.output_dim = output_dim
|
93 |
+
|
94 |
+
self.p_dropout = p_dropout
|
95 |
+
step_size = (input_dim - output_dim) // num_layers
|
96 |
+
hidden_dims = [output_dim + (i * step_size)
|
97 |
+
for i in reversed(range(num_layers))]
|
98 |
+
|
99 |
+
mlp = list()
|
100 |
+
layer_indim = input_dim
|
101 |
+
for hidden_dim in hidden_dims:
|
102 |
+
mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
|
103 |
+
nn.Dropout(p=self.p_dropout, inplace=True),
|
104 |
+
nn.PReLU()])
|
105 |
+
|
106 |
+
layer_indim = hidden_dim
|
107 |
+
|
108 |
+
self.mlp = nn.Sequential(*mlp)
|
109 |
+
|
110 |
+
# initialize weights
|
111 |
+
self.init()
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
return self.mlp(x)
|
115 |
+
|
116 |
+
def init(self):
|
117 |
+
for param in self.parameters():
|
118 |
+
nn.init.uniform_(param)
|
overall_id_to_name.json
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"8032375": "motorcycle",
|
3 |
+
"8032203": "empty",
|
4 |
+
"8032276": "pardofelis temminckii",
|
5 |
+
"8032351": "agouti paca",
|
6 |
+
"8032318": "cercopithecus lhoesti",
|
7 |
+
"8032251": "equus quagga",
|
8 |
+
"8032381": "ave desconocida",
|
9 |
+
"8032285": "unknown bird",
|
10 |
+
"8032224": "mazama gouazoubira",
|
11 |
+
"8032345": "francolinus africanus",
|
12 |
+
"8032358": "mazama pandora",
|
13 |
+
"8032284": "canis familiaris",
|
14 |
+
"8032368": "lophura sp",
|
15 |
+
"8032286": "unknown bat",
|
16 |
+
"8032289": "geotrygon sp",
|
17 |
+
"8032377": "puma yagoroundi",
|
18 |
+
"8032372": "myiophoneus caeruleus",
|
19 |
+
"8032362": "arctonyx hoevenii",
|
20 |
+
"8032369": "myiophoneus glaucinus",
|
21 |
+
"8032295": "brotogeris sp",
|
22 |
+
"8032363": "tragulus sp",
|
23 |
+
"8032294": "phaetornis sp",
|
24 |
+
"8032384": "mazama temama",
|
25 |
+
"8032383": "unknown dove",
|
26 |
+
"8032326": "andropadus virens",
|
27 |
+
"8032325": "andropadus latirostris",
|
28 |
+
"8032234": "herpestes sanguineus",
|
29 |
+
"906307": "tayassu pecari",
|
30 |
+
"848914": "dasyprocta punctata",
|
31 |
+
"296191": "cuniculus paca",
|
32 |
+
"42307": "puma concolor",
|
33 |
+
"1034223": "tapirus terrestris",
|
34 |
+
"1037242": "pecari tajacu",
|
35 |
+
"1030860": "mazama americana",
|
36 |
+
"752746": "leopardus pardalis",
|
37 |
+
"664480": "geotrygon montana",
|
38 |
+
"348031": "nasua nasua",
|
39 |
+
"796672": "dasypus novemcinctus",
|
40 |
+
"381140": "eira barbara",
|
41 |
+
"919176": "didelphis marsupialis",
|
42 |
+
"914060": "procyon cancrivorus",
|
43 |
+
"42322": "panthera onca",
|
44 |
+
"490538": "myrmecophaga tridactyla",
|
45 |
+
"402466": "tinamus major",
|
46 |
+
"634572": "sylvilagus brasiliensis",
|
47 |
+
"86162": "puma yagouaroundi",
|
48 |
+
"507553": "leopardus wiedii",
|
49 |
+
"170433": "philander opossum",
|
50 |
+
"19015": "capra aegagrus",
|
51 |
+
"490099": "bos taurus",
|
52 |
+
"70819": "ovis aries",
|
53 |
+
"247341": "canis lupus",
|
54 |
+
"747873": "lepus saxatilis",
|
55 |
+
"115449": "papio anubis",
|
56 |
+
"194343": "genetta genetta",
|
57 |
+
"561121": "tragelaphus scriptus",
|
58 |
+
"541936": "loxodonta africana",
|
59 |
+
"922511": "cricetomys gambianus",
|
60 |
+
"513789": "raphicerus campestris",
|
61 |
+
"383901": "hyaena hyaena",
|
62 |
+
"768679": "aepyceros melampus",
|
63 |
+
"397157": "crocuta crocuta",
|
64 |
+
"1033549": "caracal caracal",
|
65 |
+
"520756": "equus ferus",
|
66 |
+
"563151": "panthera leo",
|
67 |
+
"70832": "tragelaphus oryx",
|
68 |
+
"122645": "kobus ellipsiprymnus",
|
69 |
+
"1036755": "phacochoerus africanus",
|
70 |
+
"42324": "panthera pardus",
|
71 |
+
"159576": "ichneumia albicauda",
|
72 |
+
"666235": "canis mesomelas",
|
73 |
+
"644255": "syncerus caffer",
|
74 |
+
"768674": "giraffa camelopardalis",
|
75 |
+
"989807": "alcelaphus buselaphus",
|
76 |
+
"571323": "chlorocebus pygerythrus",
|
77 |
+
"40168": "madoqua guentheri",
|
78 |
+
"995183": "potamochoerus larvatus",
|
79 |
+
"346068": "nanger granti",
|
80 |
+
"702522": "eudorcas thomsonii",
|
81 |
+
"647692": "struthio camelus",
|
82 |
+
"561087": "orycteropus afer",
|
83 |
+
"752759": "acinonyx jubatus",
|
84 |
+
"521834": "eupodotis senegalensis",
|
85 |
+
"563163": "felis silvestris",
|
86 |
+
"98208": "oryx beisa",
|
87 |
+
"3600024": "lophotis gindiana",
|
88 |
+
"521837": "ardeotis kori",
|
89 |
+
"5021": "lissotis melanogaster",
|
90 |
+
"521339": "argusianus argus",
|
91 |
+
"280108": "prionailurus bengalensis",
|
92 |
+
"194340": "hemigalus derbyanus",
|
93 |
+
"194523": "muntiacus muntjak",
|
94 |
+
"730013": "sus scrofa",
|
95 |
+
"679701": "helarctos malayanus",
|
96 |
+
"844145": "rusa unicolor",
|
97 |
+
"67361": "hystrix brachyura",
|
98 |
+
"42314": "panthera tigris",
|
99 |
+
"201068": "lariscus insignis",
|
100 |
+
"1032049": "chalcophaps indica",
|
101 |
+
"350016": "genetta tigrina",
|
102 |
+
"220326": "hystrix cristata",
|
103 |
+
"821953": "lycaon pictus",
|
104 |
+
"561114": "procavia capensis",
|
105 |
+
"989081": "momotus momota",
|
106 |
+
"592588": "dasyprocta fuliginosa",
|
107 |
+
"736280": "nasua narica",
|
108 |
+
"273227": "tamandua mexicana",
|
109 |
+
"362785": "didelphis sp",
|
110 |
+
"157741": "penelope purpurascens",
|
111 |
+
"510752": "camelus dromedarius",
|
112 |
+
"821973": "otocyon megalotis",
|
113 |
+
"684040": "acryllium vulturinum",
|
114 |
+
"1068209": "equus grevyi",
|
115 |
+
"563161": "proteles cristata",
|
116 |
+
"86170": "leptailurus serval",
|
117 |
+
"70827": "tragelaphus strepsiceros",
|
118 |
+
"510762": "hippopotamus amphibius",
|
119 |
+
"427706": "burhinus capensis",
|
120 |
+
"397136": "paguma larvata",
|
121 |
+
"660452": "pardofelis marmorata",
|
122 |
+
"313163": "cuon alpinus",
|
123 |
+
"872963": "varanus salvator",
|
124 |
+
"213517": "martes flavigula",
|
125 |
+
"194349": "prionodon linsang",
|
126 |
+
"352755": "rollulus rouloul",
|
127 |
+
"53708": "lophura inornata",
|
128 |
+
"110936": "polyplectron chalcurum",
|
129 |
+
"644245": "manis javanica",
|
130 |
+
"798021": "capricornis sumatraensis",
|
131 |
+
"837394": "macaca sp",
|
132 |
+
"1080967": "francolinus nobilis",
|
133 |
+
"436155": "cephalophus nigrifrons",
|
134 |
+
"276723": "atherurus africanus",
|
135 |
+
"417950": "pan troglodytes",
|
136 |
+
"203191": "cercopithecus mitis",
|
137 |
+
"524854": "funisciurus carruthersi",
|
138 |
+
"645461": "motacilla flava",
|
139 |
+
"3611156": "thamnomys venustus",
|
140 |
+
"675198": "protoxerus stangeri",
|
141 |
+
"3609124": "paraxerus boehmi",
|
142 |
+
"380144": "cephalophus silvicultor",
|
143 |
+
"976856": "oenomys hypoxanthus",
|
144 |
+
"106895": "melocichla mentalis",
|
145 |
+
"666961": "hybomys univittatus",
|
146 |
+
"23039": "colomys goslingi",
|
147 |
+
"185338": "hylomyscus stella",
|
148 |
+
"159587": "genetta servalina",
|
149 |
+
"621176": "canis adustus",
|
150 |
+
"845966": "mus minutoides",
|
151 |
+
"772741": "musophaga rossae",
|
152 |
+
"150851": "turtur tympanistria",
|
153 |
+
"717794": "praomys tullbergi",
|
154 |
+
"782347": "malacomys longipes",
|
155 |
+
"693339": "alopochen aegyptiaca",
|
156 |
+
"254745": "deomys ferrugineus",
|
157 |
+
"96367": "turdus olivaceus",
|
158 |
+
"92562": "mazama sp",
|
159 |
+
"685113": "urocyon cinereoargenteus",
|
160 |
+
"446490": "meleagris ocellata",
|
161 |
+
"132829": "crax rubra",
|
162 |
+
"9419": "tapirus bairdii",
|
163 |
+
"348040": "procyon lotor",
|
164 |
+
"410145": "odocoileus virginianus",
|
165 |
+
"244142": "leptotila plumbeiceps",
|
166 |
+
"3611950": "mazama temama",
|
167 |
+
"1023230": "conepatus semistriatus",
|
168 |
+
"109882": "ortalis vetula",
|
169 |
+
"512437": "presbytis thomasi",
|
170 |
+
"882766": "neofelis diardi",
|
171 |
+
"3598028": "dendrocitta occipitalis",
|
172 |
+
"3598135": "niltava sumatrana",
|
173 |
+
"451623": "leiothrix argentauris",
|
174 |
+
"3596058": "arborophila rubrirostris",
|
175 |
+
"53692": "lophura erythrophthalma",
|
176 |
+
"266054": "spilornis cheela",
|
177 |
+
"3613295": "herpestes semitorquatus",
|
178 |
+
"821960": "cerdocyon thous",
|
179 |
+
"407000": "peromyscus sp",
|
180 |
+
"3596764": "tigrisoma mexicanum",
|
181 |
+
"604964": "claravis pretiosa",
|
182 |
+
"421036": "sciurus sp",
|
183 |
+
"906602": "aramides cajanea"
|
184 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.0
|
2 |
+
pandas==1.5.3
|
3 |
+
numpy==1.24.2
|
4 |
+
Pillow==9.4.0
|
5 |
+
scipy==1.10.1
|
6 |
+
tensorboard==2.12.2
|
7 |
+
torchvision==0.15.1
|
8 |
+
tqdm==4.64.1
|
9 |
+
wilds==2.0.0
|
10 |
+
matplotlib==3.7.1
|
11 |
+
gradio==3.50.0
|
species_class_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:429e1eac2a4cc58b6e3ed0c660ed93b01525628530e5822a076ebf6d878120b9
|
3 |
+
size 301166871
|
utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import csv
|
4 |
+
import argparse
|
5 |
+
import random
|
6 |
+
from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import pandas as pd
|
10 |
+
import re
|
11 |
+
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
try:
|
15 |
+
from torch_geometric.data import Batch
|
16 |
+
except ImportError:
|
17 |
+
pass
|
18 |
+
|
19 |
+
def set_seed(seed):
|
20 |
+
"""Sets seed"""
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
torch.cuda.manual_seed(seed)
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
np.random.seed(seed)
|
25 |
+
random.seed(seed)
|
26 |
+
torch.backends.cudnn.benchmark = False
|
27 |
+
torch.backends.cudnn.deterministic = True
|
28 |
+
|
29 |
+
|
30 |
+
def move_to(obj, device):
|
31 |
+
if isinstance(obj, dict):
|
32 |
+
return {k: move_to(v, device) for k, v in obj.items()}
|
33 |
+
elif isinstance(obj, list):
|
34 |
+
return [move_to(v, device) for v in obj]
|
35 |
+
elif isinstance(obj, float) or isinstance(obj, int):
|
36 |
+
return obj
|
37 |
+
else:
|
38 |
+
# Assume obj is a Tensor or other type
|
39 |
+
# (like Batch, for MolPCBA) that supports .to(device)
|
40 |
+
return obj.to(device)
|
41 |
+
|
42 |
+
def detach_and_clone(obj):
|
43 |
+
if torch.is_tensor(obj):
|
44 |
+
return obj.detach().clone()
|
45 |
+
elif isinstance(obj, dict):
|
46 |
+
return {k: detach_and_clone(v) for k, v in obj.items()}
|
47 |
+
elif isinstance(obj, list):
|
48 |
+
return [detach_and_clone(v) for v in obj]
|
49 |
+
elif isinstance(obj, float) or isinstance(obj, int):
|
50 |
+
return obj
|
51 |
+
else:
|
52 |
+
raise TypeError("Invalid type for detach_and_clone")
|
53 |
+
|
54 |
+
def collate_list(vec):
|
55 |
+
"""
|
56 |
+
If vec is a list of Tensors, it concatenates them all along the first dimension.
|
57 |
+
|
58 |
+
If vec is a list of lists, it joins these lists together, but does not attempt to
|
59 |
+
recursively collate. This allows each element of the list to be, e.g., its own dict.
|
60 |
+
|
61 |
+
If vec is a list of dicts (with the same keys in each dict), it returns a single dict
|
62 |
+
with the same keys. For each key, it recursively collates all entries in the list.
|
63 |
+
"""
|
64 |
+
if not isinstance(vec, list):
|
65 |
+
raise TypeError("collate_list must take in a list")
|
66 |
+
elem = vec[0]
|
67 |
+
if torch.is_tensor(elem):
|
68 |
+
return torch.cat(vec)
|
69 |
+
elif isinstance(elem, list):
|
70 |
+
return [obj for sublist in vec for obj in sublist]
|
71 |
+
elif isinstance(elem, dict):
|
72 |
+
return {k: collate_list([d[k] for d in vec]) for k in elem}
|
73 |
+
else:
|
74 |
+
raise TypeError("Elements of the list to collate must be tensors or dicts.")
|